diff --git a/class.go b/class.go index 83f0780..789f922 100644 --- a/class.go +++ b/class.go @@ -20,6 +20,21 @@ var fns = template.FuncMap{ "plus1": func(x int) int { return x + 1 }, + "sub": func(x, y int) int { + return x - y + }, + "truncateComment": func(comment string, maxLen int) string { + if len(comment) <= maxLen { + return comment + } + // Find the last space before maxLen + truncated := comment[:maxLen] + lastSpace := strings.LastIndex(truncated, " ") + if lastSpace > maxLen/2 { + return truncated[:lastSpace] + } + return truncated + }, } func init() { @@ -150,8 +165,17 @@ func ParseClass(file string) (*Class, error) { log.Error("No class found in document") return nil, fmt.Errorf("no class found") } - res.ClassName = strings.TrimSpace(class.Text()) - log.Info("Found class: %s", res.ClassName) + className := strings.TrimSpace(class.Text()) + // Clean up class name to be a valid Lua identifier + // Replace spaces and special characters with underscores + className = strings.ReplaceAll(className, " ", "_") + className = strings.ReplaceAll(className, "[", "") + className = strings.ReplaceAll(className, "]", "") + className = strings.ReplaceAll(className, ":", "_") + className = strings.ReplaceAll(className, "-", "_") + className = strings.ReplaceAll(className, ",", "") + res.ClassName = className + log.Info("Found class: %s (cleaned from: %s)", res.ClassName, strings.TrimSpace(class.Text())) log.Debug("Parsing constructors") res.Constructors, err = getConstructors(doc, log) @@ -281,7 +305,6 @@ func getFields(doc *goquery.Document, log *logger.Logger) ([]Field, error) { return res, nil } -// TODO: Implement parsing return value types and comments func getMethods(doc *goquery.Document, log *logger.Logger) ([]Method, error) { log.Debug("Starting method parsing") res := []Method{} @@ -289,32 +312,163 @@ func getMethods(doc *goquery.Document, log *logger.Logger) ([]Method, error) { codeblocks := doc.Find("div.floatright > div.codecontainer") log.Trace("Found %d code blocks for method parsing", codeblocks.Length()) - codeblocks.ChildrenFiltered("div.function").Each(func(i int, s *goquery.Selection) { + codeblocks.Each(func(blockIndex int, codeblock *goquery.Selection) { + functionDiv := codeblock.Find("div.function") + if functionDiv.Length() == 0 { + return + } + method := Method{} - method.Name = strings.TrimSpace(s.AttrOr("id", "")) - method.Comment = strings.TrimSpace(s.Find("span.comment").Text()) + method.Name = strings.TrimSpace(functionDiv.AttrOr("id", "")) + method.Comment = strings.TrimSpace(functionDiv.Find("span.comment").Text()) - log.Trace("Processing method %d: name='%s', comment='%s'", i, method.Name, method.Comment) + log.Trace("Processing method %d: name='%s', comment='%s'", blockIndex, method.Name, method.Comment) - types := s.Find("span.type") - parameters := s.Find("span.parameter") + // Parse parameters + types := functionDiv.Find("span.type") + parameters := functionDiv.Find("span.parameter") log.Trace("Method %s has %d types and %d parameters", method.Name, types.Length(), parameters.Length()) types.Each(func(i int, s *goquery.Selection) { param := Param{} param.Name = strings.TrimSpace(parameters.Eq(i).Text()) + + // Skip parameters with empty names + if param.Name == "" { + log.Trace("Method %s parameter %d has empty name, skipping", method.Name, i) + return + } + if IsReservedKeyword(param.Name) { log.Trace("Parameter name '%s' is reserved keyword, prefixing with __", param.Name) param.Name = fmt.Sprintf("__%s", param.Name) } - param.Type = strings.TrimSpace(types.Eq(i).Text()) - param.Type = MapType(param.Type) + + // Get the type text and handle cases where it might be split across lines + typeText := strings.TrimSpace(types.Eq(i).Text()) + // Replace newlines and multiple spaces with single spaces + typeText = strings.ReplaceAll(typeText, "\n", " ") + typeText = strings.ReplaceAll(typeText, "\r", " ") + // Replace multiple spaces with single space + for strings.Contains(typeText, " ") { + typeText = strings.ReplaceAll(typeText, " ", " ") + } + typeText = strings.TrimSpace(typeText) + + param.Type = MapType(typeText) log.Trace("Method %s parameter %d: name='%s', type='%s'", method.Name, i, param.Name, param.Type) method.Params = append(method.Params, param) }) - log.Trace("Method %s has %d parameters", method.Name, len(method.Params)) + // Parse return values + // First, try to get return type from function signature + functionText := functionDiv.Text() + if strings.Contains(functionText, "function") { + // Extract return types from function signature + // Look for patterns like "function var", "function int...", "function Matrix, int..." + lines := strings.Split(functionText, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "function") { + // Extract everything after "function" until the function name + functionPart := strings.TrimPrefix(line, "function") + // Find the function name (usually ends with "(" or " ") + funcNameIndex := strings.Index(functionPart, "(") + if funcNameIndex == -1 { + funcNameIndex = strings.Index(functionPart, " ") + } + if funcNameIndex != -1 { + returnTypesPart := strings.TrimSpace(functionPart[:funcNameIndex]) + // Remove the function name from the return types part + // The function name is the last word in the return types part + words := strings.Fields(returnTypesPart) + if len(words) > 1 { + // Remove the last word (function name) and join the rest + returnTypesPart = strings.Join(words[:len(words)-1], " ") + } + + // Handle complex return types like "table" and multiple returns like "Matrix, int..." + // We need to be careful about commas inside angle brackets + var returnTypes []string + var currentType strings.Builder + bracketDepth := 0 + + for i := 0; i < len(returnTypesPart); i++ { + char := returnTypesPart[i] + if char == '<' { + bracketDepth++ + currentType.WriteByte(char) + } else if char == '>' { + bracketDepth-- + currentType.WriteByte(char) + } else if char == ',' && bracketDepth == 0 { + // Only split on commas that are not inside angle brackets + returnTypes = append(returnTypes, strings.TrimSpace(currentType.String())) + currentType.Reset() + } else { + currentType.WriteByte(char) + } + } + // Add the last type + if currentType.Len() > 0 { + returnTypes = append(returnTypes, strings.TrimSpace(currentType.String())) + } + + for _, returnType := range returnTypes { + if returnType == "" { + continue + } + + // Handle cases like "int..." (multiple return values) + if strings.HasSuffix(returnType, "...") { + returnType = strings.TrimSuffix(returnType, "...") + returnType = MapType(returnType) + method.Returns = append(method.Returns, Return{ + Type: returnType, + Comment: "multiple return values", + }) + } else { + returnType = MapType(returnType) + method.Returns = append(method.Returns, Return{ + Type: returnType, + Comment: "", + }) + } + } + } + break + } + } + } + + // Parse return documentation from the second div + detailsDiv := codeblock.Find("div:not(.function)") + if detailsDiv.Length() > 0 { + // Look for "Returns" section + returnsHeader := detailsDiv.Find("p:contains('Returns')") + if returnsHeader.Length() > 0 { + // Get the indented content after the Returns header + indentedContent := detailsDiv.Find("div.indented p") + indentedContent.Each(func(i int, s *goquery.Selection) { + comment := strings.TrimSpace(s.Text()) + if comment != "" { + // If we have return types but no comments yet, add this as a comment + if len(method.Returns) > i { + method.Returns[i].Comment = comment + } else if len(method.Returns) == 0 { + // If no return types were found in signature, create a return with just comment + method.Returns = append(method.Returns, Return{ + Type: "any", + Comment: comment, + }) + } + } + }) + } + } + + log.Trace("Method %s has %d parameters and %d return values", method.Name, len(method.Params), len(method.Returns)) res = append(res, method) }) diff --git a/class.tmpl b/class.tmpl index 8d48ba5..9de03ab 100644 --- a/class.tmpl +++ b/class.tmpl @@ -11,25 +11,28 @@ {{- range $param := $method.Params}} ---@param {{.Name}} {{.Type}}{{if ne .Comment ""}} {{.Comment}}{{end}} {{- end}} - {{- range $ret := $method.Returns}} - ---@return {{.Type}}{{if ne .Comment ""}} #{{.Comment}}{{end}} + {{- if gt (len $method.Returns) 0}} + {{- range $retIndex, $ret := $method.Returns}} + {{- if eq $retIndex (sub (len $method.Returns) 1)}} + ---@return {{.Type}}{{if ne .Comment ""}} #{{truncateComment .Comment 80}}{{end}} + {{- else}} + ---@return {{.Type}}{{if ne .Comment ""}} #{{truncateComment .Comment 80}}{{end}} + {{- end}} {{- end}} {{.Name}} = function(self{{if gt (len .Params) 0}}, {{range $index, $param := .Params}}{{if $index}}, {{end}}{{$param.Name}}{{end}}{{end}}) end, + {{- else}} + {{.Name}} = function(self{{if gt (len .Params) 0}}, {{range $index, $param := .Params}}{{if $index}}, {{end}}{{$param.Name}}{{end}}{{end}}) end, + {{- end}} {{- if ne (plus1 $index) $n}} {{- end}} {{- end}} } +{{- if gt (len .Constructors) 0}} ---@type {{$.ClassName}} {{- range .Constructors}} - {{- range .Params}} - {{- if ne .Comment ""}} ----{{.Name}} ({{.Type}}) -> {{.Comment}} - {{- end}} - {{- end}} ----@overload fun({{range $index, $param := .Params}}{{if $index}}, {{end}}{{$param.Name}}: {{$param.Type}}{{end}}): {{$.ClassName}}{{- if ne .Comment ""}} ----{{.Comment}} +---@overload fun({{range $index, $param := .Params}}{{if $index}}, {{end}}{{$param.Name}}: {{$param.Type}}{{end}}): {{$.ClassName}} {{- end}} {{- end}} {{.ClassName}} = nil diff --git a/main.go b/main.go index 6b337ac..3a9ed05 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "flag" + "strings" _ "embed" @@ -35,11 +36,56 @@ func main() { } func MapType(t string) string { + // Handle complex types like table + if strings.Contains(t, "<") && strings.Contains(t, ">") { + // Extract the base type and inner types + openBracket := strings.Index(t, "<") + closeBracket := strings.LastIndex(t, ">") + if openBracket != -1 && closeBracket != -1 { + baseType := t[:openBracket] + innerTypes := t[openBracket+1 : closeBracket] + + // Split inner types by comma, but be careful about nested brackets + var mappedInnerTypes []string + var current strings.Builder + bracketDepth := 0 + + for i := 0; i < len(innerTypes); i++ { + char := innerTypes[i] + if char == '<' { + bracketDepth++ + current.WriteByte(char) + } else if char == '>' { + bracketDepth-- + current.WriteByte(char) + } else if char == ',' && bracketDepth == 0 { + // Only split on commas that are not inside nested brackets + mappedInnerTypes = append(mappedInnerTypes, MapType(strings.TrimSpace(current.String()))) + current.Reset() + } else { + current.WriteByte(char) + } + } + // Add the last inner type + if current.Len() > 0 { + mappedInnerTypes = append(mappedInnerTypes, MapType(strings.TrimSpace(current.String()))) + } + + // Reconstruct the complex type + return baseType + "<" + strings.Join(mappedInnerTypes, ", ") + ">" + } + } + + // Handle simple types switch t { case "var": return "any" + case "var...": + return "any..." case "int": return "number" + case "unsigned int": + return "number" case "float": return "number" case "double":