diff --git a/class.go b/class.go index 25f6291..a8b4d6f 100644 --- a/class.go +++ b/class.go @@ -52,6 +52,7 @@ func init() { type ( Class struct { ClassName string + Inheritance string Fields []Field Methods []Method Constructors []Constructor diff --git a/class.tmpl b/class.tmpl index 1282349..fd44371 100644 --- a/class.tmpl +++ b/class.tmpl @@ -1,6 +1,6 @@ --luacheck: ignore 212 111 ---@diagnostic disable: missing-return, lowercase-global ----@class {{.ClassName}} +---@class {{.ClassName}}{{if ne .Inheritance ""}} : {{.Inheritance}}{{end}} {{- range .Fields}} ---@field {{.Name}} {{.Type}}{{if ne .Comment ""}} {{.Comment}}{{end}} {{- end}} diff --git a/main.go b/main.go index 4436abd..ead6207 100644 --- a/main.go +++ b/main.go @@ -3,12 +3,14 @@ package main import ( "flag" "fmt" + "os" "path/filepath" "strings" _ "embed" logger "git.site.quack-lab.dev/dave/cylogger" + "github.com/PuerkitoBio/goquery" ) func main() { @@ -152,6 +154,7 @@ func MergeClasses(files []string) (*Class, error) { // Parse all classes var classes []*Class var baseName string + var inheritance string for _, file := range files { class, err := ParseClass(file) @@ -164,11 +167,42 @@ func MergeClasses(files []string) (*Class, error) { if baseName == "" { baseName = class.ClassName } + + // Check for inheritance information in the original file + originalClassName := getOriginalClassName(file) + if strings.Contains(originalClassName, " : ") { + parts := strings.Split(originalClassName, " : ") + if len(parts) == 2 { + // Extract the inherited class name and clean it + inheritedClass := strings.TrimSpace(parts[1]) + inheritedClass = strings.ReplaceAll(inheritedClass, "[Client]", "") + inheritedClass = strings.ReplaceAll(inheritedClass, "[Server]", "") + inheritedClass = strings.ReplaceAll(inheritedClass, "[", "") + inheritedClass = strings.ReplaceAll(inheritedClass, "]", "") + inheritedClass = strings.ReplaceAll(inheritedClass, "-", "_") + inheritedClass = strings.ReplaceAll(inheritedClass, ",", "") + inheritedClass = strings.ReplaceAll(inheritedClass, " ", "_") + // Clean up multiple underscores + for strings.Contains(inheritedClass, "__") { + inheritedClass = strings.ReplaceAll(inheritedClass, "__", "_") + } + inheritedClass = strings.Trim(inheritedClass, "_") + + if inheritance == "" { + inheritance = inheritedClass + } + } + } } // Merge all classes into the first one merged := classes[0] + // Set inheritance if found + if inheritance != "" { + merged.Inheritance = inheritance + } + // Create maps to track methods and fields by name methodMap := make(map[string]*Method) fieldMap := make(map[string]*Field) @@ -212,3 +246,24 @@ func MergeClasses(files []string) (*Class, error) { return merged, nil } + +// getOriginalClassName extracts the original class name from the HTML file +func getOriginalClassName(file string) string { + filehandle, err := os.Open(file) + if err != nil { + return "" + } + defer filehandle.Close() + + doc, err := goquery.NewDocumentFromReader(filehandle) + if err != nil { + return "" + } + + class := doc.Find("div.floatright > h1") + if class.Length() == 0 { + return "" + } + + return strings.TrimSpace(class.Text()) +}