37 Commits

Author SHA1 Message Date
2fa99ec3a2 Add lua meta to be used with luals and dumped via cook -m 2025-12-19 13:37:23 +01:00
8dd212fc71 Code polish 2025-12-19 13:24:44 +01:00
a4bbaf9f27 Fix up the lua tests
To be less retarded...
2025-12-19 13:22:17 +01:00
419a8118fc Update example with xml 2025-12-19 13:13:26 +01:00
da5b621cb6 Remove some unused shit and write tests for coverage 2025-12-19 13:13:26 +01:00
1df0263a42 Rework modifiers to variables again 2025-12-19 11:54:25 +01:00
74394cbde9 Integrate the xml processing with the rest of the project 2025-12-19 11:54:25 +01:00
f1ea0f9156 Hallucinate some xml helper functions in lua 2025-12-19 11:54:20 +01:00
fff8869aff Add testfiles 2025-12-19 11:54:20 +01:00
a0c5a5f18c Hallucinate up an xml parser implementation
Who knows if this will work...
2025-12-19 11:54:20 +01:00
b309e3e6f0 Update example file with json commands 2025-12-19 11:54:20 +01:00
09cdc91761 Rework variables to not be commands... 2025-12-19 11:32:59 +01:00
a18573c9f8 Memoize the glob statement
Because we were doing the same work many times :(
2025-12-02 16:50:39 +01:00
eacc92ce4b Rename GLob to glob 2025-12-02 16:36:11 +01:00
3bcc958dda Remove the error return value and instead just throw error 2025-11-15 18:10:24 +01:00
11f0bbee53 "Rework" the csv parsing to cook metatable for header access instead of whatever the fuck I was doing 2025-11-15 18:06:21 +01:00
c145ad0900 Add headers to rowS (SSSSSSSSSSSSSSSS) and not individual row 2025-11-15 16:56:46 +01:00
e02c1f018f fixup! Add lua csv parser regression test(s) 2025-11-15 16:54:22 +01:00
07fea6238f Add lua csv parser regression test(s) 2025-11-15 16:53:06 +01:00
5f1fdfa6c1 Fix some go linter warnings 2025-11-15 16:53:06 +01:00
4fb25d0463 Add a complex test for csv parser 2025-11-15 16:37:38 +01:00
bf23894188 Check options passed to csv parser 2025-11-15 16:30:53 +01:00
aec0f9f171 Update readme for csv parsing 2025-11-15 16:01:36 +01:00
83fed68432 Fix csv header key assignment 2025-11-15 16:01:31 +01:00
4311533445 Lowercase the csv options 2025-11-15 15:57:21 +01:00
ce28b948d0 Add an options parameter to csv parser and comment support 2025-11-15 15:51:49 +01:00
efc602e0ba Add a few more lua tests 2025-11-15 15:43:50 +01:00
917063db0c Refactor lua helper script to separate file
And write a few tests for it
2025-11-15 15:42:57 +01:00
3e552428a5 Improve the csv parser by reading header and assigning values to their header
So we have things like row[1] AND ALSO row["foobar"]
And also row.foobar of course
2025-11-15 15:39:49 +01:00
50455c491d Add support for any SV like TSV 2025-11-15 15:25:18 +01:00
12ec399b09 Do some retarded shit idk what claude did here hopefully something good this will probably get dropped 2025-11-11 12:53:53 +01:00
5a49998c2c Add a toCSV to go along with the fromCSV 2025-11-03 19:28:00 +01:00
590f19603e Make regex matching set INTEGER keys not string 2025-11-03 19:07:15 +01:00
ee8c4b9aa5 Ah oops, can't be local oopsie!! 2025-11-03 16:35:53 +01:00
e8d6613ac8 Add ParseCSV as lua function 2025-11-03 16:33:44 +01:00
91ad9006fa Move the example generation to separate flag 2025-11-02 20:04:39 +01:00
60ba3ad417 Fix some tests that broke for some good reason I'm sure 2025-10-26 17:10:42 +01:00
48 changed files with 6882 additions and 689 deletions

2
.gitignore vendored
View File

@@ -1,4 +1,4 @@
*.exe
.qodo
*.sqlite
testfiles
.cursor/rules

View File

@@ -16,6 +16,7 @@ A Go-based tool for modifying XML, JSON, and text documents using XPath/JSONPath
- String manipulations
- Date conversions
- Structural changes
- CSV/TSV parsing with comments and headers
- Whole ass Lua environment
- **Error Handling**: Comprehensive error detection for:
- Invalid XML/JSON
@@ -101,6 +102,87 @@ chef -xml "//item" "if tonumber(v.stock) > 0 then v.price = v.price * 0.8 end" i
<item stock="5" price="8.00"/>
```
### 6. CSV/TSV Processing
The Lua environment includes CSV parsing functions that support comments, headers, and custom delimiters.
```lua
-- Basic CSV parsing
local rows = fromCSV(csvText)
-- With options
local rows = fromCSV(csvText, {
delimiter = "\t", -- Tab delimiter for TSV (default: ",")
hasHeaders = true, -- First row is headers (default: false)
hasComments = true -- Filter lines starting with # (default: false)
})
-- Access by index
local value = rows[1][2]
-- Access by header name (when hasHeaders = true)
local value = rows[1].Name
-- Convert back to CSV
local csv = toCSV(rows, "\t") -- Optional delimiter parameter
```
**Example with commented TSV file:**
```lua
-- Input file:
-- #mercenary_profiles
-- Id Name Value
-- 1 Test 100
-- 2 Test2 200
local csv = readFile("mercenaries.tsv")
local rows = fromCSV(csv, {
delimiter = "\t",
hasHeaders = true,
hasComments = true
})
-- Access data
rows[1].Name -- "Test"
rows[2].Value -- "200"
```
## Lua Helper Functions
The Lua environment includes many helper functions:
### Math Functions
- `min(a, b)`, `max(a, b)` - Min/max of two numbers
- `round(x, n)` - Round to n decimal places
- `floor(x)`, `ceil(x)` - Floor/ceiling functions
### String Functions
- `upper(s)`, `lower(s)` - Case conversion
- `trim(s)` - Remove leading/trailing whitespace
- `format(s, ...)` - String formatting
- `strsplit(inputstr, sep)` - Split string by separator
### CSV Functions
- `fromCSV(csv, options)` - Parse CSV/TSV text into table of rows
- Options: `delimiter` (default: ","), `hasHeaders` (default: false), `hasComments` (default: false)
- `toCSV(rows, delimiter)` - Convert table of rows back to CSV text
### Conversion Functions
- `num(str)` - Convert string to number (returns 0 if invalid)
- `str(num)` - Convert number to string
- `is_number(str)` - Check if string is numeric
### Table Functions
- `isArray(t)` - Check if table is a sequential array
- `dump(table, depth)` - Print table structure recursively
### HTTP Functions
- `fetch(url, options)` - Make HTTP request, returns response table
- Options: `method`, `headers`, `body`
- Returns: `{status, statusText, ok, body, headers}`
### Regex Functions
- `re(pattern, input)` - Apply regex pattern, returns table with matches
## Installation
```bash

View File

@@ -1,28 +0,0 @@
package main
import (
"time"
logger "git.site.quack-lab.dev/dave/cylogger"
)
func main() {
// Initialize logger with DEBUG level
logger.Init(logger.LevelDebug)
// Test different log levels
logger.Info("This is an info message")
logger.Debug("This is a debug message")
logger.Warning("This is a warning message")
logger.Error("This is an error message")
logger.Trace("This is a trace message (not visible at DEBUG level)")
// Test with a goroutine
logger.SafeGo(func() {
time.Sleep(10 * time.Millisecond)
logger.Info("Message from goroutine")
})
// Wait for goroutine to complete
time.Sleep(20 * time.Millisecond)
}

View File

@@ -1,6 +1,9 @@
# Global variables (no name/regex/lua/files - only modifiers)
[[commands]]
modifiers = { foobar = 4, multiply = 1.5, prefix = 'NEW_', enabled = true }
# Global variables - available to all commands
[variables]
foobar = 4
multiply = 1.5
prefix = 'NEW_'
enabled = true
# Multi-regex example using variable in Lua
[[commands]]
@@ -99,21 +102,443 @@ regex = '(?P<key>[A-Za-z0-9_]+)\s*='
lua = 'key = prefix .. key; return true'
files = ['**/*.properties']
# JSON mode examples
# HTTP fetch example - get version from API and update config
[[commands]]
name = 'JSONArrayMultiply'
name = 'UpdateVersionFromAPI'
regex = 'version\s*=\s*"(?P<version>[^"]+)"'
lua = '''
local response = fetch("https://api.example.com/version", {
method = "GET",
headers = { ["Accept"] = "application/json" }
})
if response and response.body then
local data = fromJSON(response.body)
if data.latest then
version = data.latest
return true
end
end
return false
'''
files = ['version.conf']
# Complex multiline block replacement with state machine
[[commands]]
name = 'ModifyConfigBlock'
regex = '''(?x)
\[server\]
\s+host\s*=\s*"(?P<host>[^"]+)"
\s+port\s*=\s*(?P<port>\d+)
\s+ssl\s*=\s*(?P<ssl>true|false)'''
lua = '''
port = num(port) + 1000
ssl = "true"
replacement = format('[server]\n host = "%s"\n port = %d\n ssl = %s', host, port, ssl)
return true
'''
files = ['server.conf']
# Regex with !any to capture entire sections
[[commands]]
name = 'WrapInComment'
regex = 'FEATURE_START\n(?P<feature>!any)\nFEATURE_END'
lua = '''
replacement = "FEATURE_START\n# " .. feature:gsub("\n", "\n# ") .. "\nFEATURE_END"
return true
'''
files = ['features/**/*.txt']
# Advanced capture groups with complex logic
[[commands]]
name = 'UpdateDependencies'
regex = 'dependency\("(?P<group>[^"]+)", "(?P<name>[^"]+)", "(?P<version>[^"]+)"\)'
lua = '''
local major, minor, patch = version:match("(%d+)%.(%d+)%.(%d+)")
if major and minor and patch then
-- Bump minor version
minor = num(minor) + 1
version = format("%s.%s.0", major, minor)
return true
end
return false
'''
files = ['build.gradle', 'build.gradle.kts']
# JSON mode examples - modify single field
[[commands]]
name = 'JSONModifyField'
json = true
lua = 'for i, item in ipairs(data.items) do data.items[i].value = item.value * 2 end; return true'
lua = '''
data.value = 84
modified = true
'''
files = ['data/**/*.json']
# JSON mode - add new field
[[commands]]
name = 'JSONObjectUpdate'
name = 'JSONAddField'
json = true
lua = 'data.version = "2.0.0"; data.enabled = true; return true'
lua = '''
data.newField = "added"
modified = true
'''
files = ['config/**/*.json']
# JSON mode - modify nested fields
[[commands]]
name = 'JSONNestedModify'
json = true
lua = 'if data.settings and data.settings.performance then data.settings.performance.multiplier = data.settings.performance.multiplier * 1.5 end; return true'
lua = '''
if data.config and data.config.settings then
data.config.settings.enabled = true
data.config.settings.timeout = 60
modified = true
end
'''
files = ['settings/**/*.json']
# JSON mode - modify array elements
[[commands]]
name = 'JSONArrayMultiply'
json = true
lua = '''
if data.items then
for i, item in ipairs(data.items) do
data.items[i].value = item.value * multiply
end
modified = true
end
'''
files = ['data/**/*.json']
# JSON mode - modify object version
[[commands]]
name = 'JSONObjectUpdate'
json = true
lua = '''
data.version = "2.0.0"
data.enabled = enabled
modified = true
'''
files = ['config/**/*.json']
# JSON mode - surgical editing of specific row
[[commands]]
name = 'JSONSurgicalEdit'
json = true
lua = '''
if data.Rows and data.Rows[1] then
data.Rows[1].Weight = 999
modified = true
end
'''
files = ['items/**/*.json']
# JSON mode - remove array elements conditionally
[[commands]]
name = 'JSONRemoveDisabled'
json = true
lua = '''
if data.features then
local i = 1
while i <= #data.features do
if data.features[i].enabled == false then
table.remove(data.features, i)
else
i = i + 1
end
end
modified = true
end
'''
files = ['config/**/*.json']
# JSON mode - deep nested object manipulation
[[commands]]
name = 'JSONDeepUpdate'
json = true
lua = '''
if data.game and data.game.balance and data.game.balance.economy then
local econ = data.game.balance.economy
econ.inflation = (econ.inflation or 1.0) * 1.05
econ.taxRate = 0.15
econ.lastUpdate = os.date("%Y-%m-%d")
modified = true
end
'''
files = ['settings/**/*.json']
# JSON mode - iterate and transform all matching objects
[[commands]]
name = 'JSONTransformItems'
json = true
lua = '''
local function processItem(item)
if item.type == "weapon" and item.damage then
item.damage = item.damage * multiply
item.modified = true
end
end
if data.items then
for _, item in ipairs(data.items) do
processItem(item)
end
modified = true
elseif data.inventory then
for _, item in ipairs(data.inventory) do
processItem(item)
end
modified = true
end
'''
files = ['data/**/*.json']
# CSV processing example - read, modify, write
[[commands]]
name = 'CSVProcess'
regex = '(?P<csv>!any)'
lua = '''
local rows = fromCSV(csv, { hasheader = true })
for i, row in ipairs(rows) do
if row.Value then
row.Value = num(row.Value) * multiply
end
end
replacement = toCSV(rows, { hasheader = true })
return true
'''
files = ['data/**/*.csv']
# CSV processing with custom delimiter (TSV)
[[commands]]
name = 'TSVProcess'
regex = '(?P<tsv>!any)'
lua = '''
local rows = fromCSV(tsv, { delimiter = "\t", hasheader = true, hascomments = true })
for i, row in ipairs(rows) do
if row.Price then
row.Price = num(row.Price) * 1.1
end
end
replacement = toCSV(rows, { delimiter = "\t", hasheader = true })
return true
'''
files = ['data/**/*.tsv']
# CSV processing - modify specific columns
[[commands]]
name = 'CSVModifyColumns'
regex = '(?P<csv>!any)'
lua = '''
local rows = fromCSV(csv, { hasheader = true })
for i, row in ipairs(rows) do
if row.Name then
row.Name = prefix .. row.Name
end
if row.Status then
row.Status = upper(row.Status)
end
end
replacement = toCSV(rows, { hasheader = true })
return true
'''
files = ['exports/**/*.csv']
# XML mode - multiply numeric attributes using helper functions
[[commands]]
name = 'XMLMultiplyAttributes'
regex = '(?P<xml>!any)'
lua = '''
visitElements(data, function(elem)
if elem._tag == "Item" then
modifyNumAttr(elem, "Weight", function(val) return val * multiply end)
modifyNumAttr(elem, "Value", function(val) return val * foobar end)
end
end)
modified = true
'''
files = ['game/**/*.xml']
# XML mode - modify specific element attributes
[[commands]]
name = 'XMLUpdateAfflictions'
regex = '(?P<xml>!any)'
lua = '''
local afflictions = findElements(data, "Affliction")
for _, affliction in ipairs(afflictions) do
local id = getAttr(affliction, "identifier")
if id == "burn" or id == "bleeding" then
modifyNumAttr(affliction, "strength", function(val) return val * 0.5 end)
setAttr(affliction, "description", "Weakened effect")
end
end
modified = true
'''
files = ['config/Afflictions.xml']
# XML mode - add new elements using helpers
[[commands]]
name = 'XMLAddItems'
regex = '(?P<xml>!any)'
lua = '''
local items = findFirstElement(data, "Items")
if items then
local newItem = {
_tag = "Item",
_attr = {
identifier = "new_item",
Weight = "10",
Value = "500"
}
}
addChild(items, newItem)
modified = true
end
'''
files = ['items/**/*.xml']
# XML mode - remove elements by attribute value
[[commands]]
name = 'XMLRemoveDisabled'
regex = '(?P<xml>!any)'
lua = '''
visitElements(data, function(elem)
if elem._tag == "Feature" and getAttr(elem, "enabled") == "false" then
-- Mark for removal (actual removal happens via parent)
elem._remove = true
end
end)
-- Remove marked children
visitElements(data, function(elem)
if elem._children then
local i = 1
while i <= #elem._children do
if elem._children[i]._remove then
table.remove(elem._children, i)
else
i = i + 1
end
end
end
end)
modified = true
'''
files = ['config/**/*.xml']
# XML mode - conditional attribute updates based on other attributes
[[commands]]
name = 'XMLConditionalUpdate'
regex = '(?P<xml>!any)'
lua = '''
visitElements(data, function(elem)
if elem._tag == "Weapon" then
local tier = getAttr(elem, "tier")
if tier and num(tier) >= 3 then
-- High tier weapons get damage boost
modifyNumAttr(elem, "damage", function(val) return val * 1.5 end)
setAttr(elem, "rarity", "legendary")
end
end
end)
modified = true
'''
files = ['weapons/**/*.xml']
# XML mode - modify nested elements
[[commands]]
name = 'XMLNestedModify'
regex = '(?P<xml>!any)'
lua = '''
local config = findFirstElement(data, "Configuration")
if config then
local settings = findFirstElement(config, "Settings")
if settings then
setAttr(settings, "timeout", "120")
setAttr(settings, "maxRetries", "5")
-- Add or update nested element
local logging = findFirstElement(settings, "Logging")
if not logging then
logging = {
_tag = "Logging",
_attr = { level = "DEBUG", enabled = "true" }
}
addChild(settings, logging)
else
setAttr(logging, "level", "INFO")
end
end
end
modified = true
'''
files = ['config/**/*.xml']
# XML mode - batch attribute operations
[[commands]]
name = 'XMLBatchAttributeUpdate'
regex = '(?P<xml>!any)'
lua = '''
-- Update all Price attributes across entire document
visitElements(data, function(elem)
if hasAttr(elem, "Price") then
modifyNumAttr(elem, "Price", function(val) return val * 1.1 end)
end
if hasAttr(elem, "Cost") then
modifyNumAttr(elem, "Cost", function(val) return val * 0.9 end)
end
end)
modified = true
'''
files = ['economy/**/*.xml']
# XML mode - clone and modify elements
[[commands]]
name = 'XMLCloneItems'
regex = '(?P<xml>!any)'
lua = '''
local items = findElements(data, "Item")
local newItems = {}
for _, item in ipairs(items) do
local id = getAttr(item, "identifier")
if id and id:match("^weapon_") then
-- Clone weapon as upgraded version
local upgraded = {
_tag = "Item",
_attr = {
identifier = id .. "_mk2",
Weight = getAttr(item, "Weight"),
Value = tostring(num(getAttr(item, "Value")) * 2)
}
}
table.insert(newItems, upgraded)
end
end
-- Add all new items
for _, newItem in ipairs(newItems) do
addChild(data, newItem)
end
if #newItems > 0 then
modified = true
end
'''
files = ['items/**/*.xml']
# XML mode - remove all children with specific tag
[[commands]]
name = 'XMLRemoveObsolete'
regex = '(?P<xml>!any)'
lua = '''
visitElements(data, function(elem)
-- Remove all "Deprecated" children
removeChildren(elem, "Deprecated")
removeChildren(elem, "Legacy")
end)
modified = true
'''
files = ['config/**/*.xml']

View File

@@ -82,7 +82,7 @@ func TestGlobExpansion(t *testing.T) {
for _, pattern := range tc.patterns {
patternMap[pattern] = struct{}{}
}
files, err := utils.ExpandGLobs(patternMap)
files, err := utils.ExpandGlobs(patternMap)
if err != nil {
t.Fatalf("ExpandGLobs failed: %v", err)
}

4
go.mod
View File

@@ -12,7 +12,6 @@ require (
)
require (
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/hexops/valast v1.5.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -22,7 +21,6 @@ require (
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/spf13/cobra v1.10.1 // indirect
github.com/spf13/pflag v1.0.9 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
@@ -35,7 +33,9 @@ require (
)
require (
github.com/BurntSushi/toml v1.5.0
github.com/google/go-cmp v0.6.0
github.com/spf13/cobra v1.10.1
github.com/tidwall/gjson v1.18.0
gorm.io/driver/sqlite v1.6.0
)

View File

@@ -84,8 +84,8 @@ END`
assert.Len(t, association.Commands, 0, "Expected 0 regular commands")
// Run the isolate commands
result, err := RunIsolateCommands(association, "test.txt", testContent)
if err != nil && err != NothingToDo {
result, err := RunIsolateCommands(association, "test.txt", testContent, false)
if err != nil && err != ErrNothingToDo {
t.Fatalf("Failed to run isolate commands: %v", err)
}
@@ -162,8 +162,8 @@ END_SECTION2`
}
// Run the isolate commands
result, err := RunIsolateCommands(associations["test.txt"], "test.txt", testContent)
if err != nil && err != NothingToDo {
result, err := RunIsolateCommands(associations["test.txt"], "test.txt", testContent, false)
if err != nil && err != ErrNothingToDo {
t.Fatalf("Failed to run isolate commands: %v", err)
}
@@ -234,8 +234,8 @@ func TestIsolateCommandsWithJSONMode(t *testing.T) {
}
// Run the isolate commands
result, err := RunIsolateCommands(associations["test.json"], "test.json", testContent)
if err != nil && err != NothingToDo {
result, err := RunIsolateCommands(associations["test.json"], "test.json", testContent, false)
if err != nil && err != ErrNothingToDo {
t.Fatalf("Failed to run isolate commands: %v", err)
}
@@ -309,8 +309,8 @@ END_REGULAR`
assert.Len(t, association.Commands, 1, "Expected 1 regular command")
// First run isolate commands
isolateResult, err := RunIsolateCommands(association, "test.txt", testContent)
if err != nil && err != NothingToDo {
isolateResult, err := RunIsolateCommands(association, "test.txt", testContent, false)
if err != nil && err != ErrNothingToDo {
t.Fatalf("Failed to run isolate commands: %v", err)
}
@@ -320,8 +320,8 @@ END_REGULAR`
// Then run regular commands
commandLoggers := make(map[string]*logger.Logger)
finalResult, err := RunOtherCommands("test.txt", isolateResult, association, commandLoggers)
if err != nil && err != NothingToDo {
finalResult, err := RunOtherCommands("test.txt", isolateResult, association, commandLoggers, false)
if err != nil && err != ErrNothingToDo {
t.Fatalf("Failed to run regular commands: %v", err)
}
@@ -364,12 +364,12 @@ irons_spellbooks:chain_lightning
// Second command: targets all SpellPowerMultiplier with multiplier *4
commands := []utils.ModifyCommand{
{
Name: "healing",
Name: "healing",
Regexes: []string{
`irons_spellbooks:chain_creeper[\s\S]*?SpellPowerMultiplier = !num`,
`irons_spellbooks:chain_lightning[\s\S]*?SpellPowerMultiplier = !num`,
},
Lua: `v1 * 4`, // This should multiply by 4
Lua: `v1 * 4`, // This should multiply by 4
Files: []string{"irons_spellbooks-server.toml"},
Reset: true,
Isolate: true,
@@ -377,7 +377,7 @@ irons_spellbooks:chain_lightning
{
Name: "spellpower",
Regex: `SpellPowerMultiplier = !num`,
Lua: `v1 * 4`, // This should multiply by 4 again
Lua: `v1 * 4`, // This should multiply by 4 again
Files: []string{"irons_spellbooks-server.toml"},
Reset: true,
Isolate: true,
@@ -397,8 +397,8 @@ irons_spellbooks:chain_lightning
assert.Len(t, association.Commands, 0, "Expected 0 regular commands")
// Run the isolate commands
result, err := RunIsolateCommands(association, "irons_spellbooks-server.toml", testContent)
if err != nil && err != NothingToDo {
result, err := RunIsolateCommands(association, "irons_spellbooks-server.toml", testContent, false)
if err != nil && err != ErrNothingToDo {
t.Fatalf("Failed to run isolate commands: %v", err)
}
@@ -414,4 +414,4 @@ irons_spellbooks:chain_lightning
t.Logf("Original content:\n%s\n", testContent)
t.Logf("Result content:\n%s\n", result)
}
}

62
main.go
View File

@@ -4,6 +4,7 @@ import (
_ "embed"
"errors"
"os"
"path/filepath"
"sort"
"sync"
"sync/atomic"
@@ -12,8 +13,8 @@ import (
"cook/processor"
"cook/utils"
"github.com/spf13/cobra"
logger "git.site.quack-lab.dev/dave/cylogger"
"github.com/spf13/cobra"
)
//go:embed example_cook.toml
@@ -54,12 +55,30 @@ Features:
- Parallel file processing
- Command filtering and organization`,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
CreateExampleConfig()
logger.InitFlag()
mainLogger.Info("Initializing with log level: %s", logger.GetLevel().String())
mainLogger.Trace("Full argv: %v", os.Args)
},
Run: func(cmd *cobra.Command, args []string) {
exampleFlag, _ := cmd.Flags().GetBool("example")
if exampleFlag {
CreateExampleConfig()
return
}
metaFlag, _ := cmd.Flags().GetBool("meta")
if metaFlag {
cwd, err := os.Getwd()
if err != nil {
mainLogger.Error("Failed to get current directory: %v", err)
os.Exit(1)
}
metaPath := filepath.Join(cwd, "meta.lua")
if err := processor.GenerateMetaFile(metaPath); err != nil {
mainLogger.Error("Failed to generate meta.lua: %v", err)
os.Exit(1)
}
return
}
if len(args) == 0 {
cmd.Usage()
return
@@ -76,6 +95,8 @@ Features:
rootCmd.Flags().StringP("filter", "f", "", "Filter commands before running them")
rootCmd.Flags().Bool("json", false, "Enable JSON mode for processing JSON files")
rootCmd.Flags().BoolP("conv", "c", false, "Convert YAML files to TOML format")
rootCmd.Flags().BoolP("example", "e", false, "Generate example_cook.toml and exit")
rootCmd.Flags().BoolP("meta", "m", false, "Generate meta.lua file for LuaLS autocomplete and exit")
// Set up examples in the help text
rootCmd.SetUsageTemplate(`Usage:{{if .Runnable}}
@@ -185,29 +206,16 @@ func runModifier(args []string, cmd *cobra.Command) {
// Load all commands
mainLogger.Debug("Loading commands from arguments")
mainLogger.Trace("Arguments: %v", args)
commands, err := utils.LoadCommands(args)
commands, variables, err := utils.LoadCommands(args)
if err != nil || len(commands) == 0 {
mainLogger.Error("Failed to load commands: %v", err)
cmd.Usage()
return
}
// Collect global modifiers from special entries and filter them out
vars := map[string]interface{}{}
filtered := make([]utils.ModifyCommand, 0, len(commands))
for _, c := range commands {
if len(c.Modifiers) > 0 && c.Name == "" && c.Regex == "" && len(c.Regexes) == 0 && c.Lua == "" && len(c.Files) == 0 {
for k, v := range c.Modifiers {
vars[k] = v
}
continue
}
filtered = append(filtered, c)
if len(variables) > 0 {
mainLogger.Info("Loaded %d global variables", len(variables))
processor.SetVariables(variables)
}
if len(vars) > 0 {
mainLogger.Info("Loaded %d global modifiers", len(vars))
processor.SetVariables(vars)
}
commands = filtered
mainLogger.Info("Loaded %d commands", len(commands))
if filterFlag != "" {
@@ -238,7 +246,7 @@ func runModifier(args []string, cmd *cobra.Command) {
// Resolve all the files for all the globs
mainLogger.Info("Found %d unique file patterns", len(globs))
mainLogger.Debug("Expanding glob patterns to files")
files, err := utils.ExpandGLobs(globs)
files, err := utils.ExpandGlobs(globs)
if err != nil {
mainLogger.Error("Failed to expand file patterns: %v", err)
return
@@ -335,23 +343,23 @@ func runModifier(args []string, cmd *cobra.Command) {
isChanged := false
mainLogger.Debug("Running isolate commands for file %q", file)
fileDataStr, err = RunIsolateCommands(association, file, fileDataStr, jsonFlag)
if err != nil && err != NothingToDo {
if err != nil && err != ErrNothingToDo {
mainLogger.Error("Failed to run isolate commands for file %q: %v", file, err)
atomic.AddInt64(&stats.FailedFiles, 1)
return
}
if err != NothingToDo {
if err != ErrNothingToDo {
isChanged = true
}
mainLogger.Debug("Running other commands for file %q", file)
fileDataStr, err = RunOtherCommands(file, fileDataStr, association, commandLoggers, jsonFlag)
if err != nil && err != NothingToDo {
if err != nil && err != ErrNothingToDo {
mainLogger.Error("Failed to run other commands for file %q: %v", file, err)
atomic.AddInt64(&stats.FailedFiles, 1)
return
}
if err != NothingToDo {
if err != ErrNothingToDo {
isChanged = true
}
@@ -469,7 +477,7 @@ func CreateExampleConfig() {
createExampleConfigLogger.Info("Wrote example_cook.toml")
}
var NothingToDo = errors.New("nothing to do")
var ErrNothingToDo = errors.New("nothing to do")
func RunOtherCommands(file string, fileDataStr string, association utils.FileCommandAssociation, commandLoggers map[string]*logger.Logger, jsonFlag bool) (string, error) {
runOtherCommandsLogger := mainLogger.WithPrefix("RunOtherCommands").WithField("file", file)
@@ -560,7 +568,7 @@ func RunOtherCommands(file string, fileDataStr string, association utils.FileCom
if len(modifications) == 0 {
runOtherCommandsLogger.Warning("No modifications found for file")
return fileDataStr, NothingToDo
return fileDataStr, ErrNothingToDo
}
runOtherCommandsLogger.Debug("Executing %d modifications for file", len(modifications))
@@ -658,7 +666,7 @@ func RunIsolateCommands(association utils.FileCommandAssociation, file string, f
}
if !anythingDone {
runIsolateCommandsLogger.Debug("No isolate modifications were made for file")
return fileDataStr, NothingToDo
return fileDataStr, ErrNothingToDo
}
return currentFileData, nil
}

View File

@@ -1,3 +1,5 @@
// Package processor provides JSON processing and Lua script execution capabilities
// for data transformation and manipulation.
package processor
import (
@@ -19,9 +21,9 @@ var jsonLogger = logger.Default.WithPrefix("processor/json")
// ProcessJSON applies Lua processing to JSON content
func ProcessJSON(content string, command utils.ModifyCommand, filename string) ([]utils.ReplaceCommand, error) {
processJsonLogger := jsonLogger.WithPrefix("ProcessJSON").WithField("commandName", command.Name).WithField("file", filename)
processJsonLogger.Debug("Starting JSON processing for file")
processJsonLogger.Trace("Initial file content length: %d", len(content))
processJSONLogger := jsonLogger.WithPrefix("ProcessJSON").WithField("commandName", command.Name).WithField("file", filename)
processJSONLogger.Debug("Starting JSON processing for file")
processJSONLogger.Trace("Initial file content length: %d", len(content))
var commands []utils.ReplaceCommand
startTime := time.Now()
@@ -30,15 +32,15 @@ func ProcessJSON(content string, command utils.ModifyCommand, filename string) (
var jsonData interface{}
err := json.Unmarshal([]byte(content), &jsonData)
if err != nil {
processJsonLogger.Error("Failed to parse JSON content: %v", err)
processJSONLogger.Error("Failed to parse JSON content: %v", err)
return commands, fmt.Errorf("failed to parse JSON: %v", err)
}
processJsonLogger.Debug("Successfully parsed JSON content")
processJSONLogger.Debug("Successfully parsed JSON content")
// Create Lua state
L, err := NewLuaState()
if err != nil {
processJsonLogger.Error("Error creating Lua state: %v", err)
processJSONLogger.Error("Error creating Lua state: %v", err)
return commands, fmt.Errorf("error creating Lua state: %v", err)
}
defer L.Close()
@@ -49,70 +51,58 @@ func ProcessJSON(content string, command utils.ModifyCommand, filename string) (
// Convert JSON data to Lua table
luaTable, err := ToLuaTable(L, jsonData)
if err != nil {
processJsonLogger.Error("Failed to convert JSON to Lua table: %v", err)
processJSONLogger.Error("Failed to convert JSON to Lua table: %v", err)
return commands, fmt.Errorf("failed to convert JSON to Lua table: %v", err)
}
// Set the JSON data as a global variable
L.SetGlobal("data", luaTable)
processJsonLogger.Debug("Set JSON data as Lua global 'data'")
processJSONLogger.Debug("Set JSON data as Lua global 'data'")
// Build and execute Lua script for JSON mode
luaExpr := BuildJSONLuaScript(command.Lua)
processJsonLogger.Debug("Built Lua script from expression: %q", command.Lua)
processJsonLogger.Trace("Full Lua script: %q", utils.LimitString(luaExpr, 200))
processJSONLogger.Debug("Built Lua script from expression: %q", command.Lua)
processJSONLogger.Trace("Full Lua script: %q", utils.LimitString(luaExpr, 200))
if err := L.DoString(luaExpr); err != nil {
processJsonLogger.Error("Lua script execution failed: %v\nScript: %s", err, utils.LimitString(luaExpr, 200))
processJSONLogger.Error("Lua script execution failed: %v\nScript: %s", err, utils.LimitString(luaExpr, 200))
return commands, fmt.Errorf("lua script execution failed: %v", err)
}
processJsonLogger.Debug("Lua script executed successfully")
processJSONLogger.Debug("Lua script executed successfully")
// Check if modification flag is set
modifiedVal := L.GetGlobal("modified")
if modifiedVal.Type() != lua.LTBool || !lua.LVAsBool(modifiedVal) {
processJsonLogger.Debug("Skipping - no modifications indicated by Lua script")
processJSONLogger.Debug("Skipping - no modifications indicated by Lua script")
return commands, nil
}
// Get the modified data from Lua
modifiedData := L.GetGlobal("data")
if modifiedData.Type() != lua.LTTable {
processJsonLogger.Error("Expected 'data' to be a table after Lua processing, got %s", modifiedData.Type().String())
processJSONLogger.Error("Expected 'data' to be a table after Lua processing, got %s", modifiedData.Type().String())
return commands, fmt.Errorf("expected 'data' to be a table after Lua processing")
}
// Convert back to Go interface
goData, err := FromLua(L, modifiedData)
if err != nil {
processJsonLogger.Error("Failed to convert Lua table back to Go: %v", err)
processJSONLogger.Error("Failed to convert Lua table back to Go: %v", err)
return commands, fmt.Errorf("failed to convert Lua table back to Go: %v", err)
}
processJsonLogger.Debug("About to call applyChanges with original data and modified data")
processJSONLogger.Debug("About to call applyChanges with original data and modified data")
commands, err = applyChanges(content, jsonData, goData)
if err != nil {
processJsonLogger.Error("Failed to apply surgical JSON changes: %v", err)
processJSONLogger.Error("Failed to apply surgical JSON changes: %v", err)
return commands, fmt.Errorf("failed to apply surgical JSON changes: %v", err)
}
processJsonLogger.Debug("Total JSON processing time: %v", time.Since(startTime))
processJsonLogger.Debug("Generated %d total modifications", len(commands))
processJSONLogger.Debug("Total JSON processing time: %v", time.Since(startTime))
processJSONLogger.Debug("Generated %d total modifications", len(commands))
return commands, nil
}
// applyJSONChanges compares original and modified data and applies changes surgically
func applyJSONChanges(content string, originalData, modifiedData interface{}) ([]utils.ReplaceCommand, error) {
var commands []utils.ReplaceCommand
appliedCommands, err := applyChanges(content, originalData, modifiedData)
if err == nil && len(appliedCommands) > 0 {
return appliedCommands, nil
}
return commands, fmt.Errorf("failed to make any changes to the json")
}
// applyChanges attempts to make surgical changes while preserving exact formatting
func applyChanges(content string, originalData, modifiedData interface{}) ([]utils.ReplaceCommand, error) {
var commands []utils.ReplaceCommand
@@ -199,12 +189,10 @@ func applyChanges(content string, originalData, modifiedData interface{}) ([]uti
// Convert the new value to JSON string
newValueStr := convertValueToJSONString(newValue)
// Insert the new field with pretty-printed formatting
// Format: ,"fieldName": { ... }
insertText := fmt.Sprintf(`,"%s": %s`, fieldName, newValueStr)
commands = append(commands, utils.ReplaceCommand{
From: startPos,
@@ -343,7 +331,14 @@ func convertValueToJSONString(value interface{}) string {
// findArrayElementRemovalRange finds the exact byte range to remove for an array element
func findArrayElementRemovalRange(content, arrayPath string, elementIndex int) (int, int) {
// Get the array using gjson
arrayResult := gjson.Get(content, arrayPath)
var arrayResult gjson.Result
if arrayPath == "" {
// Root-level array
arrayResult = gjson.Parse(content)
} else {
arrayResult = gjson.Get(content, arrayPath)
}
if !arrayResult.Exists() || !arrayResult.IsArray() {
return -1, -1
}
@@ -437,8 +432,6 @@ func findDeepChanges(basePath string, original, modified interface{}) map[string
}
changes[currentPath] = nil // Mark for removal
}
} else {
// Elements added - more complex, skip for now
}
} else {
// Same length - check individual elements for value changes
@@ -469,16 +462,9 @@ func findDeepChanges(basePath string, original, modified interface{}) map[string
}
}
}
default:
// For primitive types, compare directly
if !deepEqual(original, modified) {
if basePath == "" {
changes[""] = modified
} else {
changes[basePath] = modified
}
}
}
// Note: No default case needed - JSON data from unmarshaling is always
// map[string]interface{} or []interface{} at the top level
return changes
}
@@ -545,112 +531,61 @@ func deepEqual(a, b interface{}) bool {
}
}
// ToLuaTable converts a Go interface{} to a Lua table recursively
// ToLuaTable converts a Go interface{} (map or array) to a Lua table
// This should only be called with map[string]interface{} or []interface{} from JSON unmarshaling
func ToLuaTable(L *lua.LState, data interface{}) (*lua.LTable, error) {
toLuaTableLogger := jsonLogger.WithPrefix("ToLuaTable")
toLuaTableLogger.Debug("Converting Go interface to Lua table")
toLuaTableLogger.Trace("Input data type: %T", data)
switch v := data.(type) {
case map[string]interface{}:
toLuaTableLogger.Debug("Converting map to Lua table")
table := L.CreateTable(0, len(v))
for key, value := range v {
luaValue, err := ToLuaValue(L, value)
if err != nil {
toLuaTableLogger.Error("Failed to convert map value for key %q: %v", key, err)
return nil, err
}
table.RawSetString(key, luaValue)
table.RawSetString(key, ToLuaValue(L, value))
}
return table, nil
case []interface{}:
toLuaTableLogger.Debug("Converting slice to Lua table")
table := L.CreateTable(len(v), 0)
for i, value := range v {
luaValue, err := ToLuaValue(L, value)
if err != nil {
toLuaTableLogger.Error("Failed to convert slice value at index %d: %v", i, err)
return nil, err
}
table.RawSetInt(i+1, luaValue) // Lua arrays are 1-indexed
table.RawSetInt(i+1, ToLuaValue(L, value)) // Lua arrays are 1-indexed
}
return table, nil
case string:
toLuaTableLogger.Debug("Converting string to Lua string")
return nil, fmt.Errorf("expected table or array, got string")
case float64:
toLuaTableLogger.Debug("Converting float64 to Lua number")
return nil, fmt.Errorf("expected table or array, got number")
case bool:
toLuaTableLogger.Debug("Converting bool to Lua boolean")
return nil, fmt.Errorf("expected table or array, got boolean")
case nil:
toLuaTableLogger.Debug("Converting nil to Lua nil")
return nil, fmt.Errorf("expected table or array, got nil")
default:
toLuaTableLogger.Error("Unsupported type for Lua table conversion: %T", v)
return nil, fmt.Errorf("unsupported type for Lua table conversion: %T", v)
// This should only happen with invalid JSON (root-level primitives)
return nil, fmt.Errorf("expected table or array, got %T", v)
}
}
// ToLuaValue converts a Go interface{} to a Lua value
func ToLuaValue(L *lua.LState, data interface{}) (lua.LValue, error) {
toLuaValueLogger := jsonLogger.WithPrefix("ToLuaValue")
toLuaValueLogger.Debug("Converting Go interface to Lua value")
toLuaValueLogger.Trace("Input data type: %T", data)
func ToLuaValue(L *lua.LState, data interface{}) lua.LValue {
switch v := data.(type) {
case map[string]interface{}:
toLuaValueLogger.Debug("Converting map to Lua table")
table := L.CreateTable(0, len(v))
for key, value := range v {
luaValue, err := ToLuaValue(L, value)
if err != nil {
toLuaValueLogger.Error("Failed to convert map value for key %q: %v", key, err)
return lua.LNil, err
}
table.RawSetString(key, luaValue)
table.RawSetString(key, ToLuaValue(L, value))
}
return table, nil
return table
case []interface{}:
toLuaValueLogger.Debug("Converting slice to Lua table")
table := L.CreateTable(len(v), 0)
for i, value := range v {
luaValue, err := ToLuaValue(L, value)
if err != nil {
toLuaValueLogger.Error("Failed to convert slice value at index %d: %v", i, err)
return lua.LNil, err
}
table.RawSetInt(i+1, luaValue) // Lua arrays are 1-indexed
table.RawSetInt(i+1, ToLuaValue(L, value)) // Lua arrays are 1-indexed
}
return table, nil
return table
case string:
toLuaValueLogger.Debug("Converting string to Lua string")
return lua.LString(v), nil
return lua.LString(v)
case float64:
toLuaValueLogger.Debug("Converting float64 to Lua number")
return lua.LNumber(v), nil
return lua.LNumber(v)
case bool:
toLuaValueLogger.Debug("Converting bool to Lua boolean")
return lua.LBool(v), nil
return lua.LBool(v)
case nil:
toLuaValueLogger.Debug("Converting nil to Lua nil")
return lua.LNil, nil
return lua.LNil
default:
toLuaValueLogger.Error("Unsupported type for Lua value conversion: %T", v)
return lua.LNil, fmt.Errorf("unsupported type for Lua value conversion: %T", v)
// This should never happen with JSON-unmarshaled data
return lua.LNil
}
}

View File

@@ -0,0 +1,283 @@
package processor
import (
"cook/utils"
"testing"
"github.com/stretchr/testify/assert"
)
// TestJSONFloat tests line 298 - float formatting for non-integer floats
func TestJSONFloatFormatting(t *testing.T) {
jsonContent := `{
"value": 10.5,
"another": 3.14159
}`
command := utils.ModifyCommand{
Name: "test_float",
JSON: true,
Lua: `
data.value = data.value * 2
data.another = data.another * 10
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "21") // 10.5 * 2
assert.Contains(t, result, "31.4159") // 3.14159 * 10
}
// TestJSONNestedObjectAddition tests lines 303-320 - map[string]interface{} case
func TestJSONNestedObjectAddition(t *testing.T) {
jsonContent := `{
"items": {}
}`
command := utils.ModifyCommand{
Name: "test_nested",
JSON: true,
Lua: `
data.items.newObject = {
name = "test",
value = 42,
enabled = true
}
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"newObject"`)
assert.Contains(t, result, `"name"`)
assert.Contains(t, result, `"test"`)
assert.Contains(t, result, `"value"`)
assert.Contains(t, result, "42")
}
// TestJSONKeyWithQuotes tests line 315 - key escaping with quotes
func TestJSONKeyWithQuotes(t *testing.T) {
jsonContent := `{
"data": {}
}`
command := utils.ModifyCommand{
Name: "test_key_quotes",
JSON: true,
Lua: `
data.data["key-with-dash"] = "value1"
data.data.normalKey = "value2"
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"key-with-dash"`)
assert.Contains(t, result, `"normalKey"`)
}
// TestJSONArrayInValue tests lines 321-327 - default case with json.Marshal for arrays
func TestJSONArrayInValue(t *testing.T) {
jsonContent := `{
"data": {}
}`
command := utils.ModifyCommand{
Name: "test_array_value",
JSON: true,
Lua: `
data.data.items = {1, 2, 3, 4, 5}
data.data.strings = {"a", "b", "c"}
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"items"`)
assert.Contains(t, result, `[1,2,3,4,5]`)
assert.Contains(t, result, `"strings"`)
assert.Contains(t, result, `["a","b","c"]`)
}
// TestJSONRootArrayElementRemoval tests line 422 - removing from root-level array
func TestJSONRootArrayElementRemoval(t *testing.T) {
jsonContent := `[
{"id": 1, "name": "first"},
{"id": 2, "name": "second"},
{"id": 3, "name": "third"}
]`
command := utils.ModifyCommand{
Name: "test_root_array_removal",
JSON: true,
Lua: `
-- Remove the second element
table.remove(data, 2)
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"first"`)
assert.Contains(t, result, `"third"`)
assert.NotContains(t, result, `"second"`)
}
// TestJSONRootArrayElementChange tests lines 434 and 450 - changing primitive values in root array
func TestJSONRootArrayElementChange(t *testing.T) {
jsonContent := `[10, 20, 30, 40, 50]`
command := utils.ModifyCommand{
Name: "test_root_array_change",
JSON: true,
Lua: `
-- Double all values
for i = 1, #data do
data[i] = data[i] * 2
end
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "20")
assert.Contains(t, result, "40")
assert.Contains(t, result, "60")
assert.Contains(t, result, "80")
assert.Contains(t, result, "100")
assert.NotContains(t, result, "10,")
}
// TestJSONRootArrayStringElements tests deepEqual with strings in root array
func TestJSONRootArrayStringElements(t *testing.T) {
jsonContent := `["apple", "banana", "cherry"]`
command := utils.ModifyCommand{
Name: "test_root_array_strings",
JSON: true,
Lua: `
data[2] = "orange"
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"apple"`)
assert.Contains(t, result, `"orange"`)
assert.Contains(t, result, `"cherry"`)
assert.NotContains(t, result, `"banana"`)
}
// TestJSONComplexNestedStructure tests multiple untested paths together
func TestJSONComplexNestedStructure(t *testing.T) {
jsonContent := `{
"config": {
"multiplier": 2.5
}
}`
command := utils.ModifyCommand{
Name: "test_complex",
JSON: true,
Lua: `
-- Add nested object with array
data.config.settings = {
enabled = true,
values = {1.5, 2.5, 3.5},
names = {"alpha", "beta"}
}
-- Change float
data.config.multiplier = 7.777
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "7.777")
assert.Contains(t, result, `"settings"`)
assert.Contains(t, result, `"values"`)
assert.Contains(t, result, `[1.5,2.5,3.5]`)
}
// TestJSONRemoveFirstArrayElement tests line 358-365 - removing first element with comma handling
func TestJSONRemoveFirstArrayElement(t *testing.T) {
jsonContent := `{
"items": [1, 2, 3, 4, 5]
}`
command := utils.ModifyCommand{
Name: "test_remove_first",
JSON: true,
Lua: `
table.remove(data.items, 1)
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.NotContains(t, result, "[1,")
assert.Contains(t, result, "2")
assert.Contains(t, result, "5")
}
// TestJSONRemoveLastArrayElement tests line 366-374 - removing last element with comma handling
func TestJSONRemoveLastArrayElement(t *testing.T) {
jsonContent := `{
"items": [1, 2, 3, 4, 5]
}`
command := utils.ModifyCommand{
Name: "test_remove_last",
JSON: true,
Lua: `
table.remove(data.items, 5)
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "1")
assert.Contains(t, result, "4")
assert.NotContains(t, result, ", 5")
}

View File

@@ -0,0 +1,153 @@
package processor
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDeepEqual(t *testing.T) {
tests := []struct {
name string
a interface{}
b interface{}
expected bool
}{
{
name: "both nil",
a: nil,
b: nil,
expected: true,
},
{
name: "first nil",
a: nil,
b: "something",
expected: false,
},
{
name: "second nil",
a: "something",
b: nil,
expected: false,
},
{
name: "equal primitives",
a: 42,
b: 42,
expected: true,
},
{
name: "different primitives",
a: 42,
b: 43,
expected: false,
},
{
name: "equal strings",
a: "hello",
b: "hello",
expected: true,
},
{
name: "equal maps",
a: map[string]interface{}{
"key1": "value1",
"key2": 42,
},
b: map[string]interface{}{
"key1": "value1",
"key2": 42,
},
expected: true,
},
{
name: "maps different lengths",
a: map[string]interface{}{
"key1": "value1",
},
b: map[string]interface{}{
"key1": "value1",
"key2": 42,
},
expected: false,
},
{
name: "maps different values",
a: map[string]interface{}{
"key1": "value1",
},
b: map[string]interface{}{
"key1": "value2",
},
expected: false,
},
{
name: "map vs non-map",
a: map[string]interface{}{
"key1": "value1",
},
b: "not a map",
expected: false,
},
{
name: "equal arrays",
a: []interface{}{1, 2, 3},
b: []interface{}{1, 2, 3},
expected: true,
},
{
name: "arrays different lengths",
a: []interface{}{1, 2},
b: []interface{}{1, 2, 3},
expected: false,
},
{
name: "arrays different values",
a: []interface{}{1, 2, 3},
b: []interface{}{1, 2, 4},
expected: false,
},
{
name: "array vs non-array",
a: []interface{}{1, 2, 3},
b: "not an array",
expected: false,
},
{
name: "nested equal structures",
a: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 3},
},
},
b: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 3},
},
},
expected: true,
},
{
name: "nested different structures",
a: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 3},
},
},
b: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 4},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := deepEqual(tt.a, tt.b)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -88,8 +88,7 @@ func TestToLuaValue(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ToLuaValue(L, tt.input)
assert.NoError(t, err)
result := ToLuaValue(L, tt.input)
assert.Equal(t, tt.expected, result.String())
})
}

View File

@@ -0,0 +1,43 @@
-- Load the helper script
dofile("luahelper.lua")
-- Test helper function
local function assert(condition, message)
if not condition then error("ASSERTION FAILED: " .. (message or "unknown error")) end
end
local function test(name, fn)
local ok, err = pcall(fn)
if ok then
print("PASS: " .. name)
else
print("FAIL: " .. name .. " - " .. tostring(err))
end
end
test("regression test 001", function()
local csv =
[[Id Enabled ModuleId DepartmentId IsDepartment PositionInGraph Parents Modifiers UpgradePrice
news_department TRUE navigation TRUE 2 0 NewsAnalyticsDepartment + 1 communication_relay communication_relay
nd_charge_bonus TRUE navigation news_department FALSE 1 0 news_department NDSkillChargeBonus + 1 expert_disk expert_disk
nd_cooldown_time_reduce TRUE navigation news_department FALSE 3 0 news_department NDCooldownTimeReduce - 2 communication_relay communication_relay]]
local rows, err = fromCSV(csv, { delimiter = "\t", hasheader = true, hascomments = true })
if err then error("fromCSV error: " .. err) end
assert(#rows == 3, "Should have 3 rows")
assert(rows[1].Id == "news_department", "First row Id should be 'news_department'")
assert(rows[1].Enabled == "TRUE", "First row Enabled should be 'TRUE'")
assert(rows[1].ModuleId == "navigation", "First row ModuleId should be 'navigation'")
assert(rows[1].DepartmentId == "", "First row DepartmentId should be ''")
assert(rows[1].IsDepartment == "TRUE", "First row IsDepartment should be 'TRUE'")
assert(rows.Headers[1] == "Id", "First row Headers should be 'Id'")
assert(rows.Headers[2] == "Enabled", "First row Headers should be 'Enabled'")
assert(rows.Headers[3] == "ModuleId", "First row Headers should be 'ModuleId'")
assert(rows.Headers[4] == "DepartmentId", "First row Headers should be 'DepartmentId'")
assert(rows.Headers[5] == "IsDepartment", "First row Headers should be 'IsDepartment'")
assert(rows.Headers[6] == "PositionInGraph", "First row Headers should be 'PositionInGraph'")
assert(rows.Headers[7] == "Parents", "First row Headers should be 'Parents'")
assert(rows.Headers[8] == "Modifiers", "First row Headers should be 'Modifiers'")
assert(rows.Headers[9] == "UpgradePrice", "First row Headers should be 'UpgradePrice'")
end)
print("\nAll tests completed!")

View File

@@ -0,0 +1,224 @@
-- Load the helper script
dofile("luahelper.lua")
-- Test helper function
local function assert(condition, message)
if not condition then error("ASSERTION FAILED: " .. (message or "unknown error")) end
end
local function test(name, fn)
local ok, err = pcall(fn)
if ok then
print("PASS: " .. name)
else
print("FAIL: " .. name .. " - " .. tostring(err))
end
end
-- Test findElements
test("findElements finds all matching elements recursively", function()
local testXML = {
_tag = "root",
_children = {
{ _tag = "item", _attr = { name = "sword" } },
{ _tag = "item", _attr = { name = "shield" } },
{
_tag = "container",
_children = {
{ _tag = "item", _attr = { name = "potion" } },
},
},
},
}
local items = findElements(testXML, "item")
assert(#items == 3, "Should find 3 items total (recursive)")
assert(items[1]._attr.name == "sword", "First item should be sword")
assert(items[3]._attr.name == "potion", "Third item should be potion (from nested)")
end)
-- Test getNumAttr and setNumAttr
test("getNumAttr gets numeric attribute", function()
local elem = { _tag = "item", _attr = { damage = "10" } }
local damage = getNumAttr(elem, "damage")
assert(damage == 10, "Should get damage as number")
end)
test("getNumAttr returns nil for missing attribute", function()
local elem = { _tag = "item", _attr = {} }
local damage = getNumAttr(elem, "damage")
assert(damage == nil, "Should return nil for missing attribute")
end)
test("setNumAttr sets numeric attribute", function()
local elem = { _tag = "item", _attr = {} }
setNumAttr(elem, "damage", 20)
assert(elem._attr.damage == "20", "Should set damage as string")
end)
-- Test modifyNumAttr
test("modifyNumAttr modifies numeric attribute", function()
local elem = { _tag = "item", _attr = { weight = "5.5" } }
local modified = modifyNumAttr(elem, "weight", function(val) return val * 2 end)
assert(modified == true, "Should return true when modified")
assert(elem._attr.weight == "11.0", "Should double weight")
end)
test("modifyNumAttr returns false for missing attribute", function()
local elem = { _tag = "item", _attr = {} }
local modified = modifyNumAttr(elem, "weight", function(val) return val * 2 end)
assert(modified == false, "Should return false when attribute missing")
end)
-- Test filterElements
test("filterElements filters by predicate", function()
local testXML = {
_tag = "root",
_children = {
{ _tag = "item", _attr = { healing = "20" } },
{ _tag = "item", _attr = { damage = "10" } },
{ _tag = "item", _attr = { healing = "50" } },
},
}
local healingItems = filterElements(testXML, function(elem) return hasAttr(elem, "healing") end)
assert(#healingItems == 2, "Should find 2 healing items")
end)
-- Test visitElements
test("visitElements visits all elements", function()
local testXML = {
_tag = "root",
_children = {
{ _tag = "item" },
{ _tag = "container", _children = {
{ _tag = "item" },
} },
},
}
local count = 0
visitElements(testXML, function(elem) count = count + 1 end)
assert(count == 4, "Should visit 4 elements (root + 2 items + container)")
end)
-- Test getText and setText
test("getText gets text content", function()
local elem = { _tag = "item", _text = "Iron Sword" }
local text = getText(elem)
assert(text == "Iron Sword", "Should get text content")
end)
test("setText sets text content", function()
local elem = { _tag = "item" }
setText(elem, "New Text")
assert(elem._text == "New Text", "Should set text content")
end)
-- Test hasAttr and getAttr
test("hasAttr checks attribute existence", function()
local elem = { _tag = "item", _attr = { damage = "10" } }
assert(hasAttr(elem, "damage") == true, "Should have damage")
assert(hasAttr(elem, "magic") == false, "Should not have magic")
end)
test("getAttr gets attribute value", function()
local elem = { _tag = "item", _attr = { name = "sword" } }
assert(getAttr(elem, "name") == "sword", "Should get name attribute")
assert(getAttr(elem, "missing") == nil, "Should return nil for missing")
end)
test("setAttr sets attribute value", function()
local elem = { _tag = "item" }
setAttr(elem, "name", "sword")
assert(elem._attr.name == "sword", "Should set attribute")
end)
-- Test findFirstElement
test("findFirstElement finds first direct child", function()
local parent = {
_tag = "root",
_children = {
{ _tag = "item", _attr = { id = "1" } },
{ _tag = "item", _attr = { id = "2" } },
},
}
local first = findFirstElement(parent, "item")
assert(first._attr.id == "1", "Should find first item")
end)
test("findFirstElement returns nil when not found", function()
local parent = { _tag = "root", _children = {} }
local result = findFirstElement(parent, "item")
assert(result == nil, "Should return nil when not found")
end)
-- Test getChildren
test("getChildren gets all direct children with tag", function()
local parent = {
_tag = "root",
_children = {
{ _tag = "item", _attr = { id = "1" } },
{ _tag = "config" },
{ _tag = "item", _attr = { id = "2" } },
},
}
local items = getChildren(parent, "item")
assert(#items == 2, "Should get 2 items")
assert(items[1]._attr.id == "1", "First should have id=1")
assert(items[2]._attr.id == "2", "Second should have id=2")
end)
-- Test countChildren
test("countChildren counts direct children with tag", function()
local parent = {
_tag = "root",
_children = {
{ _tag = "item" },
{ _tag = "config" },
{ _tag = "item" },
},
}
assert(countChildren(parent, "item") == 2, "Should count 2 items")
assert(countChildren(parent, "config") == 1, "Should count 1 config")
end)
-- Test addChild
test("addChild adds child element", function()
local parent = { _tag = "root", _children = {} }
addChild(parent, { _tag = "item" })
assert(#parent._children == 1, "Should have 1 child")
assert(parent._children[1]._tag == "item", "Child should be item")
end)
test("addChild creates children array if needed", function()
local parent = { _tag = "root" }
addChild(parent, { _tag = "item" })
assert(parent._children ~= nil, "Should create _children")
assert(#parent._children == 1, "Should have 1 child")
end)
-- Test removeChildren
test("removeChildren removes all matching children", function()
local parent = {
_tag = "root",
_children = {
{ _tag = "item" },
{ _tag = "config" },
{ _tag = "item" },
},
}
local removed = removeChildren(parent, "item")
assert(removed == 2, "Should remove 2 items")
assert(#parent._children == 1, "Should have 1 child left")
assert(parent._children[1]._tag == "config", "Remaining should be config")
end)
test("removeChildren returns 0 when none found", function()
local parent = {
_tag = "root",
_children = { { _tag = "item" } },
}
local removed = removeChildren(parent, "config")
assert(removed == 0, "Should remove 0")
assert(#parent._children == 1, "Should still have 1 child")
end)
print("\nAll tests completed!")

View File

@@ -0,0 +1,534 @@
-- Load the helper script
dofile("luahelper.lua")
-- Test helper function
local function assert(condition, message)
if not condition then error("ASSERTION FAILED: " .. (message or "unknown error")) end
end
local function test(name, fn)
local ok, err = pcall(fn)
if ok then
print("PASS: " .. name)
else
print("FAIL: " .. name .. " - " .. tostring(err))
end
end
-- Test fromCSV option validation
test("fromCSV invalid option", function()
local csv = "a,b,c\n1,2,3"
local ok, errMsg = pcall(function() fromCSV(csv, { invalidOption = true }) end)
assert(ok == false, "Should raise error")
assert(string.find(errMsg, "unknown option"), "Error should mention unknown option")
end)
-- Test toCSV invalid delimiter
test("toCSV invalid delimiter", function()
local rows = { { "a", "b", "c" } }
local csv = toCSV(rows, { delimiter = 123 })
-- toCSV converts delimiter to string, so 123 becomes "123"
assert(csv == "a123b123c", "Should convert delimiter to string")
end)
-- Test fromCSV basic parsing
test("fromCSV basic", function()
local csv = "a,b,c\n1,2,3\n4,5,6"
local rows = fromCSV(csv)
assert(#rows == 3, "Should have 3 rows")
assert(rows[1][1] == "a", "First row first field should be 'a'")
assert(rows[2][2] == "2", "Second row second field should be '2'")
end)
-- Test fromCSV with headers
test("fromCSV with headers", function()
local csv = "foo,bar,baz\n1,2,3\n4,5,6"
local rows = fromCSV(csv, { hasheader = true })
assert(#rows == 2, "Should have 2 data rows")
assert(rows[1][1] == "1", "First row first field should be '1'")
assert(rows[1].foo == "1", "First row foo should be '1'")
assert(rows[1].bar == "2", "First row bar should be '2'")
assert(rows[1].baz == "3", "First row baz should be '3'")
end)
-- Test fromCSV with custom delimiter
test("fromCSV with tab delimiter", function()
local csv = "a\tb\tc\n1\t2\t3"
local rows = fromCSV(csv, { delimiter = "\t" })
assert(#rows == 2, "Should have 2 rows")
assert(rows[1][1] == "a", "First row first field should be 'a'")
assert(rows[2][2] == "2", "Second row second field should be '2'")
end)
-- Test fromCSV with quoted fields
test("fromCSV with quoted fields", function()
local csv = '"hello,world","test"\n"foo","bar"'
local rows = fromCSV(csv)
assert(#rows == 2, "Should have 2 rows")
assert(rows[1][1] == "hello,world", "Quoted field with comma should be preserved")
assert(rows[1][2] == "test", "Second field should be 'test'")
end)
-- Test toCSV basic
test("toCSV basic", function()
local rows = { { "a", "b", "c" }, { "1", "2", "3" } }
local csv = toCSV(rows)
assert(csv == "a,b,c\n1,2,3", "CSV output should match expected")
end)
-- Test toCSV with custom delimiter
test("toCSV with tab delimiter", function()
local rows = { { "a", "b", "c" }, { "1", "2", "3" } }
local csv = toCSV(rows, { delimiter = "\t" })
assert(csv == "a\tb\tc\n1\t2\t3", "TSV output should match expected")
end)
-- Test toCSV with fields needing quoting
test("toCSV with quoted fields", function()
local rows = { { "hello,world", "test" }, { "foo", "bar" } }
local csv = toCSV(rows)
assert(csv == '"hello,world",test\nfoo,bar', "Fields with commas should be quoted")
end)
-- Test round trip
test("fromCSV toCSV round trip", function()
local original = "a,b,c\n1,2,3\n4,5,6"
local rows = fromCSV(original)
local csv = toCSV(rows)
assert(csv == original, "Round trip should preserve original")
end)
-- Test round trip with headers
test("fromCSV toCSV round trip with headers", function()
local original = "foo,bar,baz\n1,2,3\n4,5,6"
local rows = fromCSV(original, { hasheader = true })
local csv = toCSV(rows)
local expected = "1,2,3\n4,5,6"
assert(csv == expected, "Round trip with headers should preserve data rows")
end)
-- Test fromCSV with comments
test("fromCSV with comments", function()
local csv = "# This is a comment\nfoo,bar,baz\n1,2,3\n# Another comment\n4,5,6"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 3, "Should have 3 rows (comments filtered, header + 2 data rows)")
assert(rows[1][1] == "foo", "First row should be header row")
assert(rows[2][1] == "1", "Second row first field should be '1'")
assert(rows[3][1] == "4", "Third row first field should be '4'")
end)
-- Test fromCSV with comments and headers
test("fromCSV with comments and headers", function()
local csv = "#mercenary_profiles\nId,Name,Value\n1,Test,100\n# End of data\n2,Test2,200"
local rows = fromCSV(csv, { hasheader = true, hascomments = true })
assert(#rows == 2, "Should have 2 data rows")
assert(rows[1].Id == "1", "First row Id should be '1'")
assert(rows[1].Name == "Test", "First row Name should be 'Test'")
assert(rows[1].Value == "100", "First row Value should be '100'")
assert(rows[2].Id == "2", "Second row Id should be '2'")
end)
-- Test fromCSV with comments disabled
test("fromCSV without comments", function()
local csv = "# This should not be filtered\nfoo,bar\n1,2"
local rows = fromCSV(csv, { hascomments = false })
assert(#rows == 3, "Should have 3 rows (including comment)")
assert(rows[1][1] == "# This should not be filtered", "Comment line should be preserved")
end)
-- Test fromCSV with comment at start
test("fromCSV comment at start", function()
local csv = "# Header comment\nId,Name\n1,Test"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 2, "Should have 2 rows (comment filtered)")
assert(rows[1][1] == "Id", "First row should be header")
end)
-- Test fromCSV with comment with leading whitespace
test("fromCSV comment with whitespace", function()
local csv = " # Comment with spaces\nId,Name\n1,Test"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 2, "Should have 2 rows (comment with spaces filtered)")
assert(rows[1][1] == "Id", "First row should be header")
end)
-- Test fromCSV with comment with tabs
test("fromCSV comment with tabs", function()
local csv = "\t# Comment with tab\nId,Name\n1,Test"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 2, "Should have 2 rows (comment with tab filtered)")
assert(rows[1][1] == "Id", "First row should be header")
end)
-- Test fromCSV with multiple consecutive comments
test("fromCSV multiple consecutive comments", function()
local csv = "# First comment\n# Second comment\n# Third comment\nId,Name\n1,Test"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 2, "Should have 2 rows (all comments filtered)")
assert(rows[1][1] == "Id", "First row should be header")
end)
-- Test fromCSV with comment in middle of data
test("fromCSV comment in middle", function()
local csv = "Id,Name\n1,Test\n# Middle comment\n2,Test2"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 3, "Should have 3 rows (comment filtered)")
assert(rows[1][1] == "Id", "First row should be header")
assert(rows[2][1] == "1", "Second row should be first data")
assert(rows[3][1] == "2", "Third row should be second data")
end)
-- Test fromCSV with comment at end
test("fromCSV comment at end", function()
local csv = "Id,Name\n1,Test\n# End comment"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 2, "Should have 2 rows (end comment filtered)")
assert(rows[1][1] == "Id", "First row should be header")
assert(rows[2][1] == "1", "Second row should be data")
end)
-- Test fromCSV with empty comment line
test("fromCSV empty comment", function()
local csv = "#\nId,Name\n1,Test"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 2, "Should have 2 rows (empty comment filtered)")
assert(rows[1][1] == "Id", "First row should be header")
end)
-- Test fromCSV with comment and headers
test("fromCSV comment with headers enabled", function()
local csv = "#mercenary_profiles\nId,Name,Value\n1,Test,100\n2,Test2,200"
local rows = fromCSV(csv, { hasheader = true, hascomments = true })
assert(#rows == 2, "Should have 2 data rows")
assert(rows[1].Id == "1", "First row Id should be '1'")
assert(rows[1].Name == "Test", "First row Name should be 'Test'")
assert(rows[2].Id == "2", "Second row Id should be '2'")
end)
-- Test fromCSV with comment and TSV delimiter
test("fromCSV comment with tab delimiter", function()
local csv = "# Comment\nId\tName\n1\tTest"
local rows = fromCSV(csv, { delimiter = "\t", hascomments = true })
assert(#rows == 2, "Should have 2 rows")
assert(rows[1][1] == "Id", "First row should be header")
assert(rows[2][1] == "1", "Second row first field should be '1'")
end)
-- Test fromCSV with comment and headers and TSV
test("fromCSV comment with headers and TSV", function()
local csv = "#mercenary_profiles\nId\tName\tValue\n1\tTest\t100"
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true, hascomments = true })
assert(#rows == 1, "Should have 1 data row")
assert(rows[1].Id == "1", "Row Id should be '1'")
assert(rows[1].Name == "Test", "Row Name should be 'Test'")
assert(rows[1].Value == "100", "Row Value should be '100'")
end)
-- Test fromCSV with data field starting with # (not a comment)
test("fromCSV data field starting with hash", function()
local csv = "Id,Name\n1,#NotAComment\n2,Test"
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 3, "Should have 3 rows (data with # not filtered)")
assert(rows[1][1] == "Id", "First row should be header")
assert(rows[2][2] == "#NotAComment", "Second row should have #NotAComment as data")
end)
-- Test fromCSV with quoted field starting with #
test("fromCSV quoted field with hash", function()
local csv = 'Id,Name\n1,"#NotAComment"\n2,Test'
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 3, "Should have 3 rows (quoted # not filtered)")
assert(rows[2][2] == "#NotAComment", "Quoted field with # should be preserved")
end)
-- Test fromCSV with comment after quoted field
test("fromCSV comment after quoted field", function()
local csv = 'Id,Name\n1,"Test"\n# This is a comment\n2,Test2'
local rows = fromCSV(csv, { hascomments = true })
assert(#rows == 3, "Should have 3 rows (comment filtered)")
assert(rows[2][2] == "Test", "Quoted field should be preserved")
assert(rows[3][1] == "2", "Third row should be second data row")
end)
-- Math function tests
test("min function", function()
assert(min(5, 3) == 3, "min(5, 3) should be 3")
assert(min(-1, 0) == -1, "min(-1, 0) should be -1")
assert(min(10, 10) == 10, "min(10, 10) should be 10")
end)
test("max function", function()
assert(max(5, 3) == 5, "max(5, 3) should be 5")
assert(max(-1, 0) == 0, "max(-1, 0) should be 0")
assert(max(10, 10) == 10, "max(10, 10) should be 10")
end)
test("round function", function()
assert(round(3.14159) == 3, "round(3.14159) should be 3")
assert(round(3.14159, 2) == 3.14, "round(3.14159, 2) should be 3.14")
assert(round(3.5) == 4, "round(3.5) should be 4")
assert(round(3.4) == 3, "round(3.4) should be 3")
assert(round(123.456, 1) == 123.5, "round(123.456, 1) should be 123.5")
end)
test("floor function", function()
assert(floor(3.7) == 3, "floor(3.7) should be 3")
assert(floor(-3.7) == -4, "floor(-3.7) should be -4")
assert(floor(5) == 5, "floor(5) should be 5")
end)
test("ceil function", function()
assert(ceil(3.2) == 4, "ceil(3.2) should be 4")
assert(ceil(-3.2) == -3, "ceil(-3.2) should be -3")
assert(ceil(5) == 5, "ceil(5) should be 5")
end)
-- String function tests
test("upper function", function()
assert(upper("hello") == "HELLO", "upper('hello') should be 'HELLO'")
assert(upper("Hello World") == "HELLO WORLD", "upper('Hello World') should be 'HELLO WORLD'")
assert(upper("123abc") == "123ABC", "upper('123abc') should be '123ABC'")
end)
test("lower function", function()
assert(lower("HELLO") == "hello", "lower('HELLO') should be 'hello'")
assert(lower("Hello World") == "hello world", "lower('Hello World') should be 'hello world'")
assert(lower("123ABC") == "123abc", "lower('123ABC') should be '123abc'")
end)
test("format function", function()
assert(format("Hello %s", "World") == "Hello World", "format should work")
assert(format("Number: %d", 42) == "Number: 42", "format with number should work")
assert(format("%.2f", 3.14159) == "3.14", "format with float should work")
end)
test("trim function", function()
assert(trim(" hello ") == "hello", "trim should remove leading and trailing spaces")
assert(trim(" hello world ") == "hello world", "trim should preserve internal spaces")
assert(trim("hello") == "hello", "trim should not affect strings without spaces")
assert(trim(" ") == "", "trim should handle all spaces")
end)
test("strsplit function", function()
local result = strsplit("a,b,c", ",")
assert(#result == 3, "strsplit should return 3 elements")
assert(result[1] == "a", "First element should be 'a'")
assert(result[2] == "b", "Second element should be 'b'")
assert(result[3] == "c", "Third element should be 'c'")
end)
test("strsplit with default separator", function()
local result = strsplit("a b c")
assert(#result == 3, "strsplit with default should return 3 elements")
assert(result[1] == "a", "First element should be 'a'")
assert(result[2] == "b", "Second element should be 'b'")
assert(result[3] == "c", "Third element should be 'c'")
end)
test("strsplit with custom separator", function()
local result = strsplit("a|b|c", "|")
assert(#result == 3, "strsplit with pipe should return 3 elements")
assert(result[1] == "a", "First element should be 'a'")
assert(result[2] == "b", "Second element should be 'b'")
assert(result[3] == "c", "Third element should be 'c'")
end)
-- Conversion function tests
test("num function", function()
assert(num("123") == 123, "num('123') should be 123")
assert(num("45.67") == 45.67, "num('45.67') should be 45.67")
assert(num("invalid") == 0, "num('invalid') should be 0")
assert(num("") == 0, "num('') should be 0")
end)
test("str function", function()
assert(str(123) == "123", "str(123) should be '123'")
assert(str(45.67) == "45.67", "str(45.67) should be '45.67'")
assert(str(0) == "0", "str(0) should be '0'")
end)
test("is_number function", function()
assert(is_number("123") == true, "is_number('123') should be true")
assert(is_number("45.67") == true, "is_number('45.67') should be true")
assert(is_number("invalid") == false, "is_number('invalid') should be false")
assert(is_number("") == false, "is_number('') should be false")
assert(is_number("123abc") == false, "is_number('123abc') should be false")
end)
-- Table function tests
test("isArray function", function()
assert(isArray({ 1, 2, 3 }) == true, "isArray should return true for sequential array")
assert(isArray({ "a", "b", "c" }) == true, "isArray should return true for string array")
assert(isArray({}) == true, "isArray should return true for empty array")
assert(isArray({ a = 1, b = 2 }) == false, "isArray should return false for map")
assert(isArray({ 1, 2, [4] = 4 }) == false, "isArray should return false for sparse array")
assert(
isArray({ [1] = 1, [2] = 2, [3] = 3 }) == true,
"isArray should return true for 1-indexed array"
)
assert(
isArray({ [0] = 1, [1] = 2 }) == false,
"isArray should return false for 0-indexed array"
)
assert(
isArray({ [1] = 1, [2] = 2, [4] = 4 }) == false,
"isArray should return false for non-sequential array"
)
assert(isArray("not a table") == false, "isArray should return false for non-table")
assert(isArray(123) == false, "isArray should return false for number")
end)
test("fromCSV assigns header keys correctly", function()
local teststr = [[
#mercenary_profiles
Id ModifyStartCost ModifyStep ModifyLevelLimit Health ResistSheet WoundSlots MeleeDamage MeleeAccuracy RangeAccuracy ReceiveAmputationChance ReceiveWoundChanceMult AttackWoundChanceMult Dodge Los StarvationLimit PainThresholdLimit PainThresholdRegen TalentPerkId ActorId SkinIndex HairType HairColorHex VoiceBank Immunity CreatureClass
john_hawkwood_boss 20 0.1 140 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 8 16 crit 1.60 critchance 0.05 0.5 0.5 0.03 0.5 1.2 0.3 8 2200 16 2 talent_the_man_who_sold_the_world human_male 0 hair1 #633D08 player Human
francis_reid_daly 20 0.1 130 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 7 14 crit 1.70 critchance 0.05 0.5 0.4 0.04 0.9 1 0.3 8 2000 10 1 talent_weapon_durability human_male 0 player Human
]]
local rows = fromCSV(teststr, { delimiter = "\t", hasheader = true, hascomments = true })
assert(#rows == 2, "Should have 2 data rows")
-- Test first row
assert(rows[1].Id == "john_hawkwood_boss", "First row Id should be 'john_hawkwood_boss'")
assert(rows[1].ModifyStartCost == "20", "First row ModifyStartCost should be '20'")
assert(rows[1].ModifyStep == "0.1", "First row ModifyStep should be '0.1'")
assert(rows[1].Health == "140", "First row Health should be '140'")
assert(rows[1].ActorId == "human_male", "First row ActorId should be 'human_male'")
assert(rows[1].HairColorHex == "#633D08", "First row HairColorHex should be '#633D08'")
-- Test second row
assert(rows[2].Id == "francis_reid_daly", "Second row Id should be 'francis_reid_daly'")
assert(rows[2].ModifyStartCost == "20", "Second row ModifyStartCost should be '20'")
assert(rows[2].ModifyStep == "0.1", "Second row ModifyStep should be '0.1'")
assert(rows[2].Health == "130", "Second row Health should be '130'")
assert(rows[2].ActorId == "human_male", "Second row ActorId should be 'human_male'")
-- Test that numeric indices still work
assert(rows[1][1] == "john_hawkwood_boss", "First row first field by index should work")
assert(rows[1][2] == "20", "First row second field by index should work")
end)
test("fromCSV debug header assignment", function()
local csv = "Id Name Value\n1 Test 100\n2 Test2 200"
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true })
assert(rows[1].Id == "1", "Id should be '1'")
assert(rows[1].Name == "Test", "Name should be 'Test'")
assert(rows[1].Value == "100", "Value should be '100'")
end)
test("fromCSV real world mercenary file format", function()
local csv = [[#mercenary_profiles
Id ModifyStartCost ModifyStep ModifyLevelLimit Health ResistSheet WoundSlots MeleeDamage MeleeAccuracy RangeAccuracy ReceiveAmputationChance ReceiveWoundChanceMult AttackWoundChanceMult Dodge Los StarvationLimit PainThresholdLimit PainThresholdRegen TalentPerkId ActorId SkinIndex HairType HairColorHex VoiceBank Immunity CreatureClass
john_hawkwood_boss 20 0.1 140 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 8 16 crit 1.60 critchance 0.05 0.5 0.5 0.03 0.5 1.2 0.3 8 2200 16 2 talent_the_man_who_sold_the_world human_male 0 hair1 #633D08 player Human
francis_reid_daly 20 0.1 130 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 7 14 crit 1.70 critchance 0.05 0.5 0.4 0.04 0.9 1 0.3 8 2000 10 1 talent_weapon_durability human_male 0 player Human
]]
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true, hascomments = true })
assert(#rows == 2, "Should have 2 data rows")
assert(rows[1].Id == "john_hawkwood_boss", "First row Id should be 'john_hawkwood_boss'")
assert(rows[1].ModifyStartCost == "20", "First row ModifyStartCost should be '20'")
assert(rows[2].Id == "francis_reid_daly", "Second row Id should be 'francis_reid_daly'")
end)
test("full CSV parser complex", function()
local original = [[
#mercenary_profiles
Id ModifyStartCost ModifyStep ModifyLevelLimit Health ResistSheet WoundSlots MeleeDamage MeleeAccuracy RangeAccuracy ReceiveAmputationChance ReceiveWoundChanceMult AttackWoundChanceMult Dodge Los StarvationLimit PainThresholdLimit PainThresholdRegen TalentPerkId ActorId SkinIndex HairType HairColorHex VoiceBank Immunity CreatureClass
john_hawkwood_boss 20 0.1 140 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 8 16 crit 1.60 critchance 0.05 0.5 0.5 0.03 0.5 1.2 0.3 8 2200 16 2 talent_the_man_who_sold_the_world human_male 0 hair1 #633D08 player Human
francis_reid_daly 20 0.1 130 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 7 14 crit 1.70 critchance 0.05 0.5 0.4 0.04 0.9 1 0.3 8 2000 10 1 talent_weapon_durability human_male 0 player Human
victoria_boudicca 20 0.1 90 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 5 10 crit 1.70 critchance 0.1 0.4 0.45 0.05 1 1.2 0.3 8 1800 8 1 talent_weapon_distance human_female 0 hair1 #633D08 player Human
persival_fawcett 20 0.1 150 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 6 12 crit 1.70 critchance 0.05 0.5 0.35 0.05 0.6 1 0.25 8 2100 16 1 talent_all_resists human_male 1 hair1 #633D08 player Human
Isabella_capet 20 0.1 100 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 7 14 crit 1.70 critchance 0.15 0.55 0.3 0.03 0.8 1.4 0.35 7 1700 14 2 talent_ignore_infection human_female 1 hair3 #FF3100 player Human
maximilian_rohr 20 0.1 120 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 8 16 crit 1.75 critchance 0.05 0.45 0.45 0.06 0.9 1 0.2 8 2000 14 1 talent_ignore_pain human_male 0 hair2 #FFC400 player Human
priya_marlon 20 0.1 110 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 5 10 crit 1.70 critchance 0.15 0.45 0.35 0.05 1 1.1 0.3 7 2200 12 1 talent_all_consumables_stack human_female 0 hair2 #FFC400 player Human
jacques_kennet 20 0.1 120 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 5 10 crit 1.70 critchance 0.05 0.45 0.35 0.04 0.9 1.2 0.3 8 2300 10 1 talent_reload_time human_male 0 hair1 #908E87 player Human
mirza_aishatu 20 0.1 110 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 7 14 crit 1.70 critchance 0.05 0.55 0.45 0.03 1 1.1 0.25 9 2000 10 1 talent_starving_slower human_female 1 hair2 #633D08 player Human
kenzie_yukio 20 0.1 100 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 5 10 crit 1.70 critchance 0.1 0.6 0.4 0.04 1 1 0.4 7 1600 12 1 talent_weight_dodge_affect human_male 0 hair2 #633D08 player Human
marika_wulfnod 20 0.1 100 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 6 12 crit 1.60 critchance 0.05 0.5 0.5 0.04 1 1 0.3 9 1900 12 1 talent_belt_slots human_female 0 hair1 #FFC400 player Human
auberon_lukas 20 0.1 120 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 4 8 crit 1.60 critchance 0.15 0.45 0.45 0.05 0.8 1 0.2 9 1900 8 2 talent_weapon_slot human_male 0 hair2 #633D08 player Human
niko_medich 20 0.1 120 blunt 0 pierce 0 lacer 0 fire 0 cold 0 poison 0 shock 0 beam 0 HumanHead HumanShoulder HumanArm HumanThigh HumanFeet HumanChest HumanBody HumanStomach HumanKnee blunt 5 10 crit 1.70 critchance 0.05 0.4 0.45 0.04 1 1.3 0.25 8 2000 10 1 talent_pistol_acc human_male 0 hair1 #908E87 player Human
#end
#mercenary_classes
Id ModifyStartCost ModifyStep PerkIds
scouts_of_hades 30 0.1 cqc_specialist_basic military_training_basic gear_maintenance_basic blind_fury_basic fire_transfer_basic assault_reflex_basic
ecclipse_blades 30 0.1 berserkgang_basic athletics_basic reaction_training_basic cold_weapon_wielding_basic cannibalism_basic carnage_basic
tifton_elite 30 0.1 heavy_weaponary_basic grenadier_basic selfhealing_basic stationary_defense_basic spray_and_pray_basic shock_awe_basic
tunnel_rats 30 0.1 cautious_basic handmade_shotgun_ammo_basic marauder_basic dirty_shot_basic vicious_symbiosis_basic covermaster_basic
phoenix_brigade 30 0.1 shielding_basic battle_physicist_basic reinforced_battery_basic revealing_flame_basic cauterize_basic scholar_basic
]]
-- Parse with headers and comments
local rows = fromCSV(original, { delimiter = "\t", hasheader = true, hascomments = true })
assert(#rows > 0, "Should have parsed rows")
-- Convert back to CSV with headers
local csv = toCSV(rows, { delimiter = "\t", hasheader = true })
-- Parse again
local rows2 = fromCSV(csv, { delimiter = "\t", hasheader = true, hascomments = false })
-- Verify identical - same number of rows
assert(#rows2 == #rows, "Round trip should have same number of rows")
-- Verify first row data is identical
assert(rows2[1].Id == rows[1].Id, "Round trip first row Id should match")
assert(
rows2[1].ModifyStartCost == rows[1].ModifyStartCost,
"Round trip first row ModifyStartCost should match"
)
assert(rows2[1].Health == rows[1].Health, "Round trip first row Health should match")
-- Verify headers are preserved
assert(rows2.Headers ~= nil, "Round trip rows should have Headers field")
assert(#rows2.Headers == #rows.Headers, "Headers should have same number of elements")
assert(rows2.Headers[1] == rows.Headers[1], "First header should match")
end)
-- Test metatable: row[1] and row.foobar return same value
test("metatable row[1] equals row.header", function()
local csv = "Id Name Value\n1 Test 100"
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true })
assert(rows[1][1] == rows[1].Id, "row[1] should equal row.Id")
assert(rows[1][2] == rows[1].Name, "row[2] should equal row.Name")
assert(rows[1][3] == rows[1].Value, "row[3] should equal row.Value")
assert(rows[1].Id == "1", "row.Id should be '1'")
assert(rows[1][1] == "1", "row[1] should be '1'")
end)
-- Test metatable: setting via header name updates numeric index
test("metatable set via header name", function()
local csv = "Id Name Value\n1 Test 100"
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true })
rows[1].Id = "999"
assert(rows[1][1] == "999", "Setting row.Id should update row[1]")
assert(rows[1].Id == "999", "row.Id should be '999'")
end)
-- Test metatable: error on unknown header assignment
test("metatable error on unknown header", function()
local csv = "Id Name Value\n1 Test 100"
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true })
local ok, errMsg = pcall(function() rows[1].UnknownHeader = "test" end)
assert(ok == false, "Should error on unknown header")
assert(string.find(errMsg, "unknown header"), "Error should mention unknown header")
end)
-- Test metatable: numeric indices still work
test("metatable numeric indices work", function()
local csv = "Id Name Value\n1 Test 100"
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true })
rows[1][1] = "999"
assert(rows[1].Id == "999", "Setting row[1] should update row.Id")
assert(rows[1][1] == "999", "row[1] should be '999'")
end)
-- Test metatable: numeric keys work normally
test("metatable numeric keys work", function()
local csv = "Id Name Value\n1 Test 100"
local rows = fromCSV(csv, { delimiter = "\t", hasheader = true })
rows[1][100] = "hundred"
assert(rows[1][100] == "hundred", "Numeric keys should work")
end)
print("\nAll tests completed!")

624
processor/luahelper.lua Normal file
View File

@@ -0,0 +1,624 @@
-- Custom Lua helpers for math operations
--- Returns the minimum of two numbers
--- @param a number First number
--- @param b number Second number
--- @return number Minimum value
function min(a, b) return math.min(a, b) end
--- Returns the maximum of two numbers
--- @param a number First number
--- @param b number Second number
--- @return number Maximum value
function max(a, b) return math.max(a, b) end
--- Rounds a number to n decimal places
--- @param x number Number to round
--- @param n number? Number of decimal places (default: 0)
--- @return number Rounded number
function round(x, n)
if n == nil then n = 0 end
return math.floor(x * 10 ^ n + 0.5) / 10 ^ n
end
--- Returns the floor of a number
--- @param x number Number to floor
--- @return number Floored number
function floor(x) return math.floor(x) end
--- Returns the ceiling of a number
--- @param x number Number to ceil
--- @return number Ceiled number
function ceil(x) return math.ceil(x) end
--- Converts string to uppercase
--- @param s string String to convert
--- @return string Uppercase string
function upper(s) return string.upper(s) end
--- Converts string to lowercase
--- @param s string String to convert
--- @return string Lowercase string
function lower(s) return string.lower(s) end
--- Formats a string using Lua string.format
--- @param s string Format string
--- @param ... any Values to format
--- @return string Formatted string
function format(s, ...) return string.format(s, ...) end
--- Removes leading and trailing whitespace from string
--- @param s string String to trim
--- @return string Trimmed string
function trim(s) return string.gsub(s, "^%s*(.-)%s*$", "%1") end
--- Splits a string by separator
--- @param inputstr string String to split
--- @param sep string? Separator pattern (default: whitespace)
--- @return table Array of string parts
function strsplit(inputstr, sep)
if sep == nil then sep = "%s" end
local t = {}
for str in string.gmatch(inputstr, "([^" .. sep .. "]+)") do
table.insert(t, str)
end
return t
end
---@param table table
---@param depth number?
function dump(table, depth)
if depth == nil then depth = 0 end
if depth > 200 then
print("Error: Depth > 200 in dump()")
return
end
for k, v in pairs(table) do
if type(v) == "table" then
print(string.rep(" ", depth) .. k .. ":")
dump(v, depth + 1)
else
print(string.rep(" ", depth) .. k .. ": ", v)
end
end
end
--- @class ParserOptions
--- @field delimiter string? The field delimiter (default: ",").
--- @field hasheader boolean? If true, first non-comment row is treated as headers (default: false).
--- @field hascomments boolean? If true, lines starting with '#' are skipped (default: false).
--- @type ParserOptions
parserDefaultOptions = { delimiter = ",", hasheader = false, hascomments = false }
--- Validates options against a set of valid option keys.
--- @param options ParserOptions? The options table to validate
function areOptionsValid(options)
if options == nil then return end
if type(options) ~= "table" then error("options must be a table") end
-- Build valid options list from validOptions table
local validOptionsStr = ""
for k, _ in pairs(parserDefaultOptions) do
validOptionsStr = validOptionsStr .. k .. ", "
end
for k, _ in pairs(options) do
if parserDefaultOptions[k] == nil then
error(
"unknown option: " .. tostring(k) .. " (valid options: " .. validOptionsStr .. ")"
)
end
end
end
--- Parses CSV text into rows and fields using a minimal RFC 4180 state machine.
---
--- Requirements/assumptions:
--- - Input is a single string containing the entire CSV content.
--- - Field separators are specified by delimiter option (default: comma).
--- - Newlines between rows may be "\n" or "\r\n". "\r\n" is treated as one line break.
--- - Fields may be quoted with double quotes (").
--- - Inside quoted fields, doubled quotes ("") represent a literal quote character.
--- - No backslash escaping is supported (not part of RFC 4180).
--- - Newlines inside quoted fields are preserved as part of the field.
--- - Leading/trailing spaces are preserved; no trimming is performed.
--- - Empty fields and empty rows are preserved.
--- - The final row is emitted even if the text does not end with a newline.
--- - Lines starting with '#' (after optional leading whitespace) are treated as comments and skipped if hascomments is true.
---
--- @param csv string The CSV text to parse.
--- @param options ParserOptions? Options for the parser
--- @return table #A table (array) of rows; each row is a table with numeric indices and optionally header-named keys.
function fromCSV(csv, options)
if options == nil then options = {} end
-- Validate options
areOptionsValid(options)
local delimiter = options.delimiter or parserDefaultOptions.delimiter
local hasheader = options.hasheader or parserDefaultOptions.hasheader
local hascomments = options.hascomments or parserDefaultOptions.hascomments
local allRows = {}
local fields = {}
local field = {}
local STATE_DEFAULT = 1
local STATE_IN_QUOTES = 2
local STATE_QUOTE_IN_QUOTES = 3
local state = STATE_DEFAULT
local i = 1
local len = #csv
while i <= len do
local c = csv:sub(i, i)
if state == STATE_DEFAULT then
if c == '"' then
state = STATE_IN_QUOTES
i = i + 1
elseif c == delimiter then
table.insert(fields, table.concat(field))
field = {}
i = i + 1
elseif c == "\r" or c == "\n" then
table.insert(fields, table.concat(field))
field = {}
local shouldAdd = true
if hascomments and #fields > 0 then
local firstField = fields[1]
local trimmed = trim(firstField)
if string.sub(trimmed, 1, 1) == "#" then shouldAdd = false end
end
if shouldAdd then table.insert(allRows, fields) end
fields = {}
if c == "\r" and i < len and csv:sub(i + 1, i + 1) == "\n" then
i = i + 2
else
i = i + 1
end
else
table.insert(field, c)
i = i + 1
end
elseif state == STATE_IN_QUOTES then
if c == '"' then
state = STATE_QUOTE_IN_QUOTES
i = i + 1
else
table.insert(field, c)
i = i + 1
end
else -- STATE_QUOTE_IN_QUOTES
if c == '"' then
table.insert(field, '"')
state = STATE_IN_QUOTES
i = i + 1
elseif c == delimiter then
table.insert(fields, table.concat(field))
field = {}
state = STATE_DEFAULT
i = i + 1
elseif c == "\r" or c == "\n" then
table.insert(fields, table.concat(field))
field = {}
local shouldAdd = true
if hascomments and #fields > 0 then
local firstField = fields[1]
local trimmed = string.gsub(firstField, "^%s*(.-)%s*$", "%1")
if string.sub(trimmed, 1, 1) == "#" then shouldAdd = false end
end
if shouldAdd then table.insert(allRows, fields) end
fields = {}
state = STATE_DEFAULT
if c == "\r" and i < len and csv:sub(i + 1, i + 1) == "\n" then
i = i + 2
else
i = i + 1
end
else
state = STATE_DEFAULT
-- Don't increment i, reprocess character in DEFAULT state
end
end
end
if #field > 0 or #fields > 0 then
table.insert(fields, table.concat(field))
local shouldAdd = true
if hascomments and #fields > 0 then
local firstField = fields[1]
local trimmed = string.gsub(firstField, "^%s*(.-)%s*$", "%1")
if string.sub(trimmed, 1, 1) == "#" then shouldAdd = false end
end
if shouldAdd then table.insert(allRows, fields) end
end
if hasheader and #allRows > 0 then
local headers = allRows[1]
local headerMap = {}
for j = 1, #headers do
if headers[j] ~= nil and headers[j] ~= "" then
local headerName = trim(headers[j])
headerMap[headerName] = j
end
end
local header_mt = {
headers = headerMap,
__index = function(t, key)
local mt = getmetatable(t)
if type(key) == "string" and mt.headers and mt.headers[key] then
return rawget(t, mt.headers[key])
end
return rawget(t, key)
end,
__newindex = function(t, key, value)
local mt = getmetatable(t)
if type(key) == "string" and mt.headers then
if mt.headers[key] then
rawset(t, mt.headers[key], value)
else
error("unknown header: " .. tostring(key))
end
else
rawset(t, key, value)
end
end,
}
local rows = {}
for ii = 2, #allRows do
local row = {}
local dataRow = allRows[ii]
for j = 1, #dataRow do
row[j] = dataRow[j]
end
setmetatable(row, header_mt)
table.insert(rows, row)
end
rows.Headers = headers
return rows
end
return allRows
end
--- Converts a table of rows back to CSV text format (RFC 4180 compliant).
---
--- Requirements:
--- - Input is a table (array) of rows, where each row is a table (array) of field values.
--- - Field values are converted to strings using tostring().
--- - Fields are quoted if they contain the delimiter, newlines, or double quotes.
--- - Double quotes inside quoted fields are doubled ("").
--- - Fields are joined with the specified delimiter; rows are joined with newlines.
--- - If includeHeaders is true and rows have a Headers field, headers are included as the first row.
---
--- @param rows table Array of rows, where each row is an array of field values.
--- @param options ParserOptions? Options for the parser
--- @return string #CSV-formatted text
function toCSV(rows, options)
if options == nil then options = {} end
-- Validate options
areOptionsValid(options)
local delimiter = options.delimiter or parserDefaultOptions.delimiter
local includeHeaders = options.hasheader or parserDefaultOptions.hasheader
local rowStrings = {}
-- Include headers row if requested and available
if includeHeaders and #rows > 0 and rows.Headers ~= nil then
local headerStrings = {}
for _, header in ipairs(rows.Headers) do
local headerStr = tostring(header)
local needsQuoting = false
if
headerStr:find(delimiter)
or headerStr:find("\n")
or headerStr:find("\r")
or headerStr:find('"')
then
needsQuoting = true
end
if needsQuoting then
headerStr = headerStr:gsub('"', '""')
headerStr = '"' .. headerStr .. '"'
end
table.insert(headerStrings, headerStr)
end
table.insert(rowStrings, table.concat(headerStrings, delimiter))
end
for _, row in ipairs(rows) do
local fieldStrings = {}
for _, field in ipairs(row) do
local fieldStr = tostring(field)
local needsQuoting = false
if
fieldStr:find(delimiter)
or fieldStr:find("\n")
or fieldStr:find("\r")
or fieldStr:find('"')
then
needsQuoting = true
end
if needsQuoting then
fieldStr = fieldStr:gsub('"', '""')
fieldStr = '"' .. fieldStr .. '"'
end
table.insert(fieldStrings, fieldStr)
end
table.insert(rowStrings, table.concat(fieldStrings, delimiter))
end
return table.concat(rowStrings, "\n")
end
--- Converts string to number, returns 0 if invalid
--- @param str string String to convert
--- @return number Numeric value or 0
function num(str) return tonumber(str) or 0 end
--- Converts number to string
--- @param num number Number to convert
--- @return string String representation
function str(num) return tostring(num) end
--- Checks if string is numeric
--- @param str string String to check
--- @return boolean True if string is numeric
function is_number(str) return tonumber(str) ~= nil end
--- Checks if table is a sequential array (1-indexed with no gaps)
--- @param t table Table to check
--- @return boolean True if table is an array
function isArray(t)
if type(t) ~= "table" then return false end
local max = 0
local count = 0
for k, _ in pairs(t) do
if type(k) ~= "number" or k < 1 or math.floor(k) ~= k then return false end
max = math.max(max, k)
count = count + 1
end
return max == count
end
modified = false
-- ============================================================================
-- XML HELPER FUNCTIONS
-- ============================================================================
--- Find all elements with a specific tag name (recursive search)
--- @param root table The root XML element (with _tag, _attr, _children fields)
--- @param tagName string The tag name to search for
--- @return table Array of matching elements
function findElements(root, tagName)
local results = {}
local function search(element)
if element._tag == tagName then table.insert(results, element) end
if element._children then
for _, child in ipairs(element._children) do
search(child)
end
end
end
search(root)
return results
end
--- Visit all elements recursively and call a function on each
--- @param root table The root XML element
--- @param callback function Function to call with each element: callback(element, depth, path)
function visitElements(root, callback)
local function visit(element, depth, path)
callback(element, depth, path)
if element._children then
for i, child in ipairs(element._children) do
local childPath = path .. "/" .. child._tag .. "[" .. i .. "]"
visit(child, depth + 1, childPath)
end
end
end
visit(root, 0, "/" .. root._tag)
end
--- Get numeric value from XML element attribute
--- @param element table XML element with _attr field
--- @param attrName string Attribute name
--- @return number|nil The numeric value or nil if not found/not numeric
function getNumAttr(element, attrName)
if not element._attr then return nil end
local value = element._attr[attrName]
if not value then return nil end
return tonumber(value)
end
--- Set numeric value to XML element attribute
--- @param element table XML element with _attr field
--- @param attrName string Attribute name
--- @param value number Numeric value to set
function setNumAttr(element, attrName, value)
if not element._attr then element._attr = {} end
element._attr[attrName] = tostring(value)
end
--- Modify numeric attribute by applying a function
--- @param element table XML element
--- @param attrName string Attribute name
--- @param func function Function that takes current value and returns new value
--- @return boolean True if modification was made
function modifyNumAttr(element, attrName, func)
local current = getNumAttr(element, attrName)
if current then
setNumAttr(element, attrName, func(current))
return true
end
return false
end
--- Find all elements matching a predicate function
--- @param root table The root XML element
--- @param predicate function Function that takes element and returns true/false
--- @return table Array of matching elements
function filterElements(root, predicate)
local results = {}
visitElements(root, function(element)
if predicate(element) then table.insert(results, element) end
end)
return results
end
--- Get text content of an element
--- @param element table XML element
--- @return string|nil The text content or nil
function getText(element) return element._text end
--- Set text content of an element
--- @param element table XML element
--- @param text string Text content to set
function setText(element, text) element._text = text end
--- Check if element has an attribute
--- @param element table XML element
--- @param attrName string Attribute name
--- @return boolean True if attribute exists
function hasAttr(element, attrName) return element._attr and element._attr[attrName] ~= nil end
--- Get attribute value as string
--- @param element table XML element
--- @param attrName string Attribute name
--- @return string|nil The attribute value or nil
function getAttr(element, attrName)
if not element._attr then return nil end
return element._attr[attrName]
end
--- Set attribute value
--- @param element table XML element
--- @param attrName string Attribute name
--- @param value any Value to set (will be converted to string)
function setAttr(element, attrName, value)
if not element._attr then element._attr = {} end
element._attr[attrName] = tostring(value)
end
--- Find first element with a specific tag name (searches direct children only)
--- @param parent table The parent XML element
--- @param tagName string The tag name to search for
--- @return table|nil The first matching element or nil
function findFirstElement(parent, tagName)
if not parent._children then return nil end
for _, child in ipairs(parent._children) do
if child._tag == tagName then return child end
end
return nil
end
--- Add a child element to a parent
--- @param parent table The parent XML element
--- @param child table The child element to add
function addChild(parent, child)
if not parent._children then parent._children = {} end
table.insert(parent._children, child)
end
--- Remove all children with a specific tag name
--- @param parent table The parent XML element
--- @param tagName string The tag name to remove
--- @return number Count of removed children
function removeChildren(parent, tagName)
if not parent._children then return 0 end
local removed = 0
local i = 1
while i <= #parent._children do
if parent._children[i]._tag == tagName then
table.remove(parent._children, i)
removed = removed + 1
else
i = i + 1
end
end
return removed
end
--- Get all direct children with a specific tag name
--- @param parent table The parent XML element
--- @param tagName string The tag name to search for
--- @return table Array of matching children
function getChildren(parent, tagName)
local results = {}
if not parent._children then return results end
for _, child in ipairs(parent._children) do
if child._tag == tagName then table.insert(results, child) end
end
return results
end
--- Count children with a specific tag name
--- @param parent table The parent XML element
--- @param tagName string The tag name to count
--- @return number Count of matching children
function countChildren(parent, tagName)
if not parent._children then return 0 end
local count = 0
for _, child in ipairs(parent._children) do
if child._tag == tagName then count = count + 1 end
end
return count
end
-- ============================================================================
-- JSON HELPER FUNCTIONS
-- ============================================================================
--- Recursively visit all values in a JSON structure
--- @param data table JSON data (nested tables)
--- @param callback function Function called with (value, key, parent)
function visitJSON(data, callback)
local function visit(obj, key, parent)
callback(obj, key, parent)
if type(obj) == "table" then
for k, v in pairs(obj) do
visit(v, k, obj)
end
end
end
visit(data, nil, nil)
end
--- Find all values in JSON matching a predicate
--- @param data table JSON data
--- @param predicate function Function that takes (value, key, parent) and returns true/false
--- @return table Array of matching values
function findInJSON(data, predicate)
local results = {}
visitJSON(data, function(value, key, parent)
if predicate(value, key, parent) then table.insert(results, value) end
end)
return results
end
--- Modify all numeric values in JSON matching a condition
--- @param data table JSON data
--- @param predicate function Function that takes (value, key, parent) and returns true/false
--- @param modifier function Function that takes current value and returns new value
function modifyJSONNumbers(data, predicate, modifier)
visitJSON(data, function(value, key, parent)
if type(value) == "number" and predicate(value, key, parent) then
if parent and key then parent[key] = modifier(value) end
end
end)
end

29
processor/meta.go Normal file
View File

@@ -0,0 +1,29 @@
package processor
import (
_ "embed"
"fmt"
"os"
logger "git.site.quack-lab.dev/dave/cylogger"
)
//go:embed meta.lua
var metaFileContent string
var metaLogger = logger.Default.WithPrefix("meta")
// GenerateMetaFile generates meta.lua with function signatures for LuaLS autocomplete
func GenerateMetaFile(outputPath string) error {
metaLogger.Info("Generating meta.lua file for LuaLS autocomplete")
// Write the embedded meta file
err := os.WriteFile(outputPath, []byte(metaFileContent), 0644)
if err != nil {
metaLogger.Error("Failed to write meta.lua: %v", err)
return fmt.Errorf("failed to write meta.lua: %w", err)
}
metaLogger.Info("Successfully generated meta.lua at %q", outputPath)
return nil
}

245
processor/meta.lua Normal file
View File

@@ -0,0 +1,245 @@
---@meta
---@class ParserOptions
---@field delimiter string? The field delimiter (default: ",").
---@field hasheader boolean? If true, first non-comment row is treated as headers (default: false).
---@field hascomments boolean? If true, lines starting with '#' are skipped (default: false).
---@class XMLElement
---@field _tag string The XML tag name
---@field _attr {[string]: string}? XML attributes as key-value pairs
---@field _text string? Text content of the element
---@field _children XMLElement[]? Child elements
---@class JSONNode
---@field [string] string | number | boolean | nil | JSONNode | JSONArray JSON object fields
---@alias JSONArray (string | number | boolean | nil | JSONNode)[]
---@class CSVRow
---@field [integer] string Numeric indices for field access
---@field Headers string[]? Header row if hasheader was true
--- Returns the minimum of two numbers
---@param a number First number
---@param b number Second number
---@return number #Minimum value
function min(a, b) end
--- Returns the maximum of two numbers
---@param a number First number
---@param b number Second number
---@return number #Maximum value
function max(a, b) end
--- Rounds a number to n decimal places
---@param x number Number to round
---@param n number? Number of decimal places (default: 0)
---@return number #Rounded number
function round(x, n) end
--- Returns the floor of a number
---@param x number Number to floor
---@return number #Floored number
function floor(x) end
--- Returns the ceiling of a number
---@param x number Number to ceil
---@return number #Ceiled number
function ceil(x) end
--- Converts string to uppercase
---@param s string String to convert
---@return string #Uppercase string
function upper(s) end
--- Converts string to lowercase
---@param s string String to convert
---@return string #Lowercase string
function lower(s) end
--- Formats a string using Lua string.format
---@param s string Format string
---@param ... any Values to format
---@return string #Formatted string
function format(s, ...) end
--- Removes leading and trailing whitespace from string
---@param s string String to trim
---@return string #Trimmed string
function trim(s) end
--- Splits a string by separator
---@param inputstr string String to split
---@param sep string? Separator pattern (default: whitespace)
---@return string[] #Array of string parts
function strsplit(inputstr, sep) end
--- Prints table structure recursively
---@param table {[any]: any} Table to dump
---@param depth number? Current depth (default: 0)
function dump(table, depth) end
--- Validates options against a set of valid option keys.
---@param options ParserOptions? The options table to validate
function areOptionsValid(options) end
--- Parses CSV text into rows and fields using a minimal RFC 4180 state machine.
--- Requirements/assumptions:<br>
--- Input is a single string containing the entire CSV content.<br>
--- Field separators are specified by delimiter option (default: comma).<br>
--- Newlines between rows may be "\n" or "\r\n". "\r\n" is treated as one line break.<br>
--- Fields may be quoted with double quotes (").<br>
--- Inside quoted fields, doubled quotes ("") represent a literal quote character.<br>
--- No backslash escaping is supported (not part of RFC 4180).<br>
--- Newlines inside quoted fields are preserved as part of the field.<br>
--- Leading/trailing spaces are preserved; no trimming is performed.<br>
--- Empty fields and empty rows are preserved.<br>
--- The final row is emitted even if the text does not end with a newline.<br>
--- Lines starting with '#' (after optional leading whitespace) are treated as comments and skipped if hascomments is true.<br>
---@param csv string The CSV text to parse.
---@param options ParserOptions? Options for the parser
---@return CSVRow[] #A table (array) of rows; each row is a table with numeric indices and optionally header-named keys.
function fromCSV(csv, options) end
--- Converts a table of rows back to CSV text format (RFC 4180 compliant).<br>
--- Requirements:<br>
--- Input is a table (array) of rows, where each row is a table (array) of field values.<br>
--- Field values are converted to strings using tostring().<br>
--- Fields are quoted if they contain the delimiter, newlines, or double quotes.<br>
--- Double quotes inside quoted fields are doubled ("").<br>
--- Fields are joined with the specified delimiter; rows are joined with newlines.<br>
--- If includeHeaders is true and rows have a Headers field, headers are included as the first row.<br>
---@param rows CSVRow[] Array of rows, where each row is an array of field values.
---@param options ParserOptions? Options for the parser
---@return string #CSV-formatted text
function toCSV(rows, options) end
--- Converts string to number, returns 0 if invalid
---@param str string String to convert
---@return number #Numeric value or 0
function num(str) end
--- Converts number to string
---@param num number Number to convert
---@return string #String representation
function str(num) end
--- Checks if string is numeric
---@param str string String to check
---@return boolean #True if string is numeric
function is_number(str) end
--- Checks if table is a sequential array (1-indexed with no gaps)
---@param t {[integer]: any} Table to check
---@return boolean #True if table is an array
function isArray(t) end
--- Find all elements with a specific tag name (recursive search)
---@param root XMLElement The root XML element (with _tag, _attr, _children fields)
---@param tagName string The tag name to search for
---@return XMLElement[] #Array of matching elements
function findElements(root, tagName) end
--- Visit all elements recursively and call a function on each
---@param root XMLElement The root XML element
---@param callback fun(element: XMLElement, depth: number, path: string) Function to call with each element
function visitElements(root, callback) end
--- Get numeric value from XML element attribute
---@param element XMLElement XML element with _attr field
---@param attrName string Attribute name
---@return number? #The numeric value or nil if not found/not numeric
function getNumAttr(element, attrName) end
--- Set numeric value to XML element attribute
---@param element XMLElement XML element with _attr field
---@param attrName string Attribute name
---@param value number Numeric value to set
function setNumAttr(element, attrName, value) end
--- Modify numeric attribute by applying a function
---@param element XMLElement XML element
---@param attrName string Attribute name
---@param func fun(currentValue: number): number Function that takes current value and returns new value
---@return boolean #True if modification was made
function modifyNumAttr(element, attrName, func) end
--- Find all elements matching a predicate function
---@param root XMLElement The root XML element
---@param predicate fun(element: XMLElement): boolean Function that takes element and returns true/false
---@return XMLElement[] #Array of matching elements
function filterElements(root, predicate) end
--- Get text content of an element
---@param element XMLElement XML element
---@return string? #The text content or nil
function getText(element) end
--- Set text content of an element
---@param element XMLElement XML element
---@param text string Text content to set
function setText(element, text) end
--- Check if element has an attribute
---@param element XMLElement XML element
---@param attrName string Attribute name
---@return boolean #True if attribute exists
function hasAttr(element, attrName) end
--- Get attribute value as string
---@param element XMLElement XML element
---@param attrName string Attribute name
---@return string? #The attribute value or nil
function getAttr(element, attrName) end
--- Set attribute value
---@param element XMLElement XML element
---@param attrName string Attribute name
---@param value string | number | boolean Value to set (will be converted to string)
function setAttr(element, attrName, value) end
--- Find first element with a specific tag name (searches direct children only)
---@param parent XMLElement The parent XML element
---@param tagName string The tag name to search for
---@return XMLElement? #The first matching element or nil
function findFirstElement(parent, tagName) end
--- Add a child element to a parent
---@param parent XMLElement The parent XML element
---@param child XMLElement The child element to add
function addChild(parent, child) end
--- Remove all children with a specific tag name
---@param parent XMLElement The parent XML element
---@param tagName string The tag name to remove
---@return number #Count of removed children
function removeChildren(parent, tagName) end
--- Get all direct children with a specific tag name
---@param parent XMLElement The parent XML element
---@param tagName string The tag name to search for
---@return XMLElement[] #Array of matching children
function getChildren(parent, tagName) end
--- Count children with a specific tag name
---@param parent XMLElement The parent XML element
---@param tagName string The tag name to count
---@return number #Count of matching children
function countChildren(parent, tagName) end
--- Recursively visit all values in a JSON structure
---@param data JSONNode | JSONArray JSON data (nested tables)
---@param callback fun(value: string | number | boolean | nil | JSONNode | JSONArray, key: string?, parent: JSONNode?): nil Function called with (value, key, parent)
function visitJSON(data, callback) end
--- Find all values in JSON matching a predicate
---@param data JSONNode | JSONArray JSON data
---@param predicate fun(value: string | number | boolean | nil | JSONNode | JSONArray, key: string?, parent: JSONNode?): boolean Function that takes (value, key, parent) and returns true/false
---@return (string | number | boolean | nil | JSONNode | JSONArray)[] #Array of matching values
function findInJSON(data, predicate) end
--- Modify all numeric values in JSON matching a condition
---@param data JSONNode | JSONArray JSON data
---@param predicate fun(value: string | number | boolean | nil | JSONNode | JSONArray, key: string?, parent: JSONNode?): boolean Function that takes (value, key, parent) and returns true/false
---@param modifier fun(currentValue: number): number Function that takes current value and returns new value
function modifyJSONNumbers(data, predicate, modifier) end

View File

@@ -1,6 +1,7 @@
package processor
import (
_ "embed"
"fmt"
"io"
"net/http"
@@ -13,6 +14,9 @@ import (
lua "github.com/yuin/gopher-lua"
)
//go:embed luahelper.lua
var helperScript string
// processorLogger is a scoped logger for the processor package.
var processorLogger = logger.Default.WithPrefix("processor")
@@ -160,84 +164,6 @@ func InitLuaHelpers(L *lua.LState) error {
initLuaHelpersLogger := processorLogger.WithPrefix("InitLuaHelpers")
initLuaHelpersLogger.Debug("Loading Lua helper functions")
helperScript := `
-- Custom Lua helpers for math operations
function min(a, b) return math.min(a, b) end
function max(a, b) return math.max(a, b) end
function round(x, n)
if n == nil then n = 0 end
return math.floor(x * 10^n + 0.5) / 10^n
end
function floor(x) return math.floor(x) end
function ceil(x) return math.ceil(x) end
function upper(s) return string.upper(s) end
function lower(s) return string.lower(s) end
function format(s, ...) return string.format(s, ...) end
function trim(s) return string.gsub(s, "^%s*(.-)%s*$", "%1") end
-- String split helper
function strsplit(inputstr, sep)
if sep == nil then
sep = "%s"
end
local t = {}
for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
table.insert(t, str)
end
return t
end
---@param table table
---@param depth number?
function DumpTable(table, depth)
if depth == nil then
depth = 0
end
if (depth > 200) then
print("Error: Depth > 200 in dumpTable()")
return
end
for k, v in pairs(table) do
if (type(v) == "table") then
print(string.rep(" ", depth) .. k .. ":")
DumpTable(v, depth + 1)
else
print(string.rep(" ", depth) .. k .. ": ", v)
end
end
end
-- String to number conversion helper
function num(str)
return tonumber(str) or 0
end
-- Number to string conversion
function str(num)
return tostring(num)
end
-- Check if string is numeric
function is_number(str)
return tonumber(str) ~= nil
end
function isArray(t)
if type(t) ~= "table" then return false end
local max = 0
local count = 0
for k, _ in pairs(t) do
if type(k) ~= "number" or k < 1 or math.floor(k) ~= k then
return false
end
max = math.max(max, k)
count = count + 1
end
return max == count
end
modified = false
`
if err := L.DoString(helperScript); err != nil {
initLuaHelpersLogger.Error("Failed to load Lua helper functions: %v", err)
return fmt.Errorf("error loading helper functions: %v", err)
@@ -303,8 +229,8 @@ func BuildLuaScript(luaExpr string) string {
// BuildJSONLuaScript prepares a Lua expression for JSON mode
func BuildJSONLuaScript(luaExpr string) string {
buildJsonLuaScriptLogger := processorLogger.WithPrefix("BuildJSONLuaScript").WithField("inputLuaExpr", luaExpr)
buildJsonLuaScriptLogger.Debug("Building full Lua script for JSON mode from expression")
buildJSONLuaScriptLogger := processorLogger.WithPrefix("BuildJSONLuaScript").WithField("inputLuaExpr", luaExpr)
buildJSONLuaScriptLogger.Debug("Building full Lua script for JSON mode from expression")
// Perform $var substitutions from globalVariables
luaExpr = replaceVariables(luaExpr)
@@ -316,7 +242,7 @@ func BuildJSONLuaScript(luaExpr string) string {
local res = run()
modified = res == nil or res
`, luaExpr)
buildJsonLuaScriptLogger.Trace("Generated full JSON Lua script: %q", utils.LimitString(fullScript, 200))
buildJSONLuaScriptLogger.Trace("Generated full JSON Lua script: %q", utils.LimitString(fullScript, 200))
return fullScript
}
@@ -385,9 +311,9 @@ func fetch(L *lua.LState) int {
fetchLogger.Debug("Fetching URL: %q", url)
// Get options from second argument if provided
var method string = "GET"
var headers map[string]string = make(map[string]string)
var body string = ""
var method = "GET"
var headers = make(map[string]string)
var body = ""
if L.GetTop() > 1 {
options := L.ToTable(2)
@@ -501,8 +427,8 @@ func EvalRegex(L *lua.LState) int {
if len(matches) > 0 {
matchesTable := L.NewTable()
for i, match := range matches {
matchesTable.RawSetString(fmt.Sprintf("%d", i), lua.LString(match))
evalRegexLogger.Debug("Set table[%d] = %q", i, match)
matchesTable.RawSetInt(i+1, lua.LString(match))
evalRegexLogger.Debug("Set table[%d] = %q", i+1, match)
}
L.Push(matchesTable)
} else {
@@ -519,25 +445,59 @@ func GetLuaFunctionsHelp() string {
return `Lua Functions Available in Global Environment:
MATH FUNCTIONS:
min(a, b) - Returns the minimum of two numbers
max(a, b) - Returns the maximum of two numbers
round(x, n) - Rounds x to n decimal places (default 0)
floor(x) - Returns the floor of x
ceil(x) - Returns the ceiling of x
min(a, b) - Returns the minimum of two numbers
max(a, b) - Returns the maximum of two numbers
round(x, n) - Rounds x to n decimal places (default 0)
floor(x) - Returns the floor of x
ceil(x) - Returns the ceiling of x
STRING FUNCTIONS:
upper(s) - Converts string to uppercase
lower(s) - Converts string to lowercase
format(s, ...) - Formats string using Lua string.format
trim(s) - Removes leading/trailing whitespace
upper(s) - Converts string to uppercase
lower(s) - Converts string to lowercase
format(s, ...) - Formats string using Lua string.format
trim(s) - Removes leading/trailing whitespace
strsplit(inputstr, sep) - Splits string by separator (default: whitespace)
num(str) - Converts string to number (returns 0 if invalid)
str(num) - Converts number to string
is_number(str) - Returns true if string is numeric
fromCSV(csv, options) - Parses CSV text into rows of fields
options: {delimiter=",", hasheader=false, hascomments=false}
toCSV(rows, options) - Converts table of rows to CSV text format
options: {delimiter=",", hasheader=false}
num(str) - Converts string to number (returns 0 if invalid)
str(num) - Converts number to string
is_number(str) - Returns true if string is numeric
TABLE FUNCTIONS:
DumpTable(table, depth) - Prints table structure recursively
isArray(t) - Returns true if table is a sequential array
dump(table, depth) - Prints table structure recursively
isArray(t) - Returns true if table is a sequential array
XML HELPER FUNCTIONS:
findElements(root, tagName) - Find all elements with specific tag name (recursive)
findFirstElement(parent, tagName) - Find first direct child with specific tag name
visitElements(root, callback) - Visit all elements recursively
callback(element, depth, path)
filterElements(root, predicate) - Find elements matching condition
predicate(element) returns true/false
getNumAttr(element, attrName) - Get numeric attribute value
setNumAttr(element, attrName, value) - Set numeric attribute value
modifyNumAttr(element, attrName, func)- Modify numeric attribute with function
func(currentValue) returns newValue
hasAttr(element, attrName) - Check if attribute exists
getAttr(element, attrName) - Get attribute value as string
setAttr(element, attrName, value) - Set attribute value
getText(element) - Get element text content
setText(element, text) - Set element text content
addChild(parent, child) - Add child element to parent
removeChildren(parent, tagName) - Remove all children with specific tag name
getChildren(parent, tagName) - Get all direct children with specific tag name
countChildren(parent, tagName) - Count direct children with specific tag name
JSON HELPER FUNCTIONS:
visitJSON(data, callback) - Visit all values in JSON structure
callback(value, key, parent)
findInJSON(data, predicate) - Find values matching condition
predicate(value, key, parent) returns true/false
modifyJSONNumbers(data, predicate, modifier) - Modify numeric values
predicate(value, key, parent) returns true/false
modifier(currentValue) returns newValue
HTTP FUNCTIONS:
fetch(url, options) - Makes HTTP request, returns response table
@@ -552,12 +512,31 @@ UTILITY FUNCTIONS:
print(...) - Prints arguments to Go logger
EXAMPLES:
round(3.14159, 2) -> 3.14
-- Math
round(3.14159, 2) -> 3.14
min(5, 3) -> 3
-- String
strsplit("a,b,c", ",") -> {"a", "b", "c"}
upper("hello") -> "HELLO"
min(5, 3) -> 3
num("123") -> 123
is_number("abc") -> false
fetch("https://api.example.com/data")
re("(\\w+)@(\\w+)", "user@domain.com") -> {"user@domain.com", "user", "domain.com"}`
upper("hello") -> "HELLO"
num("123") -> 123
-- XML (where root is XML element with _tag, _attr, _children fields)
local items = findElements(root, "Item")
for _, item in ipairs(items) do
modifyNumAttr(item, "Weight", function(w) return w * 2 end)
end
-- JSON (where data is parsed JSON object)
visitJSON(data, function(value, key, parent)
if type(value) == "number" and key == "price" then
parent[key] = value * 1.5
end
end)
-- HTTP
local response = fetch("https://api.example.com/data")
if response.ok then
print(response.body)
end`
}

View File

@@ -0,0 +1,218 @@
package processor
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
lua "github.com/yuin/gopher-lua"
)
// Test replaceVariables function
func TestReplaceVariables(t *testing.T) {
// Setup global variables
globalVariables = map[string]interface{}{
"multiplier": 2.5,
"prefix": "TEST_",
"enabled": true,
"disabled": false,
"count": 42,
}
defer func() {
globalVariables = make(map[string]interface{})
}()
tests := []struct {
name string
input string
expected string
}{
{
name: "Replace numeric variable",
input: "v1 * $multiplier",
expected: "v1 * 2.5",
},
{
name: "Replace string variable",
input: `s1 = $prefix .. "value"`,
expected: `s1 = "TEST_" .. "value"`,
},
{
name: "Replace boolean true",
input: "enabled = $enabled",
expected: "enabled = true",
},
{
name: "Replace boolean false",
input: "disabled = $disabled",
expected: "disabled = false",
},
{
name: "Replace integer",
input: "count = $count",
expected: "count = 42",
},
{
name: "Multiple replacements",
input: "$count * $multiplier",
expected: "42 * 2.5",
},
{
name: "No variables",
input: "v1 * 2",
expected: "v1 * 2",
},
{
name: "Undefined variable",
input: "v1 * $undefined",
expected: "v1 * $undefined",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := replaceVariables(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// Test SetVariables with all type cases
func TestSetVariablesAllTypes(t *testing.T) {
vars := map[string]interface{}{
"int_val": 42,
"int64_val": int64(100),
"float32_val": float32(3.14),
"float64_val": 2.718,
"bool_true": true,
"bool_false": false,
"string_val": "hello",
}
SetVariables(vars)
// Create Lua state to verify
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
// Verify int64
int64Val := L.GetGlobal("int64_val")
assert.Equal(t, lua.LTNumber, int64Val.Type())
assert.Equal(t, 100.0, float64(int64Val.(lua.LNumber)))
// Verify float32
float32Val := L.GetGlobal("float32_val")
assert.Equal(t, lua.LTNumber, float32Val.Type())
assert.InDelta(t, 3.14, float64(float32Val.(lua.LNumber)), 0.01)
// Verify bool true
boolTrue := L.GetGlobal("bool_true")
assert.Equal(t, lua.LTBool, boolTrue.Type())
assert.True(t, bool(boolTrue.(lua.LBool)))
// Verify bool false
boolFalse := L.GetGlobal("bool_false")
assert.Equal(t, lua.LTBool, boolFalse.Type())
assert.False(t, bool(boolFalse.(lua.LBool)))
// Verify string
stringVal := L.GetGlobal("string_val")
assert.Equal(t, lua.LTString, stringVal.Type())
assert.Equal(t, "hello", string(stringVal.(lua.LString)))
}
// Test HTTP fetch with test server
func TestFetchWithTestServer(t *testing.T) {
// Create test HTTP server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request
assert.Equal(t, "GET", r.Method)
// Send response
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status": "success"}`))
}))
defer server.Close()
// Test fetch
L := lua.NewState()
defer L.Close()
L.SetGlobal("fetch", L.NewFunction(fetch))
script := `
response = fetch("` + server.URL + `")
assert(response ~= nil, "Expected response")
assert(response.ok == true, "Expected ok to be true")
assert(response.status == 200, "Expected status 200")
assert(response.body == '{"status": "success"}', "Expected correct body")
`
err := L.DoString(script)
assert.NoError(t, err)
}
func TestFetchWithTestServerPOST(t *testing.T) {
// Create test HTTP server
receivedBody := ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
// Read body
buf := make([]byte, 1024)
n, _ := r.Body.Read(buf)
receivedBody = string(buf[:n])
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`{"created": true}`))
}))
defer server.Close()
L := lua.NewState()
defer L.Close()
L.SetGlobal("fetch", L.NewFunction(fetch))
script := `
local opts = {
method = "POST",
headers = {["Content-Type"] = "application/json"},
body = '{"test": "data"}'
}
response = fetch("` + server.URL + `", opts)
assert(response ~= nil, "Expected response")
assert(response.ok == true, "Expected ok to be true")
assert(response.status == 201, "Expected status 201")
`
err := L.DoString(script)
assert.NoError(t, err)
assert.Equal(t, `{"test": "data"}`, receivedBody)
}
func TestFetchWithTestServer404(t *testing.T) {
// Create test HTTP server that returns 404
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`{"error": "not found"}`))
}))
defer server.Close()
L := lua.NewState()
defer L.Close()
L.SetGlobal("fetch", L.NewFunction(fetch))
script := `
response = fetch("` + server.URL + `")
assert(response ~= nil, "Expected response")
assert(response.ok == false, "Expected ok to be false for 404")
assert(response.status == 404, "Expected status 404")
`
err := L.DoString(script)
assert.NoError(t, err)
}

View File

@@ -0,0 +1,366 @@
package processor
import (
"testing"
"github.com/stretchr/testify/assert"
lua "github.com/yuin/gopher-lua"
)
func TestSetVariables(t *testing.T) {
// Test with various variable types
vars := map[string]interface{}{
"multiplier": 2.5,
"prefix": "TEST_",
"enabled": true,
"count": 42,
}
SetVariables(vars)
// Create a new Lua state to verify variables are set
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
// Verify the variables are accessible
multiplier := L.GetGlobal("multiplier")
assert.Equal(t, lua.LTNumber, multiplier.Type())
assert.Equal(t, 2.5, float64(multiplier.(lua.LNumber)))
prefix := L.GetGlobal("prefix")
assert.Equal(t, lua.LTString, prefix.Type())
assert.Equal(t, "TEST_", string(prefix.(lua.LString)))
enabled := L.GetGlobal("enabled")
assert.Equal(t, lua.LTBool, enabled.Type())
assert.True(t, bool(enabled.(lua.LBool)))
count := L.GetGlobal("count")
assert.Equal(t, lua.LTNumber, count.Type())
assert.Equal(t, 42.0, float64(count.(lua.LNumber)))
}
func TestSetVariablesEmpty(t *testing.T) {
// Test with empty map
vars := map[string]interface{}{}
SetVariables(vars)
// Should not panic
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
}
func TestSetVariablesNil(t *testing.T) {
// Test with nil map
SetVariables(nil)
// Should not panic
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
}
func TestGetLuaFunctionsHelp(t *testing.T) {
help := GetLuaFunctionsHelp()
// Verify help is not empty
assert.NotEmpty(t, help)
// Verify it contains documentation for key functions
assert.Contains(t, help, "MATH FUNCTIONS")
assert.Contains(t, help, "STRING FUNCTIONS")
assert.Contains(t, help, "TABLE FUNCTIONS")
assert.Contains(t, help, "XML HELPER FUNCTIONS")
assert.Contains(t, help, "JSON HELPER FUNCTIONS")
assert.Contains(t, help, "HTTP FUNCTIONS")
assert.Contains(t, help, "REGEX FUNCTIONS")
assert.Contains(t, help, "UTILITY FUNCTIONS")
assert.Contains(t, help, "EXAMPLES")
// Verify specific functions are documented
assert.Contains(t, help, "min(a, b)")
assert.Contains(t, help, "max(a, b)")
assert.Contains(t, help, "round(x, n)")
assert.Contains(t, help, "fetch(url, options)")
assert.Contains(t, help, "findElements(root, tagName)")
assert.Contains(t, help, "visitJSON(data, callback)")
assert.Contains(t, help, "re(pattern, input)")
assert.Contains(t, help, "print(...)")
}
func TestFetchFunction(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the fetch function
L.SetGlobal("fetch", L.NewFunction(fetch))
// Test 1: Missing URL should return nil and error
err := L.DoString(`
result, err = fetch("")
assert(result == nil, "Expected nil result for empty URL")
assert(err ~= nil, "Expected error for empty URL")
`)
assert.NoError(t, err)
// Test 2: Invalid URL should return error
err = L.DoString(`
result, err = fetch("not-a-valid-url")
assert(result == nil, "Expected nil result for invalid URL")
assert(err ~= nil, "Expected error for invalid URL")
`)
assert.NoError(t, err)
}
func TestFetchFunctionWithOptions(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the fetch function
L.SetGlobal("fetch", L.NewFunction(fetch))
// Test with options (should fail gracefully with invalid URL)
err := L.DoString(`
local opts = {
method = "POST",
headers = {["Content-Type"] = "application/json"},
body = '{"test": "data"}'
}
result, err = fetch("http://invalid-domain-that-does-not-exist.local", opts)
-- Should get error due to invalid domain
assert(result == nil, "Expected nil result for invalid domain")
assert(err ~= nil, "Expected error for invalid domain")
`)
assert.NoError(t, err)
}
func TestPrependLuaAssignment(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Simple assignment",
input: "10",
expected: "v1 = 10",
},
{
name: "Expression",
input: "v1 * 2",
expected: "v1 = v1 * 2",
},
{
name: "Assignment with equal sign",
input: "= 5",
expected: "v1 = 5",
},
{
name: "Complex expression",
input: "math.floor(v1 / 2)",
expected: "v1 = math.floor(v1 / 2)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := PrependLuaAssignment(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestBuildJSONLuaScript(t *testing.T) {
tests := []struct {
name string
input string
contains []string
}{
{
name: "Simple JSON modification",
input: "data.value = data.value * 2; modified = true",
contains: []string{
"data.value = data.value * 2",
"modified = true",
},
},
{
name: "Complex JSON script",
input: "for i, item in ipairs(data.items) do item.price = item.price * 1.5 end; modified = true",
contains: []string{
"for i, item in ipairs(data.items)",
"modified = true",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildJSONLuaScript(tt.input)
for _, substr := range tt.contains {
assert.Contains(t, result, substr)
}
})
}
}
func TestPrintToGo(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the print function
L.SetGlobal("print", L.NewFunction(printToGo))
// Test printing various types
err := L.DoString(`
print("Hello, World!")
print(42)
print(true)
print(3.14)
`)
assert.NoError(t, err)
}
func TestEvalRegex(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the regex function
L.SetGlobal("re", L.NewFunction(EvalRegex))
// Test 1: Simple match
err := L.DoString(`
matches = re("(\\d+)", "The answer is 42")
assert(matches ~= nil, "Expected matches")
assert(matches[1] == "42", "Expected full match to be 42")
assert(matches[2] == "42", "Expected capture group to be 42")
`)
assert.NoError(t, err)
// Test 2: No match
err = L.DoString(`
matches = re("(\\d+)", "No numbers here")
assert(matches == nil, "Expected nil for no match")
`)
assert.NoError(t, err)
// Test 3: Multiple capture groups
err = L.DoString(`
matches = re("(\\w+)\\s+(\\d+)", "item 123")
assert(matches ~= nil, "Expected matches")
assert(matches[1] == "item 123", "Expected full match")
assert(matches[2] == "item", "Expected first capture group")
assert(matches[3] == "123", "Expected second capture group")
`)
assert.NoError(t, err)
}
func TestEstimatePatternComplexity(t *testing.T) {
tests := []struct {
name string
pattern string
minExpected int
}{
{
name: "Simple literal",
pattern: "hello",
minExpected: 1,
},
{
name: "With capture group",
pattern: "(\\d+)",
minExpected: 2,
},
{
name: "Complex pattern",
pattern: "(?P<name>\\w+)\\s+(?P<value>\\d+\\.\\d+)",
minExpected: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
complexity := estimatePatternComplexity(tt.pattern)
assert.GreaterOrEqual(t, complexity, tt.minExpected)
})
}
}
func TestParseNumeric(t *testing.T) {
tests := []struct {
name string
input string
expected float64
shouldOk bool
}{
{"Integer", "42", 42.0, true},
{"Float", "3.14", 3.14, true},
{"Negative", "-10", -10.0, true},
{"Invalid", "not a number", 0, false},
{"Empty", "", 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, ok := parseNumeric(tt.input)
assert.Equal(t, tt.shouldOk, ok)
if tt.shouldOk {
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestFormatNumeric(t *testing.T) {
tests := []struct {
name string
input float64
expected string
}{
{"Integer value", 42.0, "42"},
{"Float value", 3.14, "3.14"},
{"Negative integer", -10.0, "-10"},
{"Negative float", -3.14, "-3.14"},
{"Zero", 0.0, "0"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatNumeric(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestLuaHelperFunctionsDocumentation(t *testing.T) {
help := GetLuaFunctionsHelp()
// All main function categories should be documented
expectedCategories := []string{
"MATH FUNCTIONS",
"STRING FUNCTIONS",
"XML HELPER FUNCTIONS",
"JSON HELPER FUNCTIONS",
}
for _, category := range expectedCategories {
assert.Contains(t, help, category, "Help should contain category: %s", category)
}
// Verify some key functions are mentioned
keyFunctions := []string{
"findElements",
"visitElements",
"visitJSON",
"round",
"fetch",
}
for _, fn := range keyFunctions {
assert.Contains(t, help, fn, "Help should mention function: %s", fn)
}
}

View File

@@ -1,7 +1,6 @@
package processor_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -30,8 +29,8 @@ func TestEvalRegex_CaptureGroupsReturned(t *testing.T) {
}
expected := []string{"test-42", "test", "42"}
for i, v := range expected {
val := tbl.RawGetString(fmt.Sprintf("%d", i))
assert.Equal(t, lua.LString(v), val, "Expected index %d to be %q", i, v)
val := tbl.RawGetInt(i + 1)
assert.Equal(t, lua.LString(v), val, "Expected index %d to be %q", i+1, v)
}
}
@@ -67,9 +66,9 @@ func TestEvalRegex_NoCaptureGroups(t *testing.T) {
if !ok {
t.Fatalf("Expected Lua table, got %T", out)
}
fullMatch := tbl.RawGetString("0")
fullMatch := tbl.RawGetInt(1)
assert.Equal(t, lua.LString("foo123"), fullMatch)
// There should be only the full match (index 0)
// There should be only the full match (index 1)
count := 0
tbl.ForEach(func(k, v lua.LValue) {
count++

View File

@@ -22,9 +22,9 @@ type CaptureGroup struct {
Range [2]int
}
// ProcessContent applies regex replacement with Lua processing
// The filename here exists ONLY so we can pass it to the lua environment
// It's not used for anything else
// ProcessRegex applies regex replacement with Lua processing.
// The filename here exists ONLY so we can pass it to the lua environment.
// It's not used for anything else.
func ProcessRegex(content string, command utils.ModifyCommand, filename string) ([]utils.ReplaceCommand, error) {
processRegexLogger := regexLogger.WithPrefix("ProcessRegex").WithField("commandName", command.Name).WithField("file", filename)
processRegexLogger.Debug("Starting regex processing for file")
@@ -216,9 +216,6 @@ func ProcessRegex(content string, command utils.ModifyCommand, filename string)
}
if replacement == "" {
// Apply the modifications to the original match
replacement = matchContent
// Count groups that were actually modified
modifiedGroupsCount := 0
for _, capture := range updatedCaptureGroups {

View File

@@ -0,0 +1,87 @@
package processor
import (
"cook/utils"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
// Test named capture group fallback when value is not in Lua
func TestNamedCaptureGroupFallback(t *testing.T) {
pattern := `value = (?P<myvalue>\d+)`
input := `value = 42`
// Don't set myvalue in Lua, but do something else so we get a match
lua := `v1 = v1 * 2 -- Set v1 but not myvalue, test fallback`
cmd := utils.ModifyCommand{
Name: "test_fallback",
Regex: pattern,
Lua: lua,
}
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatchIndex(input)
assert.NotNil(t, matches)
replacements, err := ProcessRegex(input, cmd, "test.txt")
// Should not error
assert.NoError(t, err)
// Since only v1 is set, myvalue should keep original
// Should have 1 replacement for v1
if replacements != nil {
assert.GreaterOrEqual(t, len(replacements), 0)
}
}
// Test named capture groups with nil value in Lua
func TestNamedCaptureGroupNilInLua(t *testing.T) {
pattern := `value = (?P<num>\d+)`
input := `value = 123`
// Set num to nil explicitly, and also set v1 to get a modification
lua := `v1 = v1 .. "_test"; num = nil -- v1 modified, num set to nil`
cmd := utils.ModifyCommand{
Name: "test_nil",
Regex: pattern,
Lua: lua,
}
replacements, err := ProcessRegex(input, cmd, "test.txt")
// Should not error
assert.NoError(t, err)
// Should have replacements for v1, num should fallback to original
if replacements != nil {
assert.GreaterOrEqual(t, len(replacements), 0)
}
}
// Test multiple named capture groups with some undefined
func TestMixedNamedCaptureGroups(t *testing.T) {
pattern := `(?P<key>\w+) = (?P<value>\d+)`
input := `count = 100`
lua := `key = key .. "_modified" -- Only modify key, leave value undefined`
cmd := utils.ModifyCommand{
Name: "test_mixed",
Regex: pattern,
Lua: lua,
}
replacements, err := ProcessRegex(input, cmd, "test.txt")
assert.NoError(t, err)
assert.NotNil(t, replacements)
// Apply replacements
result, _ := utils.ExecuteModifications(replacements, input)
// key should be modified, value should remain unchanged
assert.Contains(t, result, "count_modified")
assert.Contains(t, result, "100")
}

View File

@@ -30,7 +30,7 @@ func normalizeWhitespace(s string) string {
return re.ReplaceAllString(strings.TrimSpace(s), " ")
}
func ApiAdaptor(content string, regex string, lua string) (string, int, int, error) {
func APIAdaptor(content string, regex string, lua string) (string, int, int, error) {
command := utils.ModifyCommand{
Regex: regex,
Lua: lua,
@@ -79,7 +79,7 @@ func TestSimpleValueMultiplication(t *testing.T) {
</item>
</config>`
result, mods, matches, err := ApiAdaptor(content, `(?s)<value>(\d+)</value>`, "v1 = v1*1.5")
result, mods, matches, err := APIAdaptor(content, `(?s)<value>(\d+)</value>`, "v1 = v1*1.5")
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 1, matches, "Expected 1 match, got %d", matches)
@@ -100,7 +100,7 @@ func TestShorthandNotation(t *testing.T) {
</item>
</config>`
result, mods, matches, err := ApiAdaptor(content, `(?s)<value>(\d+)</value>`, "v1*1.5")
result, mods, matches, err := APIAdaptor(content, `(?s)<value>(\d+)</value>`, "v1*1.5")
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 1, matches, "Expected 1 match, got %d", matches)
@@ -121,7 +121,7 @@ func TestShorthandNotationFloats(t *testing.T) {
</item>
</config>`
result, mods, matches, err := ApiAdaptor(content, `(?s)<value>(\d+\.\d+)</value>`, "v1*1.5")
result, mods, matches, err := APIAdaptor(content, `(?s)<value>(\d+\.\d+)</value>`, "v1*1.5")
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 1, matches, "Expected 1 match, got %d", matches)
@@ -146,7 +146,7 @@ func TestArrayNotation(t *testing.T) {
</prices>
</config>`
result, mods, matches, err := ApiAdaptor(content, `(?s)<price>(\d+)</price>`, "v1*2")
result, mods, matches, err := APIAdaptor(content, `(?s)<price>(\d+)</price>`, "v1*2")
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 3, matches, "Expected 3 matches, got %d", matches)
@@ -167,7 +167,7 @@ func TestMultipleNumericMatches(t *testing.T) {
<entry>400</entry>
</data>`
result, mods, matches, err := ApiAdaptor(content, `<entry>(\d+)</entry>`, "v1*2")
result, mods, matches, err := APIAdaptor(content, `<entry>(\d+)</entry>`, "v1*2")
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 3, matches, "Expected 3 matches, got %d", matches)
@@ -186,7 +186,7 @@ func TestMultipleStringMatches(t *testing.T) {
<name>Mary_modified</name>
</data>`
result, mods, matches, err := ApiAdaptor(content, `<name>([A-Za-z]+)</name>`, `s1 = s1 .. "_modified"`)
result, mods, matches, err := APIAdaptor(content, `<name>([A-Za-z]+)</name>`, `s1 = s1 .. "_modified"`)
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 2, matches, "Expected 2 matches, got %d", matches)
@@ -205,7 +205,7 @@ func TestStringUpperCase(t *testing.T) {
<user>MARY</user>
</users>`
result, mods, matches, err := ApiAdaptor(content, `<user>([A-Za-z]+)</user>`, `s1 = string.upper(s1)`)
result, mods, matches, err := APIAdaptor(content, `<user>([A-Za-z]+)</user>`, `s1 = string.upper(s1)`)
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 2, matches, "Expected 2 matches, got %d", matches)
@@ -224,7 +224,7 @@ func TestStringConcatenation(t *testing.T) {
<product>Banana_fruit</product>
</products>`
result, mods, matches, err := ApiAdaptor(content, `<product>([A-Za-z]+)</product>`, `s1 = s1 .. "_fruit"`)
result, mods, matches, err := APIAdaptor(content, `<product>([A-Za-z]+)</product>`, `s1 = s1 .. "_fruit"`)
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 2, matches, "Expected 2 matches, got %d", matches)
@@ -254,7 +254,7 @@ func TestDecimalValues(t *testing.T) {
regex := regexp.MustCompile(`(?s)<value>([0-9.]+)</value>.*?<multiplier>([0-9.]+)</multiplier>`)
luaExpr := BuildLuaScript("v1 = v1 * v2")
result, _, _, err := ApiAdaptor(content, regex.String(), luaExpr)
result, _, _, err := APIAdaptor(content, regex.String(), luaExpr)
assert.NoError(t, err, "Error processing content: %v", err)
normalizedModified := normalizeWhitespace(result)
@@ -282,7 +282,7 @@ func TestLuaMathFunctions(t *testing.T) {
regex := regexp.MustCompile(`(?s)<value>(\d+)</value>`)
luaExpr := BuildLuaScript("v1 = math.sqrt(v1)")
modifiedContent, _, _, err := ApiAdaptor(content, regex.String(), luaExpr)
modifiedContent, _, _, err := APIAdaptor(content, regex.String(), luaExpr)
assert.NoError(t, err, "Error processing content: %v", err)
normalizedModified := normalizeWhitespace(modifiedContent)
@@ -310,7 +310,7 @@ func TestDirectAssignment(t *testing.T) {
regex := regexp.MustCompile(`(?s)<value>(\d+)</value>`)
luaExpr := BuildLuaScript("=0")
modifiedContent, _, _, err := ApiAdaptor(content, regex.String(), luaExpr)
modifiedContent, _, _, err := APIAdaptor(content, regex.String(), luaExpr)
assert.NoError(t, err, "Error processing content: %v", err)
normalizedModified := normalizeWhitespace(modifiedContent)
@@ -369,7 +369,7 @@ func TestStringAndNumericOperations(t *testing.T) {
luaExpr := BuildLuaScript(tt.luaExpression)
// Process with our function
result, modCount, _, err := ApiAdaptor(tt.input, pattern, luaExpr)
result, modCount, _, err := APIAdaptor(tt.input, pattern, luaExpr)
assert.NoError(t, err, "Process function failed: %v", err)
// Check results
@@ -430,7 +430,7 @@ func TestEdgeCases(t *testing.T) {
luaExpr := BuildLuaScript(tt.luaExpression)
// Process with our function
result, modCount, _, err := ApiAdaptor(tt.input, pattern, luaExpr)
result, modCount, _, err := APIAdaptor(tt.input, pattern, luaExpr)
assert.NoError(t, err, "Process function failed: %v", err)
// Check results
@@ -453,7 +453,7 @@ func TestNamedCaptureGroups(t *testing.T) {
</item>
</config>`
result, mods, matches, err := ApiAdaptor(content, `(?s)<value>(?<amount>\d+)</value>`, "amount = amount * 2")
result, mods, matches, err := APIAdaptor(content, `(?s)<value>(?<amount>\d+)</value>`, "amount = amount * 2")
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 1, matches, "Expected 1 match, got %d", matches)
@@ -474,7 +474,7 @@ func TestNamedCaptureGroupsNum(t *testing.T) {
</item>
</config>`
result, mods, matches, err := ApiAdaptor(content, `(?s)<value>(?<amount>!num)</value>`, "amount = amount * 2")
result, mods, matches, err := APIAdaptor(content, `(?s)<value>(?<amount>!num)</value>`, "amount = amount * 2")
assert.NoError(t, err, "Error processing content: %v", err)
assert.Equal(t, 1, matches, "Expected 1 match, got %d", matches)
@@ -495,7 +495,7 @@ func TestMultipleNamedCaptureGroups(t *testing.T) {
<quantity>15</quantity>
</product>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`(?s)<name>(?<prodName>[^<]+)</name>.*?<price>(?<prodPrice>\d+\.\d+)</price>.*?<quantity>(?<prodQty>\d+)</quantity>`,
`prodName = string.upper(prodName)
prodPrice = round(prodPrice + 8, 2)
@@ -518,7 +518,7 @@ func TestMixedIndexedAndNamedCaptures(t *testing.T) {
<data>VALUE</data>
</entry>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`(?s)<id>(\d+)</id>.*?<data>(?<dataField>[^<]+)</data>`,
`v1 = v1 * 2
dataField = string.upper(dataField)`)
@@ -550,7 +550,7 @@ func TestComplexNestedNamedCaptures(t *testing.T) {
</contact>
</person>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`(?s)<details>.*?<name>(?<fullName>[^<]+)</name>.*?<age>(?<age>\d+)</age>`,
`fullName = string.upper(fullName) .. " (" .. age .. ")"`)
@@ -571,7 +571,7 @@ func TestNamedCaptureWithVariableReadback(t *testing.T) {
<mana>300</mana>
</stats>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`(?s)<health>(?<hp>\d+)</health>.*?<mana>(?<mp>\d+)</mana>`,
`hp = hp * 1.5
mp = mp * 1.5`)
@@ -587,7 +587,7 @@ func TestNamedCaptureWithSpecialCharsInName(t *testing.T) {
expected := `<data value="84" min="10" max="100" />`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<data value="(?<val_1>\d+)"`,
`val_1 = val_1 * 2`)
@@ -602,7 +602,7 @@ func TestEmptyNamedCapture(t *testing.T) {
expected := `<tag attr="default" />`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`attr="(?<value>.*?)"`,
`value = value == "" and "default" or value`)
@@ -617,7 +617,7 @@ func TestMultipleNamedCapturesInSameLine(t *testing.T) {
expected := `<rect x="20" y="40" width="200" height="100" />`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`x="(?<x>\d+)" y="(?<y>\d+)" width="(?<w>\d+)" height="(?<h>\d+)"`,
`x = x * 2
y = y * 2
@@ -641,7 +641,7 @@ func TestConditionalNamedCapture(t *testing.T) {
<item status="inactive" count="10" />
`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<item status="(?<status>[^"]+)" count="(?<count>\d+)"`,
`count = status == "active" and count * 2 or count`)
@@ -662,7 +662,7 @@ func TestLuaFunctionsOnNamedCaptures(t *testing.T) {
<user name="JANE SMITH" role="admin" />
`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<user name="(?<name>[^"]+)" role="(?<role>[^"]+)"`,
`-- Capitalize first letters for regular users
if role == "user" then
@@ -692,7 +692,7 @@ func TestNamedCaptureWithMath(t *testing.T) {
<item price="19.99" quantity="3" total="59.97" />
`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<item price="(?<price>\d+\.\d+)" quantity="(?<qty>\d+)"!any$`,
`-- Calculate and add total
replacement = string.format('<item price="%s" quantity="%s" total="%.2f" />',
@@ -712,7 +712,7 @@ func TestNamedCaptureWithGlobals(t *testing.T) {
expected := `<temp unit="F">77</temp>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<temp unit="(?<unit>[CF]?)">(?<value>\d+)</temp>`,
`if unit == "C" then
value = value * 9/5 + 32
@@ -739,7 +739,7 @@ func TestMixedDynamicAndNamedCaptures(t *testing.T) {
<color rgb="0,255,0" name="GREEN" hex="#00FF00" />
`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<color rgb="(?<r>\d+),(?<g>\d+),(?<b>\d+)" name="(?<colorName>[^"]+)" />`,
`-- Uppercase the name
colorName = string.upper(colorName)
@@ -765,7 +765,7 @@ func TestNamedCapturesWithMultipleReferences(t *testing.T) {
expected := `<text format="uppercase" length="11">HELLO WORLD</text>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<text>(?<content>[^<]+)</text>`,
`local uppercaseContent = string.upper(content)
local contentLength = string.len(content)
@@ -783,7 +783,7 @@ func TestNamedCaptureWithJsonData(t *testing.T) {
expected := `<data>{"name":"JOHN","age":30}</data>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<data>(?<json>\{.*?\})</data>`,
`-- Parse JSON (simplified, assumes valid JSON)
local name = json:match('"name":"([^"]+)"')
@@ -813,7 +813,7 @@ func TestNamedCaptureInXML(t *testing.T) {
</product>
`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`(?s)<price currency="(?<currency>[^"]+)">(?<price>\d+\.\d+)</price>.*?<stock>(?<stock>\d+)</stock>`,
`-- Add 20% to price if USD
if currency == "USD" then
@@ -870,7 +870,7 @@ func TestComprehensiveNamedCaptures(t *testing.T) {
</products>
`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`(?s)<product sku="(?<sku>[^"]+)" status="(?<status>[^"]+)"[^>]*>\s*<name>(?<product_name>[^<]+)</name>\s*<price currency="(?<currency>[^"]+)">(?<price>\d+\.\d+)</price>\s*<quantity>(?<qty>\d+)</quantity>`,
`-- Only process in-stock items
if status == "in-stock" then
@@ -924,7 +924,7 @@ func TestVariousNamedCaptureFormats(t *testing.T) {
</data>
`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`<entry id="(?<id_num>\d+)" value="(?<val>\d+)"(?: status="(?<status>[^"]*)")? />`,
`-- Prefix the ID with "ID-"
id_num = "ID-" .. id_num
@@ -963,7 +963,7 @@ func TestSimpleNamedCapture(t *testing.T) {
expected := `<product name="WIDGET" price="19.99"/>`
result, mods, matches, err := ApiAdaptor(content,
result, mods, matches, err := APIAdaptor(content,
`name="(?<product_name>[^"]+)"`,
`product_name = string.upper(product_name)`)

View File

@@ -1,27 +0,0 @@
package processor
import (
"io"
"os"
logger "git.site.quack-lab.dev/dave/cylogger"
)
func init() {
// Only modify logger in test mode
// This checks if we're running under 'go test'
if os.Getenv("GO_TESTING") == "1" || os.Getenv("TESTING") == "1" {
// Initialize logger with ERROR level for tests
// to minimize noise in test output
logger.Init(logger.LevelError)
// Optionally redirect logger output to discard
// This prevents logger output from interfering with test output
disableTestLogs := os.Getenv("ENABLE_TEST_LOGS") != "1"
if disableTestLogs {
// Create a new logger that writes to nowhere
silentLogger := logger.New(io.Discard, "", 0)
logger.Default = silentLogger
}
}
}

533
processor/xml.go Normal file
View File

@@ -0,0 +1,533 @@
package processor
import (
"cook/utils"
"encoding/xml"
"fmt"
"io"
"sort"
"strconv"
"strings"
logger "git.site.quack-lab.dev/dave/cylogger"
lua "github.com/yuin/gopher-lua"
)
var xmlLogger = logger.Default.WithPrefix("processor/xml")
// XMLElement represents a parsed XML element with position tracking
type XMLElement struct {
Tag string
Attributes map[string]XMLAttribute
Text string
Children []*XMLElement
StartPos int64
EndPos int64
TextStart int64
TextEnd int64
}
// XMLAttribute represents an attribute with its position in the source
type XMLAttribute struct {
Value string
ValueStart int64
ValueEnd int64
}
// parseXMLWithPositions parses XML while tracking byte positions of all elements and attributes
func parseXMLWithPositions(content string) (*XMLElement, error) {
decoder := xml.NewDecoder(strings.NewReader(content))
var root *XMLElement
var stack []*XMLElement
var lastPos int64
for {
token, err := decoder.Token()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("failed to parse XML: %v", err)
}
offset := decoder.InputOffset()
switch t := token.(type) {
case xml.StartElement:
// Find the actual start position of this element by searching for "<tagname"
tagSearchPattern := "<" + t.Name.Local
startPos := int64(strings.LastIndex(content[:offset], tagSearchPattern))
element := &XMLElement{
Tag: t.Name.Local,
Attributes: make(map[string]XMLAttribute),
StartPos: startPos,
Children: []*XMLElement{},
}
// Parse attributes - search within the tag boundaries
if len(t.Attr) > 0 {
tagEnd := offset
tagSection := content[startPos:tagEnd]
for _, attr := range t.Attr {
// Find attribute in the tag section: attrname="value"
attrPattern := attr.Name.Local + `="`
attrIdx := strings.Index(tagSection, attrPattern)
if attrIdx >= 0 {
valueStart := startPos + int64(attrIdx) + int64(len(attrPattern))
valueEnd := valueStart + int64(len(attr.Value))
element.Attributes[attr.Name.Local] = XMLAttribute{
Value: attr.Value,
ValueStart: valueStart,
ValueEnd: valueEnd,
}
}
}
}
if len(stack) > 0 {
parent := stack[len(stack)-1]
parent.Children = append(parent.Children, element)
} else {
root = element
}
stack = append(stack, element)
lastPos = offset
case xml.CharData:
rawText := string(t)
text := strings.TrimSpace(rawText)
if len(stack) > 0 && text != "" {
current := stack[len(stack)-1]
current.Text = text
// The text content is between lastPos (after >) and offset (before </)
// Search for the trimmed text within the raw content
textInContent := content[lastPos:offset]
trimmedStart := strings.Index(textInContent, text)
if trimmedStart >= 0 {
current.TextStart = lastPos + int64(trimmedStart)
current.TextEnd = current.TextStart + int64(len(text))
}
}
lastPos = offset
case xml.EndElement:
if len(stack) > 0 {
current := stack[len(stack)-1]
current.EndPos = offset
stack = stack[:len(stack)-1]
}
lastPos = offset
}
}
return root, nil
}
// XMLChange represents a detected difference between original and modified XML structures
type XMLChange struct {
Type string // "text", "attribute", "add_element", "remove_element"
Path string
OldValue string
NewValue string
StartPos int64
EndPos int64
InsertText string
}
func findXMLChanges(original, modified *XMLElement, path string) []XMLChange {
var changes []XMLChange
// Check text content changes
if original.Text != modified.Text {
changes = append(changes, XMLChange{
Type: "text",
Path: path,
OldValue: original.Text,
NewValue: modified.Text,
StartPos: original.TextStart,
EndPos: original.TextEnd,
})
}
// Check attribute changes
for attrName, origAttr := range original.Attributes {
if modAttr, exists := modified.Attributes[attrName]; exists {
if origAttr.Value != modAttr.Value {
changes = append(changes, XMLChange{
Type: "attribute",
Path: path + "/@" + attrName,
OldValue: origAttr.Value,
NewValue: modAttr.Value,
StartPos: origAttr.ValueStart,
EndPos: origAttr.ValueEnd,
})
}
} else {
// Attribute removed
changes = append(changes, XMLChange{
Type: "remove_attribute",
Path: path + "/@" + attrName,
OldValue: origAttr.Value,
StartPos: origAttr.ValueStart - int64(len(attrName)+2), // Include attr=" part
EndPos: origAttr.ValueEnd + 1, // Include closing "
})
}
}
// Check for added attributes
for attrName, modAttr := range modified.Attributes {
if _, exists := original.Attributes[attrName]; !exists {
changes = append(changes, XMLChange{
Type: "add_attribute",
Path: path + "/@" + attrName,
NewValue: modAttr.Value,
StartPos: original.StartPos, // Will be adjusted to insert after tag name
InsertText: fmt.Sprintf(` %s="%s"`, attrName, modAttr.Value),
})
}
}
// Check children recursively
origChildMap := make(map[string][]*XMLElement)
for _, child := range original.Children {
origChildMap[child.Tag] = append(origChildMap[child.Tag], child)
}
modChildMap := make(map[string][]*XMLElement)
for _, child := range modified.Children {
modChildMap[child.Tag] = append(modChildMap[child.Tag], child)
}
// Compare children by tag name
processedTags := make(map[string]bool)
for tag, origChildren := range origChildMap {
processedTags[tag] = true
modChildren := modChildMap[tag]
// Match children by index
maxLen := len(origChildren)
if len(modChildren) > maxLen {
maxLen = len(modChildren)
}
for i := 0; i < maxLen; i++ {
childPath := fmt.Sprintf("%s/%s[%d]", path, tag, i)
if i < len(origChildren) && i < len(modChildren) {
// Both exist, compare recursively
childChanges := findXMLChanges(origChildren[i], modChildren[i], childPath)
changes = append(changes, childChanges...)
} else if i < len(origChildren) {
// Child removed
changes = append(changes, XMLChange{
Type: "remove_element",
Path: childPath,
StartPos: origChildren[i].StartPos,
EndPos: origChildren[i].EndPos,
})
}
}
// Handle added children
if len(modChildren) > len(origChildren) {
for i := len(origChildren); i < len(modChildren); i++ {
childPath := fmt.Sprintf("%s/%s[%d]", path, tag, i)
// Generate XML text for the new element
xmlText := serializeXMLElement(modChildren[i], " ")
changes = append(changes, XMLChange{
Type: "add_element",
Path: childPath,
InsertText: xmlText,
StartPos: original.EndPos - int64(len(original.Tag)+3), // Before closing tag
})
}
}
}
return changes
}
// serializeXMLElement converts an XMLElement back to XML text
func serializeXMLElement(elem *XMLElement, indent string) string {
var sb strings.Builder
sb.WriteString(indent)
sb.WriteString("<")
sb.WriteString(elem.Tag)
// Write attributes
attrNames := make([]string, 0, len(elem.Attributes))
for name := range elem.Attributes {
attrNames = append(attrNames, name)
}
sort.Strings(attrNames)
for _, name := range attrNames {
attr := elem.Attributes[name]
sb.WriteString(fmt.Sprintf(` %s="%s"`, name, attr.Value))
}
if elem.Text == "" && len(elem.Children) == 0 {
sb.WriteString(" />")
return sb.String()
}
sb.WriteString(">")
if elem.Text != "" {
sb.WriteString(elem.Text)
}
if len(elem.Children) > 0 {
sb.WriteString("\n")
for _, child := range elem.Children {
sb.WriteString(serializeXMLElement(child, indent+" "))
sb.WriteString("\n")
}
sb.WriteString(indent)
}
sb.WriteString("</")
sb.WriteString(elem.Tag)
sb.WriteString(">")
return sb.String()
}
// applyXMLChanges generates ReplaceCommands from detected XML changes
func applyXMLChanges(changes []XMLChange) []utils.ReplaceCommand {
var commands []utils.ReplaceCommand
for _, change := range changes {
switch change.Type {
case "text":
commands = append(commands, utils.ReplaceCommand{
From: int(change.StartPos),
To: int(change.EndPos),
With: change.NewValue,
})
case "attribute":
commands = append(commands, utils.ReplaceCommand{
From: int(change.StartPos),
To: int(change.EndPos),
With: change.NewValue,
})
case "add_attribute":
// Insert after tag name, before > or />
commands = append(commands, utils.ReplaceCommand{
From: int(change.StartPos),
To: int(change.StartPos),
With: change.InsertText,
})
case "remove_attribute":
commands = append(commands, utils.ReplaceCommand{
From: int(change.StartPos),
To: int(change.EndPos),
With: "",
})
case "add_element":
commands = append(commands, utils.ReplaceCommand{
From: int(change.StartPos),
To: int(change.StartPos),
With: "\n" + change.InsertText,
})
case "remove_element":
commands = append(commands, utils.ReplaceCommand{
From: int(change.StartPos),
To: int(change.EndPos),
With: "",
})
}
}
return commands
}
// deepCopyXMLElement creates a deep copy of an XMLElement
func deepCopyXMLElement(elem *XMLElement) *XMLElement {
if elem == nil {
return nil
}
copied := &XMLElement{
Tag: elem.Tag,
Text: elem.Text,
StartPos: elem.StartPos,
EndPos: elem.EndPos,
TextStart: elem.TextStart,
TextEnd: elem.TextEnd,
Attributes: make(map[string]XMLAttribute),
Children: make([]*XMLElement, len(elem.Children)),
}
for k, v := range elem.Attributes {
copied.Attributes[k] = v
}
for i, child := range elem.Children {
copied.Children[i] = deepCopyXMLElement(child)
}
return copied
}
// Helper function to parse numeric values
func parseNumeric(s string) (float64, bool) {
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f, true
}
return 0, false
}
// Helper function to format numeric values
func formatNumeric(f float64) string {
if f == float64(int64(f)) {
return strconv.FormatInt(int64(f), 10)
}
return strconv.FormatFloat(f, 'f', -1, 64)
}
// ProcessXML applies Lua processing to XML content with surgical editing
func ProcessXML(content string, command utils.ModifyCommand, filename string) ([]utils.ReplaceCommand, error) {
processXMLLogger := xmlLogger.WithPrefix("ProcessXML").WithField("commandName", command.Name).WithField("file", filename)
processXMLLogger.Debug("Starting XML processing for file")
// Parse XML with position tracking
originalElem, err := parseXMLWithPositions(content)
if err != nil {
processXMLLogger.Error("Failed to parse XML: %v", err)
return nil, fmt.Errorf("failed to parse XML: %v", err)
}
processXMLLogger.Debug("Successfully parsed XML content")
// Create Lua state
L, err := NewLuaState()
if err != nil {
processXMLLogger.Error("Error creating Lua state: %v", err)
return nil, fmt.Errorf("error creating Lua state: %v", err)
}
defer L.Close()
// Set filename global
L.SetGlobal("file", lua.LString(filename))
// Create modifiable copy
modifiedElem := deepCopyXMLElement(originalElem)
// Convert to Lua table and set as global
luaTable := xmlElementToLuaTable(L, modifiedElem)
L.SetGlobal("root", luaTable)
processXMLLogger.Debug("Set XML data as Lua global 'root'")
// Build and execute Lua script
luaExpr := BuildJSONLuaScript(command.Lua) // Reuse JSON script builder
processXMLLogger.Debug("Built Lua script from expression: %q", command.Lua)
if err := L.DoString(luaExpr); err != nil {
processXMLLogger.Error("Lua script execution failed: %v\nScript: %s", err, luaExpr)
return nil, fmt.Errorf("lua script execution failed: %v", err)
}
processXMLLogger.Debug("Lua script executed successfully")
// Check if modification flag is set
modifiedVal := L.GetGlobal("modified")
if modifiedVal.Type() != lua.LTBool || !lua.LVAsBool(modifiedVal) {
processXMLLogger.Debug("Skipping - no modifications indicated by Lua script")
return nil, nil
}
// Get the modified data back from Lua
modifiedTable := L.GetGlobal("root")
if modifiedTable.Type() != lua.LTTable {
processXMLLogger.Error("Expected 'root' to be a table after Lua processing")
return nil, fmt.Errorf("expected 'root' to be a table after Lua processing")
}
// Apply Lua modifications back to XMLElement
luaTableToXMLElement(L, modifiedTable.(*lua.LTable), modifiedElem)
// Find changes between original and modified
changes := findXMLChanges(originalElem, modifiedElem, "")
processXMLLogger.Debug("Found %d changes", len(changes))
if len(changes) == 0 {
return nil, nil
}
// Generate surgical replace commands
commands := applyXMLChanges(changes)
processXMLLogger.Debug("Generated %d replace commands", len(commands))
return commands, nil
}
// xmlElementToLuaTable converts an XMLElement to a Lua table
func xmlElementToLuaTable(L *lua.LState, elem *XMLElement) *lua.LTable {
table := L.CreateTable(0, 4)
table.RawSetString("_tag", lua.LString(elem.Tag))
if len(elem.Attributes) > 0 {
attrs := L.CreateTable(0, len(elem.Attributes))
for name, attr := range elem.Attributes {
attrs.RawSetString(name, lua.LString(attr.Value))
}
table.RawSetString("_attr", attrs)
}
if len(elem.Children) > 0 {
children := L.CreateTable(len(elem.Children), 0)
for i, child := range elem.Children {
children.RawSetInt(i+1, xmlElementToLuaTable(L, child))
}
table.RawSetString("_children", children)
}
return table
}
// luaTableToXMLElement applies Lua table modifications back to XMLElement
func luaTableToXMLElement(L *lua.LState, table *lua.LTable, elem *XMLElement) {
// Update attributes
if attrVal := table.RawGetString("_attr"); attrVal.Type() == lua.LTTable {
attrTable := attrVal.(*lua.LTable)
// Clear and rebuild attributes
elem.Attributes = make(map[string]XMLAttribute)
attrTable.ForEach(func(key lua.LValue, value lua.LValue) {
if key.Type() == lua.LTString && value.Type() == lua.LTString {
attrName := string(key.(lua.LString))
attrValue := string(value.(lua.LString))
elem.Attributes[attrName] = XMLAttribute{Value: attrValue}
}
})
}
// Update children
if childrenVal := table.RawGetString("_children"); childrenVal.Type() == lua.LTTable {
childrenTable := childrenVal.(*lua.LTable)
newChildren := []*XMLElement{}
// Iterate over array indices
for i := 1; ; i++ {
childVal := childrenTable.RawGetInt(i)
if childVal.Type() == lua.LTNil {
break
}
if childVal.Type() == lua.LTTable {
if i-1 < len(elem.Children) {
// Update existing child
luaTableToXMLElement(L, childVal.(*lua.LTable), elem.Children[i-1])
newChildren = append(newChildren, elem.Children[i-1])
}
}
}
elem.Children = newChildren
}
}

View File

@@ -0,0 +1,346 @@
package processor
import (
"strings"
"testing"
"cook/utils"
)
// TestRealWorldGameXML tests with game-like XML structure
func TestRealWorldGameXML(t *testing.T) {
original := `<?xml version="1.0" encoding="utf-8"?>
<Items>
<Item name="Fiber" identifier="Item_Fiber" category="Resource">
<Icon texture="Items/Fiber.png" />
<Weight value="0.01" />
<MaxStack value="1000" />
<Description text="Soft plant fibers useful for crafting." />
</Item>
<Item name="Wood" identifier="Item_Wood" category="Resource">
<Icon texture="Items/Wood.png" />
<Weight value="0.05" />
<MaxStack value="500" />
<Description text="Basic building material." />
</Item>
</Items>`
// Parse
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse: %v", err)
}
// Modify: Double all MaxStack values and change Wood weight
modElem := deepCopyXMLElement(origElem)
// Fiber MaxStack: 1000 → 2000
fiberItem := modElem.Children[0]
fiberMaxStack := fiberItem.Children[2]
valueAttr := fiberMaxStack.Attributes["value"]
valueAttr.Value = "2000"
fiberMaxStack.Attributes["value"] = valueAttr
// Wood MaxStack: 500 → 1000
woodItem := modElem.Children[1]
woodMaxStack := woodItem.Children[2]
valueAttr2 := woodMaxStack.Attributes["value"]
valueAttr2.Value = "1000"
woodMaxStack.Attributes["value"] = valueAttr2
// Wood Weight: 0.05 → 0.10
woodWeight := woodItem.Children[1]
weightAttr := woodWeight.Attributes["value"]
weightAttr.Value = "0.10"
woodWeight.Attributes["value"] = weightAttr
// Generate changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 3 {
t.Fatalf("Expected 3 changes, got %d", len(changes))
}
// Apply
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Verify changes
if !strings.Contains(result, `<MaxStack value="2000"`) {
t.Errorf("Failed to update Fiber MaxStack")
}
if !strings.Contains(result, `<MaxStack value="1000"`) {
t.Errorf("Failed to update Wood MaxStack")
}
if !strings.Contains(result, `<Weight value="0.10"`) {
t.Errorf("Failed to update Wood Weight")
}
// Verify formatting preserved (check XML declaration and indentation)
if !strings.HasPrefix(result, `<?xml version="1.0" encoding="utf-8"?>`) {
t.Errorf("XML declaration not preserved")
}
if !strings.Contains(result, "\n <Item") {
t.Errorf("Indentation not preserved")
}
}
// TestAddRemoveMultipleChildren tests adding and removing multiple elements
func TestAddRemoveMultipleChildren(t *testing.T) {
original := `<inventory>
<item name="sword" />
<item name="shield" />
<item name="potion" />
<item name="scroll" />
</inventory>`
// Parse
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse: %v", err)
}
// Remove middle two items, add a new one
modElem := deepCopyXMLElement(origElem)
// Remove shield and potion (indices 1 and 2)
modElem.Children = []*XMLElement{
modElem.Children[0], // sword
modElem.Children[3], // scroll
}
// Add a new item
newItem := &XMLElement{
Tag: "item",
Attributes: map[string]XMLAttribute{
"name": {Value: "helmet"},
},
Children: []*XMLElement{},
}
modElem.Children = append(modElem.Children, newItem)
// Generate changes
changes := findXMLChanges(origElem, modElem, "")
// The algorithm compares by matching indices:
// orig[0]=sword vs mod[0]=sword (no change)
// orig[1]=shield vs mod[1]=scroll (treated as replace - shows as attribute changes)
// orig[2]=potion vs mod[2]=helmet (treated as replace)
// orig[3]=scroll (removed)
// This is fine - the actual edits will be correct
if len(changes) == 0 {
t.Fatalf("Expected changes, got none")
}
// Apply
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Verify
if strings.Contains(result, `name="shield"`) {
t.Errorf("Shield not removed")
}
if strings.Contains(result, `name="potion"`) {
t.Errorf("Potion not removed")
}
if !strings.Contains(result, `name="sword"`) {
t.Errorf("Sword incorrectly removed")
}
if !strings.Contains(result, `name="scroll"`) {
t.Errorf("Scroll incorrectly removed")
}
if !strings.Contains(result, `name="helmet"`) {
t.Errorf("Helmet not added")
}
}
// TestModifyAttributesAndText tests changing both attributes and text content
func TestModifyAttributesAndText(t *testing.T) {
original := `<weapon>
<item type="sword" damage="10">Iron Sword</item>
<item type="axe" damage="15">Battle Axe</item>
</weapon>`
// Parse
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse: %v", err)
}
// Modify both items
modElem := deepCopyXMLElement(origElem)
// First item: change damage and text
item1 := modElem.Children[0]
dmgAttr := item1.Attributes["damage"]
dmgAttr.Value = "20"
item1.Attributes["damage"] = dmgAttr
item1.Text = "Steel Sword"
// Second item: change damage and type
item2 := modElem.Children[1]
dmgAttr2 := item2.Attributes["damage"]
dmgAttr2.Value = "30"
item2.Attributes["damage"] = dmgAttr2
typeAttr := item2.Attributes["type"]
typeAttr.Value = "greataxe"
item2.Attributes["type"] = typeAttr
// Generate and apply changes
changes := findXMLChanges(origElem, modElem, "")
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Verify
if !strings.Contains(result, `damage="20"`) {
t.Errorf("First item damage not updated")
}
if !strings.Contains(result, "Steel Sword") {
t.Errorf("First item text not updated")
}
if !strings.Contains(result, `damage="30"`) {
t.Errorf("Second item damage not updated")
}
if !strings.Contains(result, `type="greataxe"`) {
t.Errorf("Second item type not updated")
}
if strings.Contains(result, "Iron Sword") {
t.Errorf("Old text still present")
}
}
// TestSelfClosingTagPreservation tests that self-closing tags work correctly
func TestSelfClosingTagPreservation(t *testing.T) {
original := `<root>
<item name="test" />
<empty></empty>
</root>`
// Parse
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse: %v", err)
}
// Modify first item's attribute
modElem := deepCopyXMLElement(origElem)
item := modElem.Children[0]
nameAttr := item.Attributes["name"]
nameAttr.Value = "modified"
item.Attributes["name"] = nameAttr
// Generate and apply changes
changes := findXMLChanges(origElem, modElem, "")
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Verify the change was made
if !strings.Contains(result, `name="modified"`) {
t.Errorf("Attribute not updated: %s", result)
}
}
// TestNumericAttributeModification tests numeric attribute changes
func TestNumericAttributeModification(t *testing.T) {
original := `<stats health="100" mana="50" stamina="75.5" />`
// Parse
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse: %v", err)
}
// Double all numeric values
modElem := deepCopyXMLElement(origElem)
// Helper to modify numeric attributes
modifyNumericAttr := func(attrName string, multiplier float64) {
if attr, exists := modElem.Attributes[attrName]; exists {
if val, ok := parseNumeric(attr.Value); ok {
attr.Value = formatNumeric(val * multiplier)
modElem.Attributes[attrName] = attr
}
}
}
modifyNumericAttr("health", 2.0)
modifyNumericAttr("mana", 2.0)
modifyNumericAttr("stamina", 2.0)
// Generate and apply changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 3 {
t.Fatalf("Expected 3 changes, got %d", len(changes))
}
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Verify numeric changes
if !strings.Contains(result, `health="200"`) {
t.Errorf("Health not doubled: %s", result)
}
if !strings.Contains(result, `mana="100"`) {
t.Errorf("Mana not doubled: %s", result)
}
if !strings.Contains(result, `stamina="151"`) {
t.Errorf("Stamina not doubled: %s", result)
}
}
// TestMinimalGitDiff verifies that only changed parts are modified
func TestMinimalGitDiff(t *testing.T) {
original := `<config>
<setting name="volume" value="50" />
<setting name="brightness" value="75" />
<setting name="contrast" value="100" />
</config>`
// Parse
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse: %v", err)
}
// Change only brightness
modElem := deepCopyXMLElement(origElem)
brightnessItem := modElem.Children[1]
valueAttr := brightnessItem.Attributes["value"]
valueAttr.Value = "90"
brightnessItem.Attributes["value"] = valueAttr
// Generate changes
changes := findXMLChanges(origElem, modElem, "")
// Should be exactly 1 change
if len(changes) != 1 {
t.Fatalf("Expected exactly 1 change for minimal diff, got %d", len(changes))
}
if changes[0].OldValue != "75" || changes[0].NewValue != "90" {
t.Errorf("Wrong change detected: %v", changes[0])
}
// Apply
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Calculate diff size (rough approximation)
diffChars := len(changes[0].OldValue) + len(changes[0].NewValue)
if diffChars > 10 {
t.Errorf("Diff too large: %d characters changed (expected < 10)", diffChars)
}
// Verify only brightness changed
if !strings.Contains(result, `value="50"`) {
t.Errorf("Volume incorrectly modified")
}
if !strings.Contains(result, `value="90"`) {
t.Errorf("Brightness not modified")
}
if !strings.Contains(result, `value="100"`) {
t.Errorf("Contrast incorrectly modified")
}
}

165
processor/xml_real_test.go Normal file
View File

@@ -0,0 +1,165 @@
package processor
import (
"os"
"strings"
"testing"
"cook/utils"
)
func TestRealAfflictionsXML(t *testing.T) {
// Read the real Afflictions.xml file
content, err := os.ReadFile("../testfiles/Afflictions.xml")
if err != nil {
t.Fatalf("Failed to read Afflictions.xml: %v", err)
}
original := string(content)
// Test 1: Double all maxstrength values using helper functions
command := utils.ModifyCommand{
Name: "double_maxstrength",
Lua: `
-- Double all maxstrength attributes in Affliction elements
local afflictions = findElements(root, "Affliction")
for _, affliction in ipairs(afflictions) do
modifyNumAttr(affliction, "maxstrength", function(val) return val * 2 end)
end
modified = true
`,
}
commands, err := ProcessXML(original, command, "Afflictions.xml")
if err != nil {
t.Fatalf("ProcessXML failed: %v", err)
}
if len(commands) == 0 {
t.Fatal("Expected modifications but got none")
}
t.Logf("Generated %d surgical modifications", len(commands))
// Apply modifications
result, count := utils.ExecuteModifications(commands, original)
t.Logf("Applied %d modifications", count)
// Verify specific changes
if !strings.Contains(result, `maxstrength="20"`) {
t.Errorf("Expected to find maxstrength=\"20\" (doubled from 10)")
}
if !strings.Contains(result, `maxstrength="480"`) {
t.Errorf("Expected to find maxstrength=\"480\" (doubled from 240)")
}
if !strings.Contains(result, `maxstrength="12"`) {
t.Errorf("Expected to find maxstrength=\"12\" (doubled from 6)")
}
// Verify formatting preserved (XML declaration should be there)
if !strings.Contains(result, `<?xml`) {
t.Errorf("XML declaration not preserved")
}
// Count lines to ensure structure preserved
origLines := len(strings.Split(original, "\n"))
resultLines := len(strings.Split(result, "\n"))
if origLines != resultLines {
t.Errorf("Line count changed: original %d, result %d", origLines, resultLines)
}
}
func TestRealAfflictionsAttributes(t *testing.T) {
// Read the real file
content, err := os.ReadFile("../testfiles/Afflictions.xml")
if err != nil {
t.Fatalf("Failed to read Afflictions.xml: %v", err)
}
original := string(content)
// Test 2: Modify resistance values using helper functions
command := utils.ModifyCommand{
Name: "increase_resistance",
Lua: `
-- Increase all minresistance and maxresistance by 50%
local effects = findElements(root, "Effect")
for _, effect in ipairs(effects) do
modifyNumAttr(effect, "minresistance", function(val) return val * 1.5 end)
modifyNumAttr(effect, "maxresistance", function(val) return val * 1.5 end)
end
modified = true
`,
}
commands, err := ProcessXML(original, command, "Afflictions.xml")
if err != nil {
t.Fatalf("ProcessXML failed: %v", err)
}
if len(commands) == 0 {
t.Fatal("Expected modifications but got none")
}
t.Logf("Generated %d surgical modifications", len(commands))
// Apply modifications
_, count := utils.ExecuteModifications(commands, original)
t.Logf("Applied %d modifications", count)
// Verify we made resistance modifications
if count < 10 {
t.Errorf("Expected at least 10 resistance modifications, got %d", count)
}
}
func TestRealAfflictionsNestedModifications(t *testing.T) {
// Read the real file
content, err := os.ReadFile("../testfiles/Afflictions.xml")
if err != nil {
t.Fatalf("Failed to read Afflictions.xml: %v", err)
}
original := string(content)
// Test 3: Modify nested Effect attributes using helper functions
command := utils.ModifyCommand{
Name: "modify_effects",
Lua: `
-- Double all amount values in ReduceAffliction elements
local reduces = findElements(root, "ReduceAffliction")
for _, reduce in ipairs(reduces) do
modifyNumAttr(reduce, "amount", function(val) return val * 2 end)
end
modified = true
`,
}
commands, err := ProcessXML(original, command, "Afflictions.xml")
if err != nil {
t.Fatalf("ProcessXML failed: %v", err)
}
if len(commands) == 0 {
t.Fatal("Expected modifications but got none")
}
t.Logf("Generated %d surgical modifications for nested elements", len(commands))
// Apply modifications
result, count := utils.ExecuteModifications(commands, original)
t.Logf("Applied %d modifications", count)
// Verify nested changes (0.001 * 2 = 0.002)
if !strings.Contains(result, `amount="0.002"`) {
t.Errorf("Expected to find amount=\"0.002\" (0.001 * 2)")
}
// Verify we modified the nested elements
if count < 8 {
t.Errorf("Expected at least 8 amount modifications, got %d", count)
}
}

621
processor/xml_test.go Normal file
View File

@@ -0,0 +1,621 @@
package processor
import (
"strings"
"testing"
"cook/utils"
)
func TestParseXMLWithPositions(t *testing.T) {
xml := `<root><item name="test">Hello</item></root>`
elem, err := parseXMLWithPositions(xml)
if err != nil {
t.Fatalf("Failed to parse XML: %v", err)
}
if elem.Tag != "root" {
t.Errorf("Expected root tag 'root', got '%s'", elem.Tag)
}
if len(elem.Children) != 1 {
t.Fatalf("Expected 1 child, got %d", len(elem.Children))
}
child := elem.Children[0]
if child.Tag != "item" {
t.Errorf("Expected child tag 'item', got '%s'", child.Tag)
}
if child.Attributes["name"].Value != "test" {
t.Errorf("Expected attribute 'name' to be 'test', got '%s'", child.Attributes["name"].Value)
}
if child.Text != "Hello" {
t.Errorf("Expected text 'Hello', got '%s'", child.Text)
}
}
func TestSurgicalTextChange(t *testing.T) {
original := `<root>
<item name="sword" weight="10">A sword</item>
</root>`
expected := `<root>
<item name="sword" weight="10">A modified sword</item>
</root>`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version
modElem := deepCopyXMLElement(origElem)
modElem.Children[0].Text = "A modified sword"
// Find changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 1 {
t.Fatalf("Expected 1 change, got %d", len(changes))
}
if changes[0].Type != "text" {
t.Errorf("Expected change type 'text', got '%s'", changes[0].Type)
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
if result != expected {
t.Errorf("Text change failed.\nExpected:\n%s\n\nGot:\n%s", expected, result)
}
}
func TestSurgicalAttributeChange(t *testing.T) {
original := `<root>
<item name="sword" weight="10" />
</root>`
expected := `<root>
<item name="sword" weight="20" />
</root>`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version
modElem := deepCopyXMLElement(origElem)
attr := modElem.Children[0].Attributes["weight"]
attr.Value = "20"
modElem.Children[0].Attributes["weight"] = attr
// Find changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 1 {
t.Fatalf("Expected 1 change, got %d", len(changes))
}
if changes[0].Type != "attribute" {
t.Errorf("Expected change type 'attribute', got '%s'", changes[0].Type)
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
if result != expected {
t.Errorf("Attribute change failed.\nExpected:\n%s\n\nGot:\n%s", expected, result)
}
}
func TestSurgicalMultipleAttributeChanges(t *testing.T) {
original := `<item name="sword" weight="10" damage="5" />`
expected := `<item name="greatsword" weight="20" damage="15" />`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version
modElem := deepCopyXMLElement(origElem)
nameAttr := modElem.Attributes["name"]
nameAttr.Value = "greatsword"
modElem.Attributes["name"] = nameAttr
weightAttr := modElem.Attributes["weight"]
weightAttr.Value = "20"
modElem.Attributes["weight"] = weightAttr
damageAttr := modElem.Attributes["damage"]
damageAttr.Value = "15"
modElem.Attributes["damage"] = damageAttr
// Find changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 3 {
t.Fatalf("Expected 3 changes, got %d", len(changes))
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
if result != expected {
t.Errorf("Multiple attribute changes failed.\nExpected:\n%s\n\nGot:\n%s", expected, result)
}
}
func TestSurgicalAddAttribute(t *testing.T) {
original := `<item name="sword" />`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version with new attribute
modElem := deepCopyXMLElement(origElem)
modElem.Attributes["weight"] = XMLAttribute{
Value: "10",
}
// Find changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 1 {
t.Fatalf("Expected 1 change, got %d", len(changes))
}
if changes[0].Type != "add_attribute" {
t.Errorf("Expected change type 'add_attribute', got '%s'", changes[0].Type)
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Should contain the new attribute
if !strings.Contains(result, `weight="10"`) {
t.Errorf("Add attribute failed. Result doesn't contain weight=\"10\":\n%s", result)
}
}
func TestSurgicalRemoveAttribute(t *testing.T) {
original := `<item name="sword" weight="10" damage="5" />`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version without weight attribute
modElem := deepCopyXMLElement(origElem)
delete(modElem.Attributes, "weight")
// Find changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 1 {
t.Fatalf("Expected 1 change, got %d", len(changes))
}
if changes[0].Type != "remove_attribute" {
t.Errorf("Expected change type 'remove_attribute', got '%s'", changes[0].Type)
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Should not contain weight attribute
if strings.Contains(result, "weight=") {
t.Errorf("Remove attribute failed. Result still contains 'weight=':\n%s", result)
}
// Should still contain other attributes
if !strings.Contains(result, `name="sword"`) {
t.Errorf("Remove attribute incorrectly removed other attributes:\n%s", result)
}
}
func TestSurgicalAddElement(t *testing.T) {
original := `<root>
<item name="sword" />
</root>`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version with new child
modElem := deepCopyXMLElement(origElem)
newChild := &XMLElement{
Tag: "item",
Attributes: map[string]XMLAttribute{
"name": {Value: "shield"},
},
Children: []*XMLElement{},
}
modElem.Children = append(modElem.Children, newChild)
// Find changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 1 {
t.Fatalf("Expected 1 change, got %d", len(changes))
}
if changes[0].Type != "add_element" {
t.Errorf("Expected change type 'add_element', got '%s'", changes[0].Type)
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Should contain the new element
if !strings.Contains(result, `<item name="shield"`) {
t.Errorf("Add element failed. Result doesn't contain new item:\n%s", result)
}
}
func TestSurgicalRemoveElement(t *testing.T) {
original := `<root>
<item name="sword" />
<item name="shield" />
</root>`
expected := `<root>
<item name="sword" />
</root>`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version without second child
modElem := deepCopyXMLElement(origElem)
modElem.Children = modElem.Children[:1]
// Find changes
changes := findXMLChanges(origElem, modElem, "")
if len(changes) != 1 {
t.Fatalf("Expected 1 change, got %d", len(changes))
}
if changes[0].Type != "remove_element" {
t.Errorf("Expected change type 'remove_element', got '%s'", changes[0].Type)
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Should not contain shield
if strings.Contains(result, "shield") {
t.Errorf("Remove element failed. Result still contains 'shield':\n%s", result)
}
// Should still contain sword
if !strings.Contains(result, "sword") {
t.Errorf("Remove element incorrectly removed other elements:\n%s", result)
}
// Normalize whitespace for comparison
resultNorm := strings.TrimSpace(result)
expectedNorm := strings.TrimSpace(expected)
if resultNorm != expectedNorm {
t.Errorf("Remove element result mismatch.\nExpected:\n%s\n\nGot:\n%s", expectedNorm, resultNorm)
}
}
func TestComplexNestedChanges(t *testing.T) {
original := `<root>
<inventory>
<item name="sword" weight="10">
<stats damage="5" speed="3" />
</item>
<item name="shield" weight="8">
<stats defense="7" />
</item>
</inventory>
</root>`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Create modified version with multiple changes
modElem := deepCopyXMLElement(origElem)
// Change first item's weight
inventory := modElem.Children[0]
item1 := inventory.Children[0]
weightAttr := item1.Attributes["weight"]
weightAttr.Value = "20"
item1.Attributes["weight"] = weightAttr
// Change nested stats damage
stats := item1.Children[0]
damageAttr := stats.Attributes["damage"]
damageAttr.Value = "10"
stats.Attributes["damage"] = damageAttr
// Change second item's name
item2 := inventory.Children[1]
nameAttr := item2.Attributes["name"]
nameAttr.Value = "buckler"
item2.Attributes["name"] = nameAttr
// Find changes
changes := findXMLChanges(origElem, modElem, "")
// Should have 3 changes: weight, damage, name
if len(changes) != 3 {
t.Fatalf("Expected 3 changes, got %d: %+v", len(changes), changes)
}
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Verify all changes were applied
if !strings.Contains(result, `weight="20"`) {
t.Errorf("Failed to update weight to 20:\n%s", result)
}
if !strings.Contains(result, `damage="10"`) {
t.Errorf("Failed to update damage to 10:\n%s", result)
}
if !strings.Contains(result, `name="buckler"`) {
t.Errorf("Failed to update name to buckler:\n%s", result)
}
// Verify unchanged elements remain
if !strings.Contains(result, `speed="3"`) {
t.Errorf("Incorrectly modified speed attribute:\n%s", result)
}
if !strings.Contains(result, `defense="7"`) {
t.Errorf("Incorrectly modified defense attribute:\n%s", result)
}
}
func TestFormattingPreservation(t *testing.T) {
original := `<root>
<item name="sword" weight="10">
<description>A sharp blade</description>
<stats damage="5" speed="3" />
</item>
</root>`
expected := `<root>
<item name="sword" weight="20">
<description>A sharp blade</description>
<stats damage="5" speed="3" />
</item>
</root>`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Modify only weight
modElem := deepCopyXMLElement(origElem)
item := modElem.Children[0]
weightAttr := item.Attributes["weight"]
weightAttr.Value = "20"
item.Attributes["weight"] = weightAttr
// Find changes
changes := findXMLChanges(origElem, modElem, "")
// Apply changes
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
if result != expected {
t.Errorf("Formatting preservation failed.\nExpected:\n%s\n\nGot:\n%s", expected, result)
}
}
func TestNumericHelpers(t *testing.T) {
tests := []struct {
input string
expected float64
isNum bool
}{
{"42", 42.0, true},
{"3.14", 3.14, true},
{"0", 0.0, true},
{"-5", -5.0, true},
{"abc", 0.0, false},
{"", 0.0, false},
}
for _, tt := range tests {
val, ok := parseNumeric(tt.input)
if ok != tt.isNum {
t.Errorf("parseNumeric(%q) isNum = %v, expected %v", tt.input, ok, tt.isNum)
}
if ok && val != tt.expected {
t.Errorf("parseNumeric(%q) = %v, expected %v", tt.input, val, tt.expected)
}
}
// Test formatting
formatTests := []struct {
input float64
expected string
}{
{42.0, "42"},
{3.14, "3.14"},
{0.0, "0"},
{-5.0, "-5"},
{100.5, "100.5"},
}
for _, tt := range formatTests {
result := formatNumeric(tt.input)
if result != tt.expected {
t.Errorf("formatNumeric(%v) = %q, expected %q", tt.input, result, tt.expected)
}
}
}
func TestDeepCopyXMLElement(t *testing.T) {
original := &XMLElement{
Tag: "item",
Text: "content",
Attributes: map[string]XMLAttribute{
"name": {Value: "sword"},
},
Children: []*XMLElement{
{Tag: "child", Text: "text"},
},
}
copied := deepCopyXMLElement(original)
// Verify copy is equal
if copied.Tag != original.Tag {
t.Errorf("Tag not copied correctly")
}
if copied.Text != original.Text {
t.Errorf("Text not copied correctly")
}
// Modify copy
copied.Tag = "modified"
copied.Attributes["name"] = XMLAttribute{Value: "shield"}
copied.Children[0].Text = "modified text"
// Verify original unchanged
if original.Tag != "item" {
t.Errorf("Original was modified")
}
if original.Attributes["name"].Value != "sword" {
t.Errorf("Original attributes were modified")
}
if original.Children[0].Text != "text" {
t.Errorf("Original children were modified")
}
}
func TestSerializeXMLElement(t *testing.T) {
elem := &XMLElement{
Tag: "item",
Attributes: map[string]XMLAttribute{
"name": {Value: "sword"},
"weight": {Value: "10"},
},
Children: []*XMLElement{
{
Tag: "stats",
Attributes: map[string]XMLAttribute{
"damage": {Value: "5"},
},
Children: []*XMLElement{},
},
},
}
result := serializeXMLElement(elem, "")
// Check it contains expected parts
if !strings.Contains(result, "<item") {
t.Errorf("Missing opening tag")
}
if !strings.Contains(result, "</item>") {
t.Errorf("Missing closing tag")
}
if !strings.Contains(result, `name="sword"`) {
t.Errorf("Missing name attribute")
}
if !strings.Contains(result, `weight="10"`) {
t.Errorf("Missing weight attribute")
}
if !strings.Contains(result, "<stats") {
t.Errorf("Missing child element")
}
}
func TestEmptyElements(t *testing.T) {
original := `<root>
<item name="sword" />
<item name="shield"></item>
</root>`
// Parse
elem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse XML: %v", err)
}
if len(elem.Children) != 2 {
t.Errorf("Expected 2 children, got %d", len(elem.Children))
}
// Both should be parsed correctly
if elem.Children[0].Tag != "item" {
t.Errorf("First child tag incorrect")
}
if elem.Children[1].Tag != "item" {
t.Errorf("Second child tag incorrect")
}
}
func TestAttributeOrderPreservation(t *testing.T) {
original := `<item name="sword" weight="10" damage="5" speed="3" />`
// Parse original
origElem, err := parseXMLWithPositions(original)
if err != nil {
t.Fatalf("Failed to parse original XML: %v", err)
}
// Modify just weight
modElem := deepCopyXMLElement(origElem)
weightAttr := modElem.Attributes["weight"]
weightAttr.Value = "20"
modElem.Attributes["weight"] = weightAttr
// Find and apply changes
changes := findXMLChanges(origElem, modElem, "")
commands := applyXMLChanges(changes)
result, _ := utils.ExecuteModifications(commands, original)
// Verify attribute order is preserved (weight comes before damage and speed)
weightIdx := strings.Index(result, "weight=")
damageIdx := strings.Index(result, "damage=")
speedIdx := strings.Index(result, "speed=")
if weightIdx > damageIdx || damageIdx > speedIdx {
t.Errorf("Attribute order not preserved:\n%s", result)
}
}

74
testfiles/Afflictions.xml Normal file
View File

@@ -0,0 +1,74 @@
<?xml version="1.0" encoding="utf-8"?>
<Afflictions>
<Affliction name="" identifier="Cozy_Fire" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="10" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="10" strengthchange="-1.0">
<ReduceAffliction type="damage" amount="0.001" />
<ReduceAffliction type="bleeding" amount="0.001" />
<ReduceAffliction type="burn" amount="0.001" />
<ReduceAffliction type="bloodloss" amount="0.001" />
</Effect>
<icon texture="%ModDir%/Placable/Cozy_Fire.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="The_Bast_Defense" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="10" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="10" strengthchange="-1.0" resistancefor="damage" minresistance="0.05" maxresistance="0.05"></Effect>
<icon texture="%ModDir%/Placable/The_Bast_Defense.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Clairvoyance" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="240" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="240" strengthchange="-1.0" resistancefor="stun" minresistance="0.15" maxresistance="0.15"></Effect>
<icon texture="%ModDir%/Placable/Clairvoyance.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Heart_Lamp" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="10" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="10" strengthchange="-1.0">
<ReduceAffliction type="damage" amount="0.001" />
<ReduceAffliction type="bleeding" amount="0.001" />
<ReduceAffliction type="burn" amount="0.001" />
<ReduceAffliction type="bloodloss" amount="0.001" />
</Effect>
<icon texture="%ModDir%/Placable/Heart_Lamp.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Star_in_a_Bottle_buff" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="10" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="10" strengthchange="-1.0" resistancefor="stun" minresistance="0.1" maxresistance="0.1"></Effect>
<icon texture="%ModDir%/Placable/Star_in_a_Bottle_buff.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="HappyF" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="10" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="10" strengthchange="-1.0" resistancefor="stun" minresistance="0.05" maxresistance="0.05" minspeedmultiplier="1.1" maxspeedmultiplier="1.1"></Effect>
<icon texture="%ModDir%/Placable/Happy.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="SharpenedF" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="240" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="240" strengthchange="-1.0">
<StatValue stattype="MeleeAttackMultiplier" value="0.25" />
</Effect>
<icon texture="%ModDir%/Placable/Sharpened.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Sugar_RushF" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="240" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="240" strengthchange="-1.0" minspeedmultiplier="1.2" maxspeedmultiplier="1.2">
<StatValue stattype="MeleeAttackSpeed" value="0.05" />
<StatValue stattype="RangedAttackSpeed" value="0.05" />
</Effect>
<icon texture="%ModDir%/Placable/Sugar_Rush.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Crimson_Effigy_buff" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="240" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="240" strengthchange="-1.0" resistancefor="damage" minresistance="-0.1" maxresistance="-0.1">
<StatValue stattype="MeleeAttackSpeed" value="0.15" />
<StatValue stattype="RangedAttackSpeed" value="0.15" />
</Effect>
<icon texture="%ModDir%/Placable/Crimson_Effigy_buff.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Corruption_Effigy_buff" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="240" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="240" strengthchange="-1.0" minvitalitydecrease="0.2" multiplybymaxvitality="true" maxvitalitydecrease="0.2" resistancefor="damage" minresistance="0.1" maxresistance="0.1">
<StatValue stattype="AttackMultiplier" value="0.2" />
</Effect>
<icon texture="%ModDir%/Placable/Corruption_Effigy_buff.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Effigy_of_Decay_buff" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="240" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="240" strengthchange="-1.0" resistancefor="oxygenlow" minresistance="1" maxresistance="1">
<StatusEffect target="Character" SpeedMultiplier="1.1" OxygenAvailable="1000.0" setvalue="true" />
<AbilityFlag flagtype="ImmuneToPressure" />
</Effect>
<icon texture="%ModDir%/Placable/Effigy_of_Decay_buff.png" sourcerect="0,0,64,64" origin="0,0" />
</Affliction>
<Affliction name="" identifier="Chlorophyte_Extractinator" type="strengthbuff" limbspecific="false" isbuff="true" maxstrength="6" hideiconafterdelay="true">
<Effect minstrength="0" maxstrength="6" strengthchange="-1.0"></Effect>
<icon texture="%ModDir%/Extractinator/Chlorophyte_Extractinator.png" sourcerect="0,0,144,152" origin="0,0" />
</Affliction>
</Afflictions>

View File

@@ -0,0 +1,8 @@
- name: "JSONFormattingTest"
json: true
lua: |
data.version = "2.0.0"
data.enabled = true
data.settings.timeout = 60
return true
files: ["testfiles/test3.json"]

View File

@@ -0,0 +1,15 @@
# Test with global JSON flag (no json: true in commands)
- name: "JSONArrayMultiply"
lua: |
for i, item in ipairs(data.items) do
data.items[i].value = item.value * 2
end
return true
files: ["testfiles/test2.json"]
- name: "JSONObjectUpdate"
lua: |
data.version = "3.0.0"
data.enabled = false
return true
files: ["testfiles/test2.json"]

View File

@@ -0,0 +1,32 @@
# Global modifiers
- modifiers:
multiply: 2.0
new_version: "2.0.0"
# JSON mode examples
- name: "JSONArrayMultiply"
json: true
lua: |
for i, item in ipairs(data.items) do
data.items[i].value = item.value * $multiply
end
return true
files: ["testfiles/test.json"]
- name: "JSONObjectUpdate"
json: true
lua: |
data.version = $new_version
data.enabled = true
return true
files: ["testfiles/test.json"]
- name: "JSONNestedModify"
json: true
lua: |
if data.settings and data.settings.performance then
data.settings.performance.multiplier = data.settings.performance.multiplier * 1.5
data.settings.performance.enabled = true
end
return true
files: ["testfiles/test.json"]

30
testfiles/test.json Normal file
View File

@@ -0,0 +1,30 @@
{
"name": "test-config",
"version": "1.0.0",
"enabled": false,
"settings": {
"timeout": 30,
"retries": 3,
"performance": {
"multiplier": 1.0,
"enabled": false
}
},
"items": [
{
"id": 1,
"name": "item1",
"value": 10
},
{
"id": 2,
"name": "item2",
"value": 20
},
{
"id": 3,
"name": "item3",
"value": 30
}
]
}

30
testfiles/test2.json Normal file
View File

@@ -0,0 +1,30 @@
{
"enabled": false,
"items": [
{
"id": 1,
"name": "item1",
"value": 80
},
{
"id": 2,
"name": "item2",
"value": 160
},
{
"id": 3,
"name": "item3",
"value": 240
}
],
"name": "test-config",
"settings": {
"performance": {
"enabled": true,
"multiplier": 1.5
},
"retries": 3,
"timeout": 30
},
"version": "3.0.0"
}

25
testfiles/test3.json Normal file
View File

@@ -0,0 +1,25 @@
{
"enabled": true,
"items": [
{
"id": 1,
"name": "item1",
"value": 10
},
{
"id": 2,
"name": "item2",
"value": 20
}
],
"name": "test-config",
"settings": {
"performance": {
"enabled": false,
"multiplier": 1
},
"retries": 3,
"timeout": 60
},
"version": "2.0.0"
}

View File

@@ -0,0 +1,25 @@
{
"enabled": true,
"items": [
{
"id": 1,
"name": "item1",
"value": 10
},
{
"id": 2,
"name": "item2",
"value": 20
}
],
"name": "test-config",
"settings": {
"performance": {
"enabled": false,
"multiplier": 1
},
"retries": 3,
"timeout": 60
},
"version": "2.0.0"
}

25
testfiles/test4.json Normal file
View File

@@ -0,0 +1,25 @@
{
"name": "test-config",
"version": "1.0.0",
"enabled": false,
"settings": {
"timeout": 30,
"retries": 3,
"performance": {
"multiplier": 1.0,
"enabled": false
}
},
"items": [
{
"id": 1,
"name": "item1",
"value": 10
},
{
"id": 2,
"name": "item2",
"value": 20
}
]
}

View File

@@ -44,7 +44,7 @@ files = ["*.txt"]
os.Chdir(tmpDir)
// Test loading TOML commands
commands, err := utils.LoadCommandsFromTomlFiles("test.toml")
commands, _, err := utils.LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err, "Should load TOML commands without error")
assert.Len(t, commands, 2, "Should load 2 commands from TOML")
@@ -69,9 +69,11 @@ func TestTOMLGlobalModifiers(t *testing.T) {
}
defer os.RemoveAll(tmpDir)
// Create TOML content with global modifiers
tomlContent := `[[commands]]
modifiers = { multiplier = 3, prefix = "TEST_", enabled = true }
// Create TOML content with global variables
tomlContent := `[variables]
multiplier = 3
prefix = "TEST_"
enabled = true
[[commands]]
name = "UseGlobalModifiers"
@@ -92,23 +94,19 @@ files = ["test.txt"]
os.Chdir(tmpDir)
// Test loading TOML commands
commands, err := utils.LoadCommandsFromTomlFiles("test.toml")
commands, variables, err := utils.LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err, "Should load TOML commands without error")
assert.Len(t, commands, 2, "Should load 2 commands from TOML")
assert.Len(t, commands, 1, "Should load 1 command from TOML")
assert.Len(t, variables, 3, "Should load 3 variables")
// Verify global modifiers command (first command should have only modifiers)
assert.Empty(t, commands[0].Name, "Global modifiers command should have no name")
assert.Empty(t, commands[0].Regex, "Global modifiers command should have no regex")
assert.Empty(t, commands[0].Lua, "Global modifiers command should have no lua")
assert.Empty(t, commands[0].Files, "Global modifiers command should have no files")
assert.Len(t, commands[0].Modifiers, 3, "Global modifiers command should have 3 modifiers")
assert.Equal(t, int64(3), commands[0].Modifiers["multiplier"], "Multiplier should be 3")
assert.Equal(t, "TEST_", commands[0].Modifiers["prefix"], "Prefix should be TEST_")
assert.Equal(t, true, commands[0].Modifiers["enabled"], "Enabled should be true")
// Verify variables
assert.Equal(t, int64(3), variables["multiplier"], "Multiplier should be 3")
assert.Equal(t, "TEST_", variables["prefix"], "Prefix should be TEST_")
assert.Equal(t, true, variables["enabled"], "Enabled should be true")
// Verify regular command
assert.Equal(t, "UseGlobalModifiers", commands[1].Name, "Regular command name should match")
assert.Equal(t, "value = !num", commands[1].Regex, "Regular command regex should match")
assert.Equal(t, "UseGlobalModifiers", commands[0].Name, "Regular command name should match")
assert.Equal(t, "value = !num", commands[0].Regex, "Regular command regex should match")
}
func TestTOMLMultilineRegex(t *testing.T) {
@@ -120,8 +118,8 @@ func TestTOMLMultilineRegex(t *testing.T) {
defer os.RemoveAll(tmpDir)
// Create TOML content with multiline regex using literal strings
tomlContent := `[[commands]]
modifiers = { factor = 2.5 }
tomlContent := `[variables]
factor = 2.5
[[commands]]
name = "MultilineTest"
@@ -166,12 +164,13 @@ isolate = true
os.Chdir(tmpDir)
// Test loading TOML commands
commands, err := utils.LoadCommandsFromTomlFiles("test.toml")
commands, variables, err := utils.LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err, "Should load TOML commands without error")
assert.Len(t, commands, 2, "Should load 2 commands from TOML")
assert.Len(t, commands, 1, "Should load 1 command from TOML")
assert.Len(t, variables, 1, "Should load 1 variable")
// Verify the multiline regex command
multilineCmd := commands[1]
multilineCmd := commands[0]
assert.Equal(t, "MultilineTest", multilineCmd.Name, "Command name should match")
assert.Contains(t, multilineCmd.Regex, "\\[config\\.settings\\]", "Regex should contain escaped brackets")
assert.Contains(t, multilineCmd.Regex, "depth = !num", "Regex should contain depth pattern")
@@ -225,7 +224,7 @@ files = ["*.conf", "*.ini"]
os.Chdir(tmpDir)
// Test loading TOML commands
commands, err := utils.LoadCommandsFromTomlFiles("test.toml")
commands, _, err := utils.LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err, "Should load TOML commands without error")
assert.Len(t, commands, 1, "Should load 1 command from TOML")
@@ -276,7 +275,7 @@ files = ["config.json"]
os.Chdir(tmpDir)
// Test loading TOML commands
commands, err := utils.LoadCommandsFromTomlFiles("test.toml")
commands, _, err := utils.LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err, "Should load TOML commands without error")
assert.Len(t, commands, 2, "Should load 2 commands from TOML")
@@ -304,8 +303,9 @@ func TestTOMLEndToEndIntegration(t *testing.T) {
defer os.RemoveAll(tmpDir)
// Create comprehensive TOML content
tomlContent := `[[commands]]
modifiers = { multiplier = 4, base_value = 100 }
tomlContent := `[variables]
multiplier = 4
base_value = 100
[[commands]]
name = "IntegrationTest"
@@ -358,9 +358,10 @@ some_other_setting = enabled = true
os.Chdir(tmpDir)
// Test the complete workflow using the main function
commands, err := utils.LoadCommands([]string{"test.toml"})
commands, variables, err := utils.LoadCommands([]string{"test.toml"})
assert.NoError(t, err, "Should load TOML commands without error")
assert.Len(t, commands, 3, "Should load 3 commands total (including global modifiers)")
assert.Len(t, commands, 2, "Should load 2 commands")
assert.Len(t, variables, 2, "Should load 2 variables")
// Associate files with commands
files := []string{"test.txt"}
@@ -404,24 +405,26 @@ files = ["test.txt"
invalidFile := filepath.Join(tmpDir, "invalid.toml")
err = os.WriteFile(invalidFile, []byte(invalidTOML), 0644)
assert.NoError(t, err, "Should write invalid TOML file")
commands, err := utils.LoadCommandsFromTomlFiles("invalid.toml")
commands, _, err := utils.LoadCommandsFromTomlFiles("invalid.toml")
assert.Error(t, err, "Should return error for invalid TOML syntax")
assert.Nil(t, commands, "Should return nil commands for invalid TOML")
assert.Contains(t, err.Error(), "failed to unmarshal TOML file", "Error should mention TOML unmarshaling")
// Test 2: Non-existent file
commands, err = utils.LoadCommandsFromTomlFiles("nonexistent.toml")
commands, _, err = utils.LoadCommandsFromTomlFiles("nonexistent.toml")
assert.NoError(t, err, "Should handle non-existent file without error")
assert.Empty(t, commands, "Should return empty commands for non-existent file")
// Test 3: Empty TOML file creates an error (this is expected behavior)
// Test 3: Empty TOML file returns no commands (not an error)
emptyFile := filepath.Join(tmpDir, "empty.toml")
err = os.WriteFile(emptyFile, []byte(""), 0644)
assert.NoError(t, err, "Should write empty TOML file")
commands, err = utils.LoadCommandsFromTomlFiles("empty.toml")
assert.Error(t, err, "Should return error for empty TOML file")
assert.Nil(t, commands, "Should return nil commands for empty TOML")
commands, _, err = utils.LoadCommandsFromTomlFiles("empty.toml")
assert.NoError(t, err, "Empty TOML should not return error")
assert.Empty(t, commands, "Should return empty commands for empty TOML")
}
func TestYAMLToTOMLConversion(t *testing.T) {
@@ -438,21 +441,20 @@ func TestYAMLToTOMLConversion(t *testing.T) {
os.Chdir(tmpDir)
// Create a test YAML file
yamlContent := `- name: "ConversionTest"
regex: "value = !num"
lua: "v1 * 3"
files: ["test.txt"]
loglevel: DEBUG
yamlContent := `variables:
multiplier: 2.5
prefix: "CONV_"
- name: "AnotherTest"
regex: "enabled = (true|false)"
lua: "= false"
files: ["*.conf"]
- name: "GlobalModifiers"
modifiers:
multiplier: 2.5
prefix: "CONV_"
commands:
- name: "ConversionTest"
regex: "value = !num"
lua: "v1 * 3"
files: ["test.txt"]
loglevel: DEBUG
- name: "AnotherTest"
regex: "enabled = (true|false)"
lua: "= false"
files: ["*.conf"]
`
yamlFile := filepath.Join(tmpDir, "test.yml")
@@ -475,28 +477,17 @@ func TestYAMLToTOMLConversion(t *testing.T) {
tomlContent := string(tomlData)
assert.Contains(t, tomlContent, `name = "ConversionTest"`, "TOML should contain first command name")
assert.Contains(t, tomlContent, `name = "AnotherTest"`, "TOML should contain second command name")
assert.Contains(t, tomlContent, `name = "GlobalModifiers"`, "TOML should contain global modifiers command")
assert.Contains(t, tomlContent, `[variables]`, "TOML should contain variables section")
assert.Contains(t, tomlContent, `multiplier = 2.5`, "TOML should contain multiplier")
assert.Contains(t, tomlContent, `prefix = "CONV_"`, "TOML should contain prefix")
// Test that converted TOML loads correctly
commands, err := utils.LoadCommandsFromTomlFiles("test.toml")
commands, variables, err := utils.LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err, "Should load converted TOML without error")
assert.Len(t, commands, 3, "Should load 3 commands from converted TOML")
assert.Len(t, commands, 2, "Should load 2 commands from converted TOML")
assert.Len(t, variables, 2, "Should have 2 variables")
// Find global modifiers command (it might not be first)
var globalCmd utils.ModifyCommand
foundGlobal := false
for _, cmd := range commands {
if cmd.Name == "GlobalModifiers" {
globalCmd = cmd
foundGlobal = true
break
}
}
assert.True(t, foundGlobal, "Should find global modifiers command")
assert.Equal(t, 2.5, globalCmd.Modifiers["multiplier"], "Should preserve multiplier value")
assert.Equal(t, "CONV_", globalCmd.Modifiers["prefix"], "Should preserve prefix value")
// Variables are now loaded separately, not as part of commands
// Test skip functionality - run conversion again
err = utils.ConvertYAMLToTOML("test.yml")
@@ -508,4 +499,4 @@ func TestYAMLToTOMLConversion(t *testing.T) {
assert.Equal(t, tomlData, originalTomlData, "TOML file content should be unchanged")
t.Logf("YAML to TOML conversion test completed successfully")
}
}

View File

@@ -2,8 +2,6 @@ package utils
import (
"os"
"path/filepath"
"strconv"
"strings"
logger "git.site.quack-lab.dev/dave/cylogger"
@@ -12,37 +10,6 @@ import (
// fileLogger is a scoped logger for the utils/file package.
var fileLogger = logger.Default.WithPrefix("utils/file")
func CleanPath(path string) string {
cleanPathLogger := fileLogger.WithPrefix("CleanPath")
cleanPathLogger.Debug("Cleaning path: %q", path)
cleanPathLogger.Trace("Original path: %q", path)
path = filepath.Clean(path)
path = strings.ReplaceAll(path, "\\", "/")
cleanPathLogger.Trace("Cleaned path result: %q", path)
return path
}
func ToAbs(path string) string {
toAbsLogger := fileLogger.WithPrefix("ToAbs")
toAbsLogger.Debug("Converting path to absolute: %q", path)
toAbsLogger.Trace("Input path: %q", path)
if filepath.IsAbs(path) {
toAbsLogger.Debug("Path is already absolute, cleaning it.")
cleanedPath := CleanPath(path)
toAbsLogger.Trace("Already absolute path after cleaning: %q", cleanedPath)
return cleanedPath
}
cwd, err := os.Getwd()
if err != nil {
toAbsLogger.Error("Error getting current working directory: %v", err)
return CleanPath(path)
}
toAbsLogger.Trace("Current working directory: %q", cwd)
cleanedPath := CleanPath(filepath.Join(cwd, path))
toAbsLogger.Trace("Converted absolute path result: %q", cleanedPath)
return cleanedPath
}
// LimitString truncates a string to maxLen and adds "..." if truncated
func LimitString(s string, maxLen int) string {
limitStringLogger := fileLogger.WithPrefix("LimitString").WithField("originalLength", len(s)).WithField("maxLength", maxLen)
@@ -57,19 +24,6 @@ func LimitString(s string, maxLen int) string {
return limited
}
// StrToFloat converts a string to a float64, returning 0 on error.
func StrToFloat(s string) float64 {
strToFloatLogger := fileLogger.WithPrefix("StrToFloat").WithField("inputString", s)
strToFloatLogger.Debug("Attempting to convert string to float")
f, err := strconv.ParseFloat(s, 64)
if err != nil {
strToFloatLogger.Warning("Failed to convert string %q to float, returning 0: %v", s, err)
return 0
}
strToFloatLogger.Trace("Successfully converted %q to float: %f", s, f)
return f
}
func ResetWhereNecessary(associations map[string]FileCommandAssociation, db DB) error {
resetWhereNecessaryLogger := fileLogger.WithPrefix("ResetWhereNecessary")
resetWhereNecessaryLogger.Debug("Starting reset where necessary operation")

209
utils/file_test.go Normal file
View File

@@ -0,0 +1,209 @@
package utils
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLimitString(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
{
name: "Short string",
input: "hello",
maxLen: 10,
expected: "hello",
},
{
name: "Exact length",
input: "hello",
maxLen: 5,
expected: "hello",
},
{
name: "Too long",
input: "hello world",
maxLen: 8,
expected: "hello...",
},
{
name: "With newlines",
input: "hello\nworld",
maxLen: 20,
expected: "hello\\nworld",
},
{
name: "With newlines truncated",
input: "hello\nworld\nfoo\nbar",
maxLen: 15,
expected: "hello\\nworld...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := LimitString(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
func TestResetWhereNecessary(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "reset-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create test files
file1 := filepath.Join(tmpDir, "file1.txt")
file2 := filepath.Join(tmpDir, "file2.txt")
file3 := filepath.Join(tmpDir, "file3.txt")
err = os.WriteFile(file1, []byte("original1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("original2"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file3, []byte("original3"), 0644)
assert.NoError(t, err)
// Modify files
err = os.WriteFile(file1, []byte("modified1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("modified2"), 0644)
assert.NoError(t, err)
// Create mock DB
db, err := GetDB()
assert.NoError(t, err)
err = db.SaveFile(file1, []byte("original1"))
assert.NoError(t, err)
err = db.SaveFile(file2, []byte("original2"))
assert.NoError(t, err)
// file3 not in DB
// Create associations with reset commands
associations := map[string]FileCommandAssociation{
file1: {
File: file1,
Commands: []ModifyCommand{
{Name: "cmd1", Reset: true},
},
},
file2: {
File: file2,
IsolateCommands: []ModifyCommand{
{Name: "cmd2", Reset: true},
},
},
file3: {
File: file3,
Commands: []ModifyCommand{
{Name: "cmd3", Reset: false}, // No reset
},
},
}
// Run reset
err = ResetWhereNecessary(associations, db)
assert.NoError(t, err)
// Verify file1 was reset
data, _ := os.ReadFile(file1)
assert.Equal(t, "original1", string(data))
// Verify file2 was reset
data, _ = os.ReadFile(file2)
assert.Equal(t, "original2", string(data))
// Verify file3 was NOT reset
data, _ = os.ReadFile(file3)
assert.Equal(t, "original3", string(data))
}
func TestResetWhereNecessaryMissingFromDB(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "reset-missing-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create a test file that's been modified
file1 := filepath.Join(tmpDir, "file1.txt")
err = os.WriteFile(file1, []byte("modified_content"), 0644)
assert.NoError(t, err)
// Create DB but DON'T save file to it
db, err := GetDB()
assert.NoError(t, err)
// Create associations with reset command
associations := map[string]FileCommandAssociation{
file1: {
File: file1,
Commands: []ModifyCommand{
{Name: "cmd1", Reset: true},
},
},
}
// Run reset - should use current disk content as fallback
err = ResetWhereNecessary(associations, db)
assert.NoError(t, err)
// Verify file was "reset" to current content (saved to DB for next time)
data, _ := os.ReadFile(file1)
assert.Equal(t, "modified_content", string(data))
// Verify it was saved to DB
savedData, err := db.GetFile(file1)
assert.NoError(t, err)
assert.Equal(t, "modified_content", string(savedData))
}
func TestResetAllFiles(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "reset-all-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create test files
file1 := filepath.Join(tmpDir, "file1.txt")
file2 := filepath.Join(tmpDir, "file2.txt")
err = os.WriteFile(file1, []byte("original1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("original2"), 0644)
assert.NoError(t, err)
// Create mock DB and save originals
db, err := GetDB()
assert.NoError(t, err)
err = db.SaveFile(file1, []byte("original1"))
assert.NoError(t, err)
err = db.SaveFile(file2, []byte("original2"))
assert.NoError(t, err)
// Modify files
err = os.WriteFile(file1, []byte("modified1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("modified2"), 0644)
assert.NoError(t, err)
// Verify they're modified
data, _ := os.ReadFile(file1)
assert.Equal(t, "modified1", string(data))
// Reset all
err = ResetAllFiles(db)
assert.NoError(t, err)
// Verify both were reset
data, _ = os.ReadFile(file1)
assert.Equal(t, "original1", string(data))
data, _ = os.ReadFile(file2)
assert.Equal(t, "original2", string(data))
}

View File

@@ -7,8 +7,8 @@ import (
"strings"
logger "git.site.quack-lab.dev/dave/cylogger"
"github.com/bmatcuk/doublestar/v4"
"github.com/BurntSushi/toml"
"github.com/bmatcuk/doublestar/v4"
"gopkg.in/yaml.v3"
)
@@ -16,18 +16,17 @@ import (
var modifyCommandLogger = logger.Default.WithPrefix("utils/modifycommand")
type ModifyCommand struct {
Name string `yaml:"name,omitempty" toml:"name,omitempty"`
Regex string `yaml:"regex,omitempty" toml:"regex,omitempty"`
Regexes []string `yaml:"regexes,omitempty" toml:"regexes,omitempty"`
Lua string `yaml:"lua,omitempty" toml:"lua,omitempty"`
Files []string `yaml:"files,omitempty" toml:"files,omitempty"`
Reset bool `yaml:"reset,omitempty" toml:"reset,omitempty"`
LogLevel string `yaml:"loglevel,omitempty" toml:"loglevel,omitempty"`
Isolate bool `yaml:"isolate,omitempty" toml:"isolate,omitempty"`
NoDedup bool `yaml:"nodedup,omitempty" toml:"nodedup,omitempty"`
Disabled bool `yaml:"disable,omitempty" toml:"disable,omitempty"`
JSON bool `yaml:"json,omitempty" toml:"json,omitempty"`
Modifiers map[string]interface{} `yaml:"modifiers,omitempty" toml:"modifiers,omitempty"`
Name string `yaml:"name,omitempty" toml:"name,omitempty"`
Regex string `yaml:"regex,omitempty" toml:"regex,omitempty"`
Regexes []string `yaml:"regexes,omitempty" toml:"regexes,omitempty"`
Lua string `yaml:"lua,omitempty" toml:"lua,omitempty"`
Files []string `yaml:"files,omitempty" toml:"files,omitempty"`
Reset bool `yaml:"reset,omitempty" toml:"reset,omitempty"`
LogLevel string `yaml:"loglevel,omitempty" toml:"loglevel,omitempty"`
Isolate bool `yaml:"isolate,omitempty" toml:"isolate,omitempty"`
NoDedup bool `yaml:"nodedup,omitempty" toml:"nodedup,omitempty"`
Disabled bool `yaml:"disable,omitempty" toml:"disable,omitempty"`
JSON bool `yaml:"json,omitempty" toml:"json,omitempty"`
}
type CookFile []ModifyCommand
@@ -62,6 +61,7 @@ func (c *ModifyCommand) Validate() error {
// Ehh.. Not much better... Guess this wasn't the big deal
var matchesMemoTable map[string]bool = make(map[string]bool)
var globMemoTable map[string][]string = make(map[string][]string)
func Matches(path string, glob string) (bool, error) {
matchesLogger := modifyCommandLogger.WithPrefix("Matches").WithField("path", path).WithField("glob", glob)
@@ -85,27 +85,18 @@ func SplitPattern(pattern string) (string, string) {
splitPatternLogger := modifyCommandLogger.WithPrefix("SplitPattern").WithField("pattern", pattern)
splitPatternLogger.Debug("Splitting pattern")
splitPatternLogger.Trace("Original pattern: %q", pattern)
static, pattern := doublestar.SplitPattern(pattern)
cwd, err := os.Getwd()
if err != nil {
splitPatternLogger.Error("Error getting current working directory: %v", err)
return "", ""
}
splitPatternLogger.Trace("Current working directory: %q", cwd)
if static == "" {
splitPatternLogger.Debug("Static part is empty, defaulting to current working directory")
static = cwd
}
if !filepath.IsAbs(static) {
splitPatternLogger.Debug("Static part is not absolute, joining with current working directory")
static = filepath.Join(cwd, static)
static = filepath.Clean(static)
splitPatternLogger.Trace("Static path after joining and cleaning: %q", static)
}
static = strings.ReplaceAll(static, "\\", "/")
splitPatternLogger.Trace("Final static path: %q, Remaining pattern: %q", static, pattern)
return static, pattern
// Split the pattern first to separate static and wildcard parts
static, remainingPattern := doublestar.SplitPattern(pattern)
splitPatternLogger.Trace("After split: static=%q, pattern=%q", static, remainingPattern)
// Resolve the static part to handle ~ expansion and make it absolute
// ResolvePath already normalizes to forward slashes
static = ResolvePath(static)
splitPatternLogger.Trace("Resolved static part: %q", static)
splitPatternLogger.Trace("Final static path: %q, Remaining pattern: %q", static, remainingPattern)
return static, remainingPattern
}
type FileCommandAssociation struct {
@@ -123,33 +114,23 @@ func AssociateFilesWithCommands(files []string, commands []ModifyCommand) (map[s
fileCommands := make(map[string]FileCommandAssociation)
for _, file := range files {
file = strings.ReplaceAll(file, "\\", "/")
associateFilesLogger.Debug("Processing file: %q", file)
// Use centralized path resolution internally but keep original file as key
resolvedFile := ResolvePath(file)
associateFilesLogger.Debug("Processing file: %q (resolved: %q)", file, resolvedFile)
fileCommands[file] = FileCommandAssociation{
File: file,
File: resolvedFile,
IsolateCommands: []ModifyCommand{},
Commands: []ModifyCommand{},
}
for _, command := range commands {
associateFilesLogger.Debug("Checking command %q for file %q", command.Name, file)
for _, glob := range command.Files {
glob = strings.ReplaceAll(glob, "\\", "/")
// SplitPattern now handles tilde expansion and path resolution
static, pattern := SplitPattern(glob)
associateFilesLogger.Trace("Glob parts for %q → static=%q pattern=%q", glob, static, pattern)
// Build absolute path for the current file to compare with static
cwd, err := os.Getwd()
if err != nil {
associateFilesLogger.Warning("Failed to get CWD when matching %q for file %q: %v", glob, file, err)
continue
}
var absFile string
if filepath.IsAbs(file) {
absFile = filepath.Clean(file)
} else {
absFile = filepath.Clean(filepath.Join(cwd, file))
}
absFile = strings.ReplaceAll(absFile, "\\", "/")
// Use resolved file for matching (already normalized to forward slashes by ResolvePath)
absFile := resolvedFile
associateFilesLogger.Trace("Absolute file path resolved for matching: %q", absFile)
// Only match if the file is under the static root
@@ -200,9 +181,14 @@ func AggregateGlobs(commands []ModifyCommand) map[string]struct{} {
for _, command := range commands {
aggregateGlobsLogger.Debug("Processing command %q for glob patterns", command.Name)
for _, glob := range command.Files {
resolvedGlob := strings.Replace(glob, "~", os.Getenv("HOME"), 1)
resolvedGlob = strings.ReplaceAll(resolvedGlob, "\\", "/")
aggregateGlobsLogger.Trace("Adding glob: %q (resolved to %q)", glob, resolvedGlob)
// Split the glob into static and pattern parts, then resolve ONLY the static part
static, pattern := SplitPattern(glob)
// Reconstruct the glob with resolved static part
resolvedGlob := static
if pattern != "" {
resolvedGlob += "/" + pattern
}
aggregateGlobsLogger.Trace("Adding glob: %q (resolved to %q) [static=%s, pattern=%s]", glob, resolvedGlob, static, pattern)
globs[resolvedGlob] = struct{}{}
}
}
@@ -211,7 +197,7 @@ func AggregateGlobs(commands []ModifyCommand) map[string]struct{} {
return globs
}
func ExpandGLobs(patterns map[string]struct{}) ([]string, error) {
func ExpandGlobs(patterns map[string]struct{}) ([]string, error) {
expandGlobsLogger := modifyCommandLogger.WithPrefix("ExpandGLobs")
expandGlobsLogger.Debug("Expanding glob patterns to actual files")
expandGlobsLogger.Trace("Input patterns for expansion: %v", patterns)
@@ -228,23 +214,30 @@ func ExpandGLobs(patterns map[string]struct{}) ([]string, error) {
for pattern := range patterns {
expandGlobsLogger.Debug("Processing glob pattern: %q", pattern)
static, pattern := SplitPattern(pattern)
matches, err := doublestar.Glob(os.DirFS(static), pattern)
if err != nil {
expandGlobsLogger.Warning("Error expanding glob %q in %q: %v", pattern, static, err)
continue
key := static + "|" + pattern
matches, ok := globMemoTable[key]
if !ok {
var err error
matches, err = doublestar.Glob(os.DirFS(static), pattern)
if err != nil {
expandGlobsLogger.Warning("Error expanding glob %q in %q: %v", pattern, static, err)
continue
}
globMemoTable[key] = matches
}
expandGlobsLogger.Debug("Found %d matches for pattern %q", len(matches), pattern)
expandGlobsLogger.Trace("Raw matches for pattern %q: %v", pattern, matches)
for _, m := range matches {
m = filepath.Join(static, m)
info, err := os.Stat(m)
// Resolve the full path
fullPath := ResolvePath(filepath.Join(static, m))
info, err := os.Stat(fullPath)
if err != nil {
expandGlobsLogger.Warning("Error getting file info for %q: %v", m, err)
expandGlobsLogger.Warning("Error getting file info for %q: %v", fullPath, err)
continue
}
if !info.IsDir() && !filesMap[m] {
expandGlobsLogger.Trace("Adding unique file to list: %q", m)
filesMap[m], files = true, append(files, m)
if !info.IsDir() && !filesMap[fullPath] {
expandGlobsLogger.Trace("Adding unique file to list: %q", fullPath)
filesMap[fullPath], files = true, append(files, fullPath)
}
}
}
@@ -258,34 +251,39 @@ func ExpandGLobs(patterns map[string]struct{}) ([]string, error) {
return files, nil
}
func LoadCommands(args []string) ([]ModifyCommand, error) {
func LoadCommands(args []string) ([]ModifyCommand, map[string]interface{}, error) {
loadCommandsLogger := modifyCommandLogger.WithPrefix("LoadCommands")
loadCommandsLogger.Debug("Loading commands from arguments (cook files or direct patterns)")
loadCommandsLogger.Trace("Input arguments: %v", args)
commands := []ModifyCommand{}
variables := make(map[string]interface{})
for _, arg := range args {
loadCommandsLogger.Debug("Processing argument for commands: %q", arg)
var newCommands []ModifyCommand
var newVariables map[string]interface{}
var err error
// Check file extension to determine format
if strings.HasSuffix(arg, ".toml") {
loadCommandsLogger.Debug("Loading TOML commands from %q", arg)
newCommands, err = LoadCommandsFromTomlFiles(arg)
newCommands, newVariables, err = LoadCommandsFromTomlFiles(arg)
if err != nil {
loadCommandsLogger.Error("Failed to load TOML commands from argument %q: %v", arg, err)
return nil, fmt.Errorf("failed to load commands from TOML files: %w", err)
return nil, nil, fmt.Errorf("failed to load commands from TOML files: %w", err)
}
} else {
// Default to YAML for .yml, .yaml, or any other extension
loadCommandsLogger.Debug("Loading YAML commands from %q", arg)
newCommands, err = LoadCommandsFromCookFiles(arg)
newCommands, newVariables, err = LoadCommandsFromCookFiles(arg)
if err != nil {
loadCommandsLogger.Error("Failed to load YAML commands from argument %q: %v", arg, err)
return nil, fmt.Errorf("failed to load commands from cook files: %w", err)
return nil, nil, fmt.Errorf("failed to load commands from cook files: %w", err)
}
}
for k, v := range newVariables {
variables[k] = v
}
loadCommandsLogger.Debug("Successfully loaded %d commands from %q", len(newCommands), arg)
for _, cmd := range newCommands {
@@ -298,62 +296,71 @@ func LoadCommands(args []string) ([]ModifyCommand, error) {
}
}
loadCommandsLogger.Info("Finished loading commands. Total %d commands loaded", len(commands))
return commands, nil
loadCommandsLogger.Info("Finished loading commands. Total %d commands and %d variables loaded", len(commands), len(variables))
return commands, variables, nil
}
func LoadCommandsFromCookFiles(pattern string) ([]ModifyCommand, error) {
func LoadCommandsFromCookFiles(pattern string) ([]ModifyCommand, map[string]interface{}, error) {
loadCookFilesLogger := modifyCommandLogger.WithPrefix("LoadCommandsFromCookFiles").WithField("pattern", pattern)
loadCookFilesLogger.Debug("Loading commands from cook files based on pattern")
loadCookFilesLogger.Trace("Input pattern: %q", pattern)
static, pattern := SplitPattern(pattern)
commands := []ModifyCommand{}
variables := make(map[string]interface{})
cookFiles, err := doublestar.Glob(os.DirFS(static), pattern)
if err != nil {
loadCookFilesLogger.Error("Failed to glob cook files for pattern %q: %v", pattern, err)
return nil, fmt.Errorf("failed to glob cook files: %w", err)
return nil, nil, fmt.Errorf("failed to glob cook files: %w", err)
}
loadCookFilesLogger.Debug("Found %d cook files for pattern %q", len(cookFiles), pattern)
loadCookFilesLogger.Trace("Cook files found: %v", cookFiles)
for _, cookFile := range cookFiles {
cookFile = filepath.Join(static, cookFile)
cookFile = filepath.Clean(cookFile)
cookFile = strings.ReplaceAll(cookFile, "\\", "/")
// Use centralized path resolution
cookFile = ResolvePath(filepath.Join(static, cookFile))
loadCookFilesLogger.Debug("Loading commands from individual cook file: %q", cookFile)
cookFileData, err := os.ReadFile(cookFile)
if err != nil {
loadCookFilesLogger.Error("Failed to read cook file %q: %v", cookFile, err)
return nil, fmt.Errorf("failed to read cook file: %w", err)
return nil, nil, fmt.Errorf("failed to read cook file: %w", err)
}
loadCookFilesLogger.Trace("Read %d bytes from cook file %q", len(cookFileData), cookFile)
newCommands, err := LoadCommandsFromCookFile(cookFileData)
newCommands, newVariables, err := LoadCommandsFromCookFile(cookFileData)
if err != nil {
loadCookFilesLogger.Error("Failed to load commands from cook file data for %q: %v", cookFile, err)
return nil, fmt.Errorf("failed to load commands from cook file: %w", err)
return nil, nil, fmt.Errorf("failed to load commands from cook file: %w", err)
}
commands = append(commands, newCommands...)
loadCookFilesLogger.Debug("Added %d commands from cook file %q. Total commands now: %d", len(newCommands), cookFile, len(commands))
for k, v := range newVariables {
variables[k] = v
}
loadCookFilesLogger.Debug("Added %d commands and %d variables from cook file %q. Total commands now: %d", len(newCommands), len(newVariables), cookFile, len(commands))
}
loadCookFilesLogger.Debug("Finished loading commands from cook files. Total %d commands", len(commands))
return commands, nil
loadCookFilesLogger.Debug("Finished loading commands from cook files. Total %d commands and %d variables", len(commands), len(variables))
return commands, variables, nil
}
func LoadCommandsFromCookFile(cookFileData []byte) ([]ModifyCommand, error) {
func LoadCommandsFromCookFile(cookFileData []byte) ([]ModifyCommand, map[string]interface{}, error) {
loadCommandLogger := modifyCommandLogger.WithPrefix("LoadCommandsFromCookFile")
loadCommandLogger.Debug("Unmarshaling commands from cook file data")
loadCommandLogger.Trace("Cook file data length: %d", len(cookFileData))
commands := []ModifyCommand{}
err := yaml.Unmarshal(cookFileData, &commands)
var cookFile struct {
Variables map[string]interface{} `yaml:"variables,omitempty"`
Commands []ModifyCommand `yaml:"commands"`
}
err := yaml.Unmarshal(cookFileData, &cookFile)
if err != nil {
loadCommandLogger.Error("Failed to unmarshal cook file data: %v", err)
return nil, fmt.Errorf("failed to unmarshal cook file: %w", err)
return nil, nil, fmt.Errorf("failed to unmarshal cook file: %w", err)
}
loadCommandLogger.Debug("Successfully unmarshaled %d commands", len(commands))
loadCommandLogger.Trace("Unmarshaled commands: %v", commands)
return commands, nil
loadCommandLogger.Debug("Successfully unmarshaled %d commands and %d variables", len(cookFile.Commands), len(cookFile.Variables))
loadCommandLogger.Trace("Unmarshaled commands: %v", cookFile.Commands)
loadCommandLogger.Trace("Unmarshaled variables: %v", cookFile.Variables)
return cookFile.Commands, cookFile.Variables, nil
}
// CountGlobsBeforeDedup counts the total number of glob patterns across all commands before deduplication
@@ -391,53 +398,57 @@ func FilterCommands(commands []ModifyCommand, filter string) []ModifyCommand {
return filteredCommands
}
func LoadCommandsFromTomlFiles(pattern string) ([]ModifyCommand, error) {
func LoadCommandsFromTomlFiles(pattern string) ([]ModifyCommand, map[string]interface{}, error) {
loadTomlFilesLogger := modifyCommandLogger.WithPrefix("LoadCommandsFromTomlFiles").WithField("pattern", pattern)
loadTomlFilesLogger.Debug("Loading commands from TOML files based on pattern")
loadTomlFilesLogger.Trace("Input pattern: %q", pattern)
static, pattern := SplitPattern(pattern)
commands := []ModifyCommand{}
variables := make(map[string]interface{})
tomlFiles, err := doublestar.Glob(os.DirFS(static), pattern)
if err != nil {
loadTomlFilesLogger.Error("Failed to glob TOML files for pattern %q: %v", pattern, err)
return nil, fmt.Errorf("failed to glob TOML files: %w", err)
return nil, nil, fmt.Errorf("failed to glob TOML files: %w", err)
}
loadTomlFilesLogger.Debug("Found %d TOML files for pattern %q", len(tomlFiles), pattern)
loadTomlFilesLogger.Trace("TOML files found: %v", tomlFiles)
for _, tomlFile := range tomlFiles {
tomlFile = filepath.Join(static, tomlFile)
tomlFile = filepath.Clean(tomlFile)
tomlFile = strings.ReplaceAll(tomlFile, "\\", "/")
// Use centralized path resolution
tomlFile = ResolvePath(filepath.Join(static, tomlFile))
loadTomlFilesLogger.Debug("Loading commands from individual TOML file: %q", tomlFile)
tomlFileData, err := os.ReadFile(tomlFile)
if err != nil {
loadTomlFilesLogger.Error("Failed to read TOML file %q: %v", tomlFile, err)
return nil, fmt.Errorf("failed to read TOML file: %w", err)
return nil, nil, fmt.Errorf("failed to read TOML file: %w", err)
}
loadTomlFilesLogger.Trace("Read %d bytes from TOML file %q", len(tomlFileData), tomlFile)
newCommands, err := LoadCommandsFromTomlFile(tomlFileData)
newCommands, newVariables, err := LoadCommandsFromTomlFile(tomlFileData)
if err != nil {
loadTomlFilesLogger.Error("Failed to load commands from TOML file data for %q: %v", tomlFile, err)
return nil, fmt.Errorf("failed to load commands from TOML file: %w", err)
return nil, nil, fmt.Errorf("failed to load commands from TOML file: %w", err)
}
commands = append(commands, newCommands...)
loadTomlFilesLogger.Debug("Added %d commands from TOML file %q. Total commands now: %d", len(newCommands), tomlFile, len(commands))
for k, v := range newVariables {
variables[k] = v
}
loadTomlFilesLogger.Debug("Added %d commands and %d variables from TOML file %q. Total commands now: %d", len(newCommands), len(newVariables), tomlFile, len(commands))
}
loadTomlFilesLogger.Debug("Finished loading commands from TOML files. Total %d commands", len(commands))
return commands, nil
loadTomlFilesLogger.Debug("Finished loading commands from TOML files. Total %d commands and %d variables", len(commands), len(variables))
return commands, variables, nil
}
func LoadCommandsFromTomlFile(tomlFileData []byte) ([]ModifyCommand, error) {
func LoadCommandsFromTomlFile(tomlFileData []byte) ([]ModifyCommand, map[string]interface{}, error) {
loadTomlCommandLogger := modifyCommandLogger.WithPrefix("LoadCommandsFromTomlFile")
loadTomlCommandLogger.Debug("Unmarshaling commands from TOML file data")
loadTomlCommandLogger.Trace("TOML file data length: %d", len(tomlFileData))
// TOML structure for commands array
// TOML structure for commands array and top-level variables
var tomlData struct {
Commands []ModifyCommand `toml:"commands"`
Variables map[string]interface{} `toml:"variables,omitempty"`
Commands []ModifyCommand `toml:"commands"`
// Also support direct array without wrapper
DirectCommands []ModifyCommand `toml:"-"`
}
@@ -446,29 +457,26 @@ func LoadCommandsFromTomlFile(tomlFileData []byte) ([]ModifyCommand, error) {
err := toml.Unmarshal(tomlFileData, &tomlData)
if err != nil {
loadTomlCommandLogger.Error("Failed to unmarshal TOML file data: %v", err)
return nil, fmt.Errorf("failed to unmarshal TOML file: %w", err)
return nil, nil, fmt.Errorf("failed to unmarshal TOML file: %w", err)
}
var commands []ModifyCommand
variables := make(map[string]interface{})
// If we found commands in the wrapped structure, use those
if len(tomlData.Commands) > 0 {
commands = tomlData.Commands
loadTomlCommandLogger.Debug("Found %d commands in wrapped TOML structure", len(commands))
} else {
// Try to parse as direct array (similar to YAML format)
commands = []ModifyCommand{}
err = toml.Unmarshal(tomlFileData, &commands)
if err != nil {
loadTomlCommandLogger.Error("Failed to unmarshal TOML file data as direct array: %v", err)
return nil, fmt.Errorf("failed to unmarshal TOML file as direct array: %w", err)
// Extract top-level variables
if len(tomlData.Variables) > 0 {
loadTomlCommandLogger.Debug("Found %d top-level variables", len(tomlData.Variables))
for k, v := range tomlData.Variables {
variables[k] = v
}
loadTomlCommandLogger.Debug("Found %d commands in direct TOML array", len(commands))
}
loadTomlCommandLogger.Debug("Successfully unmarshaled %d commands", len(commands))
// Use commands from wrapped structure
commands = tomlData.Commands
loadTomlCommandLogger.Debug("Successfully unmarshaled %d commands and %d variables", len(commands), len(variables))
loadTomlCommandLogger.Trace("Unmarshaled commands: %v", commands)
return commands, nil
loadTomlCommandLogger.Trace("Unmarshaled variables: %v", variables)
return commands, variables, nil
}
// ConvertYAMLToTOML converts YAML files to TOML format
@@ -476,20 +484,6 @@ func ConvertYAMLToTOML(yamlPattern string) error {
convertLogger := modifyCommandLogger.WithPrefix("ConvertYAMLToTOML").WithField("pattern", yamlPattern)
convertLogger.Debug("Starting YAML to TOML conversion")
// Load YAML commands
yamlCommands, err := LoadCommandsFromCookFiles(yamlPattern)
if err != nil {
convertLogger.Error("Failed to load YAML commands: %v", err)
return fmt.Errorf("failed to load YAML commands: %w", err)
}
if len(yamlCommands) == 0 {
convertLogger.Info("No YAML commands found for pattern: %s", yamlPattern)
return nil
}
convertLogger.Debug("Loaded %d commands from YAML", len(yamlCommands))
// Find all YAML files matching the pattern
static, pattern := SplitPattern(yamlPattern)
yamlFiles, err := doublestar.Glob(os.DirFS(static), pattern)
@@ -500,13 +494,17 @@ func ConvertYAMLToTOML(yamlPattern string) error {
convertLogger.Debug("Found %d YAML files to convert", len(yamlFiles))
if len(yamlFiles) == 0 {
convertLogger.Info("No YAML files found for pattern: %s", yamlPattern)
return nil
}
conversionCount := 0
skippedCount := 0
for _, yamlFile := range yamlFiles {
yamlFilePath := filepath.Join(static, yamlFile)
yamlFilePath = filepath.Clean(yamlFilePath)
yamlFilePath = strings.ReplaceAll(yamlFilePath, "\\", "/")
// Use centralized path resolution
yamlFilePath := ResolvePath(filepath.Join(static, yamlFile))
// Generate corresponding TOML file path
tomlFilePath := strings.TrimSuffix(yamlFilePath, filepath.Ext(yamlFilePath)) + ".toml"
@@ -528,14 +526,14 @@ func ConvertYAMLToTOML(yamlPattern string) error {
}
// Load YAML commands from this specific file
fileCommands, err := LoadCommandsFromCookFile(yamlData)
fileCommands, fileVariables, err := LoadCommandsFromCookFile(yamlData)
if err != nil {
convertLogger.Error("Failed to parse YAML file %s: %v", yamlFilePath, err)
continue
}
// Convert to TOML structure
tomlData, err := convertCommandsToTOML(fileCommands)
tomlData, err := convertCommandsToTOML(fileCommands, fileVariables)
if err != nil {
convertLogger.Error("Failed to convert commands to TOML for %s: %v", yamlFilePath, err)
continue
@@ -557,15 +555,17 @@ func ConvertYAMLToTOML(yamlPattern string) error {
}
// convertCommandsToTOML converts a slice of ModifyCommand to TOML format
func convertCommandsToTOML(commands []ModifyCommand) ([]byte, error) {
func convertCommandsToTOML(commands []ModifyCommand, variables map[string]interface{}) ([]byte, error) {
convertLogger := modifyCommandLogger.WithPrefix("convertCommandsToTOML")
convertLogger.Debug("Converting %d commands to TOML format", len(commands))
// Create TOML structure
tomlData := struct {
Commands []ModifyCommand `toml:"commands"`
Variables map[string]interface{} `toml:"variables,omitempty"`
Commands []ModifyCommand `toml:"commands"`
}{
Commands: commands,
Variables: variables,
Commands: commands,
}
// Marshal to TOML
@@ -575,6 +575,6 @@ func convertCommandsToTOML(commands []ModifyCommand) ([]byte, error) {
return nil, fmt.Errorf("failed to marshal commands to TOML: %w", err)
}
convertLogger.Debug("Successfully converted %d commands to TOML (%d bytes)", len(commands), len(tomlBytes))
convertLogger.Debug("Successfully converted %d commands and %d variables to TOML (%d bytes)", len(commands), len(variables), len(tomlBytes))
return tomlBytes, nil
}

View File

@@ -0,0 +1,313 @@
package utils
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAggregateGlobsWithDuplicates(t *testing.T) {
commands := []ModifyCommand{
{Files: []string{"*.txt", "*.md"}},
{Files: []string{"*.txt", "*.go"}}, // *.txt is duplicate
{Files: []string{"test/**/*.xml"}},
}
globs := AggregateGlobs(commands)
// Should deduplicate
assert.Equal(t, 4, len(globs))
// AggregateGlobs resolves paths, which uses forward slashes internally
assert.Contains(t, globs, ResolvePath("*.txt"))
assert.Contains(t, globs, ResolvePath("*.md"))
assert.Contains(t, globs, ResolvePath("*.go"))
assert.Contains(t, globs, ResolvePath("test/**/*.xml"))
}
func TestExpandGlobsWithActualFiles(t *testing.T) {
// Create temp dir with test files
tmpDir, err := os.MkdirTemp("", "glob-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create test files
testFile1 := filepath.Join(tmpDir, "test1.txt")
testFile2 := filepath.Join(tmpDir, "test2.txt")
testFile3 := filepath.Join(tmpDir, "test.md")
os.WriteFile(testFile1, []byte("test"), 0644)
os.WriteFile(testFile2, []byte("test"), 0644)
os.WriteFile(testFile3, []byte("test"), 0644)
// Change to temp directory so glob pattern can find files
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// Test expanding globs using ResolvePath to normalize the pattern
globs := map[string]struct{}{
ResolvePath("*.txt"): {},
}
files, err := ExpandGlobs(globs)
assert.NoError(t, err)
assert.Equal(t, 2, len(files))
}
func TestSplitPatternWithTilde(t *testing.T) {
pattern := "~/test/*.txt"
static, pat := SplitPattern(pattern)
// Should expand ~
assert.NotEqual(t, "~", static)
assert.Contains(t, pat, "*.txt")
}
func TestLoadCommandsWithDisabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "disabled-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
variables:
test: "value"
commands:
- name: "enabled_cmd"
regex: "test"
lua: "v1 * 2"
files: ["*.txt"]
- name: "disabled_cmd"
regex: "test2"
lua: "v1 * 3"
files: ["*.txt"]
disable: true
`
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte(yamlContent), 0644)
assert.NoError(t, err)
// Change to temp directory so LoadCommands can find the file with a simple pattern
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
commands, variables, err := LoadCommands([]string{"test.yml"})
assert.NoError(t, err)
// Should only load enabled command
assert.Equal(t, 1, len(commands))
assert.Equal(t, "enabled_cmd", commands[0].Name)
// Should still load variables
assert.Equal(t, 1, len(variables))
}
func TestFilterCommandsByName(t *testing.T) {
commands := []ModifyCommand{
{Name: "test_multiply"},
{Name: "test_divide"},
{Name: "other_command"},
{Name: "test_add"},
}
// Filter by "test"
filtered := FilterCommands(commands, "test")
assert.Equal(t, 3, len(filtered))
// Filter by multiple
filtered = FilterCommands(commands, "multiply,divide")
assert.Equal(t, 2, len(filtered))
}
func TestCountGlobsBeforeDedup(t *testing.T) {
commands := []ModifyCommand{
{Files: []string{"*.txt", "*.md", "*.go"}},
{Files: []string{"*.xml"}},
{Files: []string{"test/**/*.txt", "data/**/*.json"}},
}
count := CountGlobsBeforeDedup(commands)
assert.Equal(t, 6, count)
}
func TestMatchesWithMemoization(t *testing.T) {
path := "test/file.txt"
glob := "**/*.txt"
// First call
matches1, err1 := Matches(path, glob)
assert.NoError(t, err1)
assert.True(t, matches1)
// Second call should use memo
matches2, err2 := Matches(path, glob)
assert.NoError(t, err2)
assert.Equal(t, matches1, matches2)
}
func TestValidateCommand(t *testing.T) {
tests := []struct {
name string
cmd ModifyCommand
wantErr bool
}{
{
name: "Valid command",
cmd: ModifyCommand{
Regex: "test",
Lua: "v1 * 2",
Files: []string{"*.txt"},
},
wantErr: false,
},
{
name: "Valid JSON mode without regex",
cmd: ModifyCommand{
JSON: true,
Lua: "data.value = data.value * 2; modified = true",
Files: []string{"*.json"},
},
wantErr: false,
},
{
name: "Missing regex in non-JSON mode",
cmd: ModifyCommand{
Lua: "v1 * 2",
Files: []string{"*.txt"},
},
wantErr: true,
},
{
name: "Missing Lua",
cmd: ModifyCommand{
Regex: "test",
Files: []string{"*.txt"},
},
wantErr: true,
},
{
name: "Missing files",
cmd: ModifyCommand{
Regex: "test",
Lua: "v1 * 2",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.cmd.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestLoadCommandsFromTomlWithVariables(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "toml-vars-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
tomlContent := `[variables]
multiplier = 3
prefix = "PREFIX_"
[[commands]]
name = "test_cmd"
regex = "value = !num"
lua = "v1 * multiplier"
files = ["*.txt"]
`
tomlFile := filepath.Join(tmpDir, "test.toml")
err = os.WriteFile(tomlFile, []byte(tomlContent), 0644)
assert.NoError(t, err)
// Change to temp directory so glob pattern can find the file
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
commands, variables, err := LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err)
assert.Equal(t, 1, len(commands))
assert.Equal(t, 2, len(variables))
assert.Equal(t, int64(3), variables["multiplier"])
assert.Equal(t, "PREFIX_", variables["prefix"])
}
func TestConvertYAMLToTOMLSkipExisting(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "convert-skip-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create YAML file
yamlContent := `
commands:
- name: "test"
regex: "value"
lua: "v1 * 2"
files: ["*.txt"]
`
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte(yamlContent), 0644)
assert.NoError(t, err)
// Create TOML file (should skip conversion)
tomlFile := filepath.Join(tmpDir, "test.toml")
err = os.WriteFile(tomlFile, []byte("# existing"), 0644)
assert.NoError(t, err)
// Change to temp dir
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// Should skip existing TOML
err = ConvertYAMLToTOML("test.yml")
assert.NoError(t, err)
// TOML content should be unchanged
content, _ := os.ReadFile(tomlFile)
assert.Equal(t, "# existing", string(content))
}
func TestLoadCommandsWithTomlExtension(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "toml-ext-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
tomlContent := `
[variables]
test_var = "value"
[[commands]]
name = "TestCmd"
regex = "test"
lua = "return true"
files = ["*.txt"]
`
tomlFile := filepath.Join(tmpDir, "test.toml")
err = os.WriteFile(tomlFile, []byte(tomlContent), 0644)
assert.NoError(t, err)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// This should trigger the .toml suffix check in LoadCommands
commands, variables, err := LoadCommands([]string{"test.toml"})
assert.NoError(t, err)
assert.Len(t, commands, 1)
assert.Equal(t, "TestCmd", commands[0].Name)
assert.Len(t, variables, 1)
assert.Equal(t, "value", variables["test_var"])
}

View File

@@ -251,11 +251,19 @@ func TestAggregateGlobs(t *testing.T) {
globs := AggregateGlobs(commands)
// Now we properly resolve only the static part of globs
// *.xml has no static part (current dir), so it becomes resolved_dir/*.xml
// *.txt has no static part (current dir), so it becomes resolved_dir/*.txt
// *.json has no static part (current dir), so it becomes resolved_dir/*.json
// subdir/*.xml has static "subdir", so it becomes resolved_dir/subdir/*.xml
cwd, _ := os.Getwd()
resolvedCwd := ResolvePath(cwd)
expected := map[string]struct{}{
"*.xml": {},
"*.txt": {},
"*.json": {},
"subdir/*.xml": {},
resolvedCwd + "/*.xml": {},
resolvedCwd + "/*.txt": {},
resolvedCwd + "/*.json": {},
resolvedCwd + "/subdir/*.xml": {},
}
if len(globs) != len(expected) {
@@ -273,16 +281,17 @@ func TestAggregateGlobs(t *testing.T) {
func TestLoadCommandsFromCookFileSuccess(t *testing.T) {
// Arrange
yamlData := []byte(`
- name: command1
regex: "*.txt"
lua: replace
- name: command2
regex: "*.go"
lua: delete
commands:
- name: command1
regex: "*.txt"
lua: replace
- name: command2
regex: "*.go"
lua: delete
`)
// Act
commands, err := LoadCommandsFromCookFile(yamlData)
commands, _, err := LoadCommandsFromCookFile(yamlData)
// Assert
assert.NoError(t, err)
@@ -300,17 +309,18 @@ func TestLoadCommandsFromCookFileWithComments(t *testing.T) {
// Arrange
yamlData := []byte(`
# This is a comment
- name: command1
regex: "*.txt"
lua: replace
# Another comment
- name: command2
regex: "*.go"
lua: delete
commands:
- name: command1
regex: "*.txt"
lua: replace
# Another comment
- name: command2
regex: "*.go"
lua: delete
`)
// Act
commands, err := LoadCommandsFromCookFile(yamlData)
commands, _, err := LoadCommandsFromCookFile(yamlData)
// Assert
assert.NoError(t, err)
@@ -326,10 +336,10 @@ func TestLoadCommandsFromCookFileWithComments(t *testing.T) {
// Handle different YAML formatting styles (flow vs block)
func TestLoadCommandsFromCookFileWithFlowStyle(t *testing.T) {
// Arrange
yamlData := []byte(`[ { name: command1, regex: "*.txt", lua: replace }, { name: command2, regex: "*.go", lua: delete } ]`)
yamlData := []byte(`commands: [ { name: command1, regex: "*.txt", lua: replace }, { name: command2, regex: "*.go", lua: delete } ]`)
// Act
commands, err := LoadCommandsFromCookFile(yamlData)
commands, _, err := LoadCommandsFromCookFile(yamlData)
// Assert
assert.NoError(t, err)
@@ -349,8 +359,8 @@ func TestLoadCommandsFromCookFileNilOrEmptyData(t *testing.T) {
emptyData := []byte{}
// Act
commandsNil, errNil := LoadCommandsFromCookFile(nilData)
commandsEmpty, errEmpty := LoadCommandsFromCookFile(emptyData)
commandsNil, _, errNil := LoadCommandsFromCookFile(nilData)
commandsEmpty, _, errEmpty := LoadCommandsFromCookFile(emptyData)
// Assert
assert.Nil(t, errNil)
@@ -365,7 +375,7 @@ func TestLoadCommandsFromCookFileEmptyData(t *testing.T) {
yamlData := []byte(``)
// Act
commands, err := LoadCommandsFromCookFile(yamlData)
commands, _, err := LoadCommandsFromCookFile(yamlData)
// Assert
assert.NoError(t, err)
@@ -376,19 +386,20 @@ func TestLoadCommandsFromCookFileEmptyData(t *testing.T) {
func TestLoadCommandsFromCookFileWithMultipleEntries(t *testing.T) {
// Arrange
yamlData := []byte(`
- name: command1
regex: "*.txt"
lua: replace
- name: command2
regex: "*.go"
lua: delete
- name: command3
regex: "*.md"
lua: append
commands:
- name: command1
regex: "*.txt"
lua: replace
- name: command2
regex: "*.go"
lua: delete
- name: command3
regex: "*.md"
lua: append
`)
// Act
commands, err := LoadCommandsFromCookFile(yamlData)
commands, _, err := LoadCommandsFromCookFile(yamlData)
// Assert
assert.NoError(t, err)
@@ -407,26 +418,27 @@ func TestLoadCommandsFromCookFileWithMultipleEntries(t *testing.T) {
func TestLoadCommandsFromCookFileLegitExample(t *testing.T) {
// Arrange
yamlData := []byte(`
- name: crewlayabout
pattern: '<Talent identifier="crewlayabout">!anyvalue="(?<repairspeedpenalty>!num)"!anyvalue="(?<skillpenalty>!num)"!anyvalue="(?<repairspeedbonus>!num)"!anyvalue="(?<skillbonus>!num)"!anydistance="(?<distance>!num)"!anySkillBonus!anyvalue="(?<skillpenaltyv>!num)"!anyvalue="(?<skillpenaltyv1>!num)"!anyvalue="(?<skillpenaltyv2>!num)"!anyvalue="(?<skillpenaltyv3>!num)"!anyvalue="(?<skillpenaltyv4>!num)"!anyvalue="(?<repairspeedpenaltyv>!num)'
lua: |
repairspeedpenalty=round(repairspeedpenalty/2, 2)
skillpenalty=round(skillpenalty/2, 0)
repairspeedbonus=round(repairspeedbonus*2, 2)
skillbonus=round(skillbonus*2, 0)
distance=round(distance*2, 0)
skillpenaltyv=skillpenalty
skillpenaltyv1=skillpenalty
skillpenaltyv2=skillpenalty
skillpenaltyv3=skillpenalty
skillpenaltyv4=skillpenalty
repairspeedpenaltyv=round(-repairspeedpenalty/100, 2)
files:
- '**/TalentsAssistant.xml'
commands:
- name: crewlayabout
pattern: '<Talent identifier="crewlayabout">!anyvalue="(?<repairspeedpenalty>!num)"!anyvalue="(?<skillpenalty>!num)"!anyvalue="(?<repairspeedbonus>!num)"!anyvalue="(?<skillbonus>!num)"!anydistance="(?<distance>!num)"!anySkillBonus!anyvalue="(?<skillpenaltyv>!num)"!anyvalue="(?<skillpenaltyv1>!num)"!anyvalue="(?<skillpenaltyv2>!num)"!anyvalue="(?<skillpenaltyv3>!num)"!anyvalue="(?<skillpenaltyv4>!num)"!anyvalue="(?<repairspeedpenaltyv>!num)'
lua: |
repairspeedpenalty=round(repairspeedpenalty/2, 2)
skillpenalty=round(skillpenalty/2, 0)
repairspeedbonus=round(repairspeedbonus*2, 2)
skillbonus=round(skillbonus*2, 0)
distance=round(distance*2, 0)
skillpenaltyv=skillpenalty
skillpenaltyv1=skillpenalty
skillpenaltyv2=skillpenalty
skillpenaltyv3=skillpenalty
skillpenaltyv4=skillpenalty
repairspeedpenaltyv=round(-repairspeedpenalty/100, 2)
files:
- '**/TalentsAssistant.xml'
`)
// Act
commands, err := LoadCommandsFromCookFile(yamlData)
commands, _, err := LoadCommandsFromCookFile(yamlData)
// Assert
assert.NoError(t, err)
@@ -535,7 +547,7 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
}
// Execute function
commands, err := LoadCommandsFromCookFiles("")
commands, _, err := LoadCommandsFromCookFiles("")
// Assertions
if err != nil {
@@ -601,7 +613,7 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
// }
//
// // Execute function
// commands, err := LoadCommandsFromCookFiles("")
// commands, _, err := LoadCommandsFromCookFiles("")
//
// // Assertions
// if err != nil {
@@ -685,7 +697,7 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
// }
//
// // Execute function
// commands, err := LoadCommandsFromCookFiles("")
// commands, _, err := LoadCommandsFromCookFiles("")
//
// // Assertions
// if err != nil {
@@ -697,6 +709,58 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
// }
// }
func TestExpandGlobsMemoization(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "expand-globs-memo-test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
err = os.WriteFile(filepath.Join(tmpDir, "test1.go"), []byte("test"), 0644)
if err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
err = os.WriteFile(filepath.Join(tmpDir, "test2.go"), []byte("test"), 0644)
if err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
origDir, _ := os.Getwd()
os.Chdir(tmpDir)
defer os.Chdir(origDir)
cwd, _ := os.Getwd()
resolvedCwd := ResolvePath(cwd)
pattern1 := resolvedCwd + "/*.go"
patterns := map[string]struct{}{pattern1: {}}
globMemoTable = make(map[string][]string)
files1, err := ExpandGlobs(patterns)
if err != nil {
t.Fatalf("ExpandGlobs failed: %v", err)
}
if len(files1) != 2 {
t.Fatalf("Expected 2 files, got %d", len(files1))
}
if len(globMemoTable) != 1 {
t.Fatalf("Expected 1 entry in memo table, got %d", len(globMemoTable))
}
files2, err := ExpandGlobs(patterns)
if err != nil {
t.Fatalf("ExpandGlobs failed: %v", err)
}
if len(files2) != 2 {
t.Fatalf("Expected 2 files, got %d", len(files2))
}
if len(globMemoTable) != 1 {
t.Fatalf("Expected memo table to still have 1 entry, got %d", len(globMemoTable))
}
}
// LoadCommandsFromCookFile returns an error for a malformed YAML file
// func TestLoadCommandsFromCookFilesMalformedYAML(t *testing.T) {
// // Setup test directory with mock YAML files
@@ -792,7 +856,7 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
// }
//
// // Execute function
// commands, err := LoadCommandsFromCookFiles("")
// commands, _, err := LoadCommandsFromCookFiles("")
//
// // Assertions
// if err == nil {
@@ -859,7 +923,7 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
// }
//
// // Execute function
// commands, err := LoadCommandsFromCookFiles("")
// commands, _, err := LoadCommandsFromCookFiles("")
//
// // Assertions
// if err != nil {
@@ -929,7 +993,7 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
// }
//
// // Execute function
// commands, err := LoadCommandsFromCookFiles("")
// commands, _, err := LoadCommandsFromCookFiles("")
//
// // Assertions
// if err != nil {
@@ -987,7 +1051,7 @@ func TestLoadCommandsFromCookFilesNoYamlFiles(t *testing.T) {
// }
//
// // Execute function
// commands, err := LoadCommandsFromCookFiles("")
// commands, _, err := LoadCommandsFromCookFiles("")
//
// // Assertions
// if err != nil {

View File

@@ -0,0 +1,93 @@
package utils
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
// TestConvertYAMLToTOMLReadError tests error handling when YAML file can't be read
func TestConvertYAMLToTOMLReadError(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "convert-read-error-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create YAML file with no read permissions (on Unix) or delete it after creation
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte("commands:\n - name: test\n"), 0000)
assert.NoError(t, err)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// This should fail to read but not crash
err = ConvertYAMLToTOML("test.yml")
// Function continues on error, doesn't return error
assert.NoError(t, err)
// Fix permissions for cleanup
os.Chmod(yamlFile, 0644)
}
// TestConvertYAMLToTOMLParseError tests error handling when YAML is invalid
func TestConvertYAMLToTOMLParseError(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "convert-parse-error-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create invalid YAML
yamlFile := filepath.Join(tmpDir, "invalid.yml")
err = os.WriteFile(yamlFile, []byte("commands:\n - [this is not valid yaml}}"), 0644)
assert.NoError(t, err)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// This should fail to parse but not crash
err = ConvertYAMLToTOML("invalid.yml")
assert.NoError(t, err)
// TOML file should not exist
_, statErr := os.Stat(filepath.Join(tmpDir, "invalid.toml"))
assert.True(t, os.IsNotExist(statErr))
}
// TestConvertYAMLToTOMLWriteError tests error handling when TOML file can't be written
func TestConvertYAMLToTOMLWriteError(t *testing.T) {
if os.Getenv("CI") != "" {
t.Skip("Skipping write permission test in CI")
}
tmpDir, err := os.MkdirTemp("", "convert-write-error-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create valid YAML
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte("commands:\n - name: test\n regex: test\n lua: v1\n files: [test.txt]\n"), 0644)
assert.NoError(t, err)
// Create output directory with no write permissions
outputDir := filepath.Join(tmpDir, "readonly")
err = os.Mkdir(outputDir, 0555)
assert.NoError(t, err)
defer os.Chmod(outputDir, 0755) // Fix for cleanup
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// Move YAML into readonly dir
newYamlFile := filepath.Join(outputDir, "test.yml")
os.Rename(yamlFile, newYamlFile)
os.Chdir(outputDir)
// This should fail to write but not crash
err = ConvertYAMLToTOML("test.yml")
assert.NoError(t, err)
}

79
utils/path.go Normal file
View File

@@ -0,0 +1,79 @@
package utils
import (
"os"
"path/filepath"
"strings"
logger "git.site.quack-lab.dev/dave/cylogger"
)
// pathLogger is a scoped logger for the utils/path package.
var pathLogger = logger.Default.WithPrefix("utils/path")
// ResolvePath resolves a path to an absolute path, handling ~ expansion and cleaning
func ResolvePath(path string) string {
resolvePathLogger := pathLogger.WithPrefix("ResolvePath").WithField("inputPath", path)
resolvePathLogger.Trace("Resolving path: %q", path)
// Handle empty path
if path == "" {
resolvePathLogger.Trace("Empty path, returning empty string")
return ""
}
// Check if path is absolute
if filepath.IsAbs(path) {
resolvePathLogger.Trace("Path is already absolute: %q", path)
cleaned := filepath.ToSlash(filepath.Clean(path))
resolvePathLogger.Trace("Cleaned absolute path: %q", cleaned)
return cleaned
}
// Handle ~ expansion
if strings.HasPrefix(path, "~") {
homeDir, _ := os.UserHomeDir()
if strings.HasPrefix(path, "~/") || strings.HasPrefix(path, "~\\") {
path = filepath.Join(homeDir, path[2:])
} else if path == "~" {
path = homeDir
} else {
// ~something (like ~~), treat first ~ as home expansion, rest as literal
path = homeDir + path[1:]
}
resolvePathLogger.Trace("Expanded ~ to home directory: %q", path)
}
// Make absolute if not already
if !filepath.IsAbs(path) {
absPath, err := filepath.Abs(path)
if err != nil {
resolvePathLogger.Error("Failed to get absolute path: %v", err)
return filepath.ToSlash(filepath.Clean(path))
}
resolvePathLogger.Trace("Made path absolute: %q -> %q", path, absPath)
path = absPath
}
// Clean the path and normalize to forward slashes for consistency
cleaned := filepath.ToSlash(filepath.Clean(path))
resolvePathLogger.Trace("Final cleaned path: %q", cleaned)
return cleaned
}
// GetRelativePath returns the relative path from base to target
func GetRelativePath(base, target string) (string, error) {
getRelativePathLogger := pathLogger.WithPrefix("GetRelativePath")
getRelativePathLogger.Debug("Getting relative path from %q to %q", base, target)
relPath, err := filepath.Rel(base, target)
if err != nil {
getRelativePathLogger.Error("Failed to get relative path: %v", err)
return "", err
}
// Use forward slashes for consistency
relPath = filepath.ToSlash(relPath)
getRelativePathLogger.Debug("Relative path: %q", relPath)
return relPath, nil
}

386
utils/path_test.go Normal file
View File

@@ -0,0 +1,386 @@
package utils
import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestResolvePath(t *testing.T) {
// Save original working directory
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
// Create a temporary directory for testing
tmpDir, err := os.MkdirTemp("", "path_test")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
tests := []struct {
name string
input string
expected string
setup func() // Optional setup function
}{
{
name: "Empty path",
input: "",
expected: "",
},
{
name: "Already absolute path",
input: func() string {
if runtime.GOOS == "windows" {
return "C:/absolute/path/file.txt"
}
return "/absolute/path/file.txt"
}(),
expected: func() string {
if runtime.GOOS == "windows" {
return "C:/absolute/path/file.txt"
}
return "/absolute/path/file.txt"
}(),
},
{
name: "Relative path",
input: "relative/file.txt",
expected: func() string {
abs, _ := filepath.Abs("relative/file.txt")
return strings.ReplaceAll(abs, "\\", "/")
}(),
},
{
name: "Tilde expansion - home only",
input: "~",
expected: func() string {
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
home = os.Getenv("USERPROFILE")
}
return strings.ReplaceAll(filepath.Clean(home), "\\", "/")
}(),
},
{
name: "Tilde expansion - with subpath",
input: "~/Documents/file.txt",
expected: func() string {
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
home = os.Getenv("USERPROFILE")
}
expected := filepath.Join(home, "Documents", "file.txt")
return strings.ReplaceAll(filepath.Clean(expected), "\\", "/")
}(),
},
{
name: "Path normalization - double slashes",
input: "path//to//file.txt",
expected: func() string {
abs, _ := filepath.Abs("path/to/file.txt")
return strings.ReplaceAll(abs, "\\", "/")
}(),
},
{
name: "Path normalization - . and ..",
input: "path/./to/../file.txt",
expected: func() string {
abs, _ := filepath.Abs("path/file.txt")
return strings.ReplaceAll(abs, "\\", "/")
}(),
},
{
name: "Windows backslash normalization",
input: "path\\to\\file.txt",
expected: func() string {
abs, _ := filepath.Abs("path/to/file.txt")
return strings.ReplaceAll(abs, "\\", "/")
}(),
},
{
name: "Mixed separators with tilde",
input: "~/Documents\\file.txt",
expected: func() string {
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
home = os.Getenv("USERPROFILE")
}
expected := filepath.Join(home, "Documents", "file.txt")
return strings.ReplaceAll(filepath.Clean(expected), "\\", "/")
}(),
},
{
name: "Relative path from current directory",
input: "./file.txt",
expected: func() string {
abs, _ := filepath.Abs("file.txt")
return strings.ReplaceAll(abs, "\\", "/")
}(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
result := ResolvePath(tt.input)
assert.Equal(t, tt.expected, result, "ResolvePath(%q) = %q, want %q", tt.input, result, tt.expected)
})
}
}
func TestResolvePathWithWorkingDirectoryChange(t *testing.T) {
// Save original working directory
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
// Create temporary directories
tmpDir, err := os.MkdirTemp("", "path_test")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
subDir := filepath.Join(tmpDir, "subdir")
err = os.MkdirAll(subDir, 0755)
assert.NoError(t, err)
// Change to subdirectory
err = os.Chdir(subDir)
assert.NoError(t, err)
// Test relative path resolution from new working directory
result := ResolvePath("../test.txt")
expected := filepath.Join(tmpDir, "test.txt")
expected = strings.ReplaceAll(filepath.Clean(expected), "\\", "/")
assert.Equal(t, expected, result)
}
func TestResolvePathComplexTilde(t *testing.T) {
// Test complex tilde patterns
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
home = os.Getenv("USERPROFILE")
}
if home == "" {
t.Skip("Cannot determine home directory for tilde expansion tests")
}
tests := []struct {
input string
expected string
}{
{
input: "~",
expected: strings.ReplaceAll(filepath.Clean(home), "\\", "/"),
},
{
input: "~/",
expected: strings.ReplaceAll(filepath.Clean(home), "\\", "/"),
},
{
input: "~~",
expected: func() string {
// ~~ should be treated as ~ followed by ~ (tilde expansion)
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
home = os.Getenv("USERPROFILE")
}
if home != "" {
// First ~ gets expanded, second ~ remains
return strings.ReplaceAll(filepath.Clean(home+"~"), "\\", "/")
}
abs, _ := filepath.Abs("~~")
return strings.ReplaceAll(abs, "\\", "/")
}(),
},
{
input: func() string {
if runtime.GOOS == "windows" {
return "C:/not/tilde/path"
}
return "/not/tilde/path"
}(),
expected: func() string {
if runtime.GOOS == "windows" {
return "C:/not/tilde/path"
}
return "/not/tilde/path"
}(),
},
}
for _, tt := range tests {
t.Run("Complex tilde: "+tt.input, func(t *testing.T) {
result := ResolvePath(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetRelativePath(t *testing.T) {
// Create temporary directories for testing
tmpDir, err := os.MkdirTemp("", "relative_path_test")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
baseDir := filepath.Join(tmpDir, "base")
targetDir := filepath.Join(tmpDir, "target")
subDir := filepath.Join(targetDir, "subdir")
err = os.MkdirAll(baseDir, 0755)
assert.NoError(t, err)
err = os.MkdirAll(subDir, 0755)
assert.NoError(t, err)
tests := []struct {
name string
base string
target string
expected string
wantErr bool
}{
{
name: "Target is subdirectory of base",
base: baseDir,
target: filepath.Join(baseDir, "subdir"),
expected: "subdir",
wantErr: false,
},
{
name: "Target is parent of base",
base: filepath.Join(baseDir, "subdir"),
target: baseDir,
expected: "..",
wantErr: false,
},
{
name: "Target is sibling directory",
base: baseDir,
target: targetDir,
expected: "../target",
wantErr: false,
},
{
name: "Same directory",
base: baseDir,
target: baseDir,
expected: ".",
wantErr: false,
},
{
name: "With tilde expansion",
base: baseDir,
target: filepath.Join(baseDir, "file.txt"),
expected: "file.txt",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := GetRelativePath(tt.base, tt.target)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestResolvePathRegression(t *testing.T) {
// This test specifically addresses the original bug:
// "~ is NOT BEING FUCKING RESOLVED"
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
home = os.Getenv("USERPROFILE")
}
if home == "" {
t.Skip("Cannot determine home directory for regression test")
}
// Test the exact pattern from the bug report
testPath := "~/Seafile/activitywatch/sync.yml"
result := ResolvePath(testPath)
expected := filepath.Join(home, "Seafile", "activitywatch", "sync.yml")
expected = strings.ReplaceAll(filepath.Clean(expected), "\\", "/")
assert.Equal(t, expected, result, "Tilde expansion bug not fixed!")
assert.NotContains(t, result, "~", "Tilde still present in resolved path!")
// Convert both to forward slashes for comparison
homeForwardSlash := strings.ReplaceAll(home, "\\", "/")
assert.Contains(t, result, homeForwardSlash, "Home directory not found in resolved path!")
}
func TestResolvePathEdgeCases(t *testing.T) {
// Save original working directory
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
tests := []struct {
name string
input string
setup func()
shouldPanic bool
}{
{
name: "Just dot",
input: ".",
},
{
name: "Just double dot",
input: "..",
},
{
name: "Triple dot",
input: "...",
},
{
name: "Multiple leading dots",
input: "./.././../file.txt",
},
{
name: "Path with spaces",
input: "path with spaces/file.txt",
},
{
name: "Very long relative path",
input: strings.Repeat("../", 10) + "file.txt",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
if tt.shouldPanic {
assert.Panics(t, func() {
ResolvePath(tt.input)
})
} else {
// Should not panic
assert.NotPanics(t, func() {
ResolvePath(tt.input)
})
// Result should be a valid absolute path
result := ResolvePath(tt.input)
if tt.input != "" {
assert.True(t, filepath.IsAbs(result) || result == "", "Result should be absolute or empty")
}
}
})
}
}