Remove some unused shit and write tests for coverage

This commit is contained in:
2025-12-19 12:12:42 +01:00
parent 1df0263a42
commit da5b621cb6
19 changed files with 1892 additions and 390 deletions

View File

@@ -1,28 +0,0 @@
package main
import (
"time"
logger "git.site.quack-lab.dev/dave/cylogger"
)
func main() {
// Initialize logger with DEBUG level
logger.Init(logger.LevelDebug)
// Test different log levels
logger.Info("This is an info message")
logger.Debug("This is a debug message")
logger.Warning("This is a warning message")
logger.Error("This is an error message")
logger.Trace("This is a trace message (not visible at DEBUG level)")
// Test with a goroutine
logger.SafeGo(func() {
time.Sleep(10 * time.Millisecond)
logger.Info("Message from goroutine")
})
// Wait for goroutine to complete
time.Sleep(20 * time.Millisecond)
}

View File

@@ -331,7 +331,14 @@ func convertValueToJSONString(value interface{}) string {
// findArrayElementRemovalRange finds the exact byte range to remove for an array element // findArrayElementRemovalRange finds the exact byte range to remove for an array element
func findArrayElementRemovalRange(content, arrayPath string, elementIndex int) (int, int) { func findArrayElementRemovalRange(content, arrayPath string, elementIndex int) (int, int) {
// Get the array using gjson // Get the array using gjson
arrayResult := gjson.Get(content, arrayPath) var arrayResult gjson.Result
if arrayPath == "" {
// Root-level array
arrayResult = gjson.Parse(content)
} else {
arrayResult = gjson.Get(content, arrayPath)
}
if !arrayResult.Exists() || !arrayResult.IsArray() { if !arrayResult.Exists() || !arrayResult.IsArray() {
return -1, -1 return -1, -1
} }
@@ -455,16 +462,9 @@ func findDeepChanges(basePath string, original, modified interface{}) map[string
} }
} }
} }
default:
// For primitive types, compare directly
if !deepEqual(original, modified) {
if basePath == "" {
changes[""] = modified
} else {
changes[basePath] = modified
}
}
} }
// Note: No default case needed - JSON data from unmarshaling is always
// map[string]interface{} or []interface{} at the top level
return changes return changes
} }
@@ -531,112 +531,61 @@ func deepEqual(a, b interface{}) bool {
} }
} }
// ToLuaTable converts a Go interface{} to a Lua table recursively // ToLuaTable converts a Go interface{} (map or array) to a Lua table
// This should only be called with map[string]interface{} or []interface{} from JSON unmarshaling
func ToLuaTable(L *lua.LState, data interface{}) (*lua.LTable, error) { func ToLuaTable(L *lua.LState, data interface{}) (*lua.LTable, error) {
toLuaTableLogger := jsonLogger.WithPrefix("ToLuaTable")
toLuaTableLogger.Debug("Converting Go interface to Lua table")
toLuaTableLogger.Trace("Input data type: %T", data)
switch v := data.(type) { switch v := data.(type) {
case map[string]interface{}: case map[string]interface{}:
toLuaTableLogger.Debug("Converting map to Lua table")
table := L.CreateTable(0, len(v)) table := L.CreateTable(0, len(v))
for key, value := range v { for key, value := range v {
luaValue, err := ToLuaValue(L, value) table.RawSetString(key, ToLuaValue(L, value))
if err != nil {
toLuaTableLogger.Error("Failed to convert map value for key %q: %v", key, err)
return nil, err
}
table.RawSetString(key, luaValue)
} }
return table, nil return table, nil
case []interface{}: case []interface{}:
toLuaTableLogger.Debug("Converting slice to Lua table")
table := L.CreateTable(len(v), 0) table := L.CreateTable(len(v), 0)
for i, value := range v { for i, value := range v {
luaValue, err := ToLuaValue(L, value) table.RawSetInt(i+1, ToLuaValue(L, value)) // Lua arrays are 1-indexed
if err != nil {
toLuaTableLogger.Error("Failed to convert slice value at index %d: %v", i, err)
return nil, err
}
table.RawSetInt(i+1, luaValue) // Lua arrays are 1-indexed
} }
return table, nil return table, nil
case string:
toLuaTableLogger.Debug("Converting string to Lua string")
return nil, fmt.Errorf("expected table or array, got string")
case float64:
toLuaTableLogger.Debug("Converting float64 to Lua number")
return nil, fmt.Errorf("expected table or array, got number")
case bool:
toLuaTableLogger.Debug("Converting bool to Lua boolean")
return nil, fmt.Errorf("expected table or array, got boolean")
case nil:
toLuaTableLogger.Debug("Converting nil to Lua nil")
return nil, fmt.Errorf("expected table or array, got nil")
default: default:
toLuaTableLogger.Error("Unsupported type for Lua table conversion: %T", v) // This should only happen with invalid JSON (root-level primitives)
return nil, fmt.Errorf("unsupported type for Lua table conversion: %T", v) return nil, fmt.Errorf("expected table or array, got %T", v)
} }
} }
// ToLuaValue converts a Go interface{} to a Lua value // ToLuaValue converts a Go interface{} to a Lua value
func ToLuaValue(L *lua.LState, data interface{}) (lua.LValue, error) { func ToLuaValue(L *lua.LState, data interface{}) lua.LValue {
toLuaValueLogger := jsonLogger.WithPrefix("ToLuaValue")
toLuaValueLogger.Debug("Converting Go interface to Lua value")
toLuaValueLogger.Trace("Input data type: %T", data)
switch v := data.(type) { switch v := data.(type) {
case map[string]interface{}: case map[string]interface{}:
toLuaValueLogger.Debug("Converting map to Lua table")
table := L.CreateTable(0, len(v)) table := L.CreateTable(0, len(v))
for key, value := range v { for key, value := range v {
luaValue, err := ToLuaValue(L, value) table.RawSetString(key, ToLuaValue(L, value))
if err != nil {
toLuaValueLogger.Error("Failed to convert map value for key %q: %v", key, err)
return lua.LNil, err
}
table.RawSetString(key, luaValue)
} }
return table, nil return table
case []interface{}: case []interface{}:
toLuaValueLogger.Debug("Converting slice to Lua table")
table := L.CreateTable(len(v), 0) table := L.CreateTable(len(v), 0)
for i, value := range v { for i, value := range v {
luaValue, err := ToLuaValue(L, value) table.RawSetInt(i+1, ToLuaValue(L, value)) // Lua arrays are 1-indexed
if err != nil {
toLuaValueLogger.Error("Failed to convert slice value at index %d: %v", i, err)
return lua.LNil, err
}
table.RawSetInt(i+1, luaValue) // Lua arrays are 1-indexed
} }
return table, nil return table
case string: case string:
toLuaValueLogger.Debug("Converting string to Lua string") return lua.LString(v)
return lua.LString(v), nil
case float64: case float64:
toLuaValueLogger.Debug("Converting float64 to Lua number") return lua.LNumber(v)
return lua.LNumber(v), nil
case bool: case bool:
toLuaValueLogger.Debug("Converting bool to Lua boolean") return lua.LBool(v)
return lua.LBool(v), nil
case nil: case nil:
toLuaValueLogger.Debug("Converting nil to Lua nil") return lua.LNil
return lua.LNil, nil
default: default:
toLuaValueLogger.Error("Unsupported type for Lua value conversion: %T", v) // This should never happen with JSON-unmarshaled data
return lua.LNil, fmt.Errorf("unsupported type for Lua value conversion: %T", v) return lua.LNil
} }
} }

View File

@@ -0,0 +1,283 @@
package processor
import (
"cook/utils"
"testing"
"github.com/stretchr/testify/assert"
)
// TestJSONFloat tests line 298 - float formatting for non-integer floats
func TestJSONFloatFormatting(t *testing.T) {
jsonContent := `{
"value": 10.5,
"another": 3.14159
}`
command := utils.ModifyCommand{
Name: "test_float",
JSON: true,
Lua: `
data.value = data.value * 2
data.another = data.another * 10
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "21") // 10.5 * 2
assert.Contains(t, result, "31.4159") // 3.14159 * 10
}
// TestJSONNestedObjectAddition tests lines 303-320 - map[string]interface{} case
func TestJSONNestedObjectAddition(t *testing.T) {
jsonContent := `{
"items": {}
}`
command := utils.ModifyCommand{
Name: "test_nested",
JSON: true,
Lua: `
data.items.newObject = {
name = "test",
value = 42,
enabled = true
}
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"newObject"`)
assert.Contains(t, result, `"name"`)
assert.Contains(t, result, `"test"`)
assert.Contains(t, result, `"value"`)
assert.Contains(t, result, "42")
}
// TestJSONKeyWithQuotes tests line 315 - key escaping with quotes
func TestJSONKeyWithQuotes(t *testing.T) {
jsonContent := `{
"data": {}
}`
command := utils.ModifyCommand{
Name: "test_key_quotes",
JSON: true,
Lua: `
data.data["key-with-dash"] = "value1"
data.data.normalKey = "value2"
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"key-with-dash"`)
assert.Contains(t, result, `"normalKey"`)
}
// TestJSONArrayInValue tests lines 321-327 - default case with json.Marshal for arrays
func TestJSONArrayInValue(t *testing.T) {
jsonContent := `{
"data": {}
}`
command := utils.ModifyCommand{
Name: "test_array_value",
JSON: true,
Lua: `
data.data.items = {1, 2, 3, 4, 5}
data.data.strings = {"a", "b", "c"}
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"items"`)
assert.Contains(t, result, `[1,2,3,4,5]`)
assert.Contains(t, result, `"strings"`)
assert.Contains(t, result, `["a","b","c"]`)
}
// TestJSONRootArrayElementRemoval tests line 422 - removing from root-level array
func TestJSONRootArrayElementRemoval(t *testing.T) {
jsonContent := `[
{"id": 1, "name": "first"},
{"id": 2, "name": "second"},
{"id": 3, "name": "third"}
]`
command := utils.ModifyCommand{
Name: "test_root_array_removal",
JSON: true,
Lua: `
-- Remove the second element
table.remove(data, 2)
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"first"`)
assert.Contains(t, result, `"third"`)
assert.NotContains(t, result, `"second"`)
}
// TestJSONRootArrayElementChange tests lines 434 and 450 - changing primitive values in root array
func TestJSONRootArrayElementChange(t *testing.T) {
jsonContent := `[10, 20, 30, 40, 50]`
command := utils.ModifyCommand{
Name: "test_root_array_change",
JSON: true,
Lua: `
-- Double all values
for i = 1, #data do
data[i] = data[i] * 2
end
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "20")
assert.Contains(t, result, "40")
assert.Contains(t, result, "60")
assert.Contains(t, result, "80")
assert.Contains(t, result, "100")
assert.NotContains(t, result, "10,")
}
// TestJSONRootArrayStringElements tests deepEqual with strings in root array
func TestJSONRootArrayStringElements(t *testing.T) {
jsonContent := `["apple", "banana", "cherry"]`
command := utils.ModifyCommand{
Name: "test_root_array_strings",
JSON: true,
Lua: `
data[2] = "orange"
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, `"apple"`)
assert.Contains(t, result, `"orange"`)
assert.Contains(t, result, `"cherry"`)
assert.NotContains(t, result, `"banana"`)
}
// TestJSONComplexNestedStructure tests multiple untested paths together
func TestJSONComplexNestedStructure(t *testing.T) {
jsonContent := `{
"config": {
"multiplier": 2.5
}
}`
command := utils.ModifyCommand{
Name: "test_complex",
JSON: true,
Lua: `
-- Add nested object with array
data.config.settings = {
enabled = true,
values = {1.5, 2.5, 3.5},
names = {"alpha", "beta"}
}
-- Change float
data.config.multiplier = 7.777
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "7.777")
assert.Contains(t, result, `"settings"`)
assert.Contains(t, result, `"values"`)
assert.Contains(t, result, `[1.5,2.5,3.5]`)
}
// TestJSONRemoveFirstArrayElement tests line 358-365 - removing first element with comma handling
func TestJSONRemoveFirstArrayElement(t *testing.T) {
jsonContent := `{
"items": [1, 2, 3, 4, 5]
}`
command := utils.ModifyCommand{
Name: "test_remove_first",
JSON: true,
Lua: `
table.remove(data.items, 1)
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.NotContains(t, result, "[1,")
assert.Contains(t, result, "2")
assert.Contains(t, result, "5")
}
// TestJSONRemoveLastArrayElement tests line 366-374 - removing last element with comma handling
func TestJSONRemoveLastArrayElement(t *testing.T) {
jsonContent := `{
"items": [1, 2, 3, 4, 5]
}`
command := utils.ModifyCommand{
Name: "test_remove_last",
JSON: true,
Lua: `
table.remove(data.items, 5)
modified = true
`,
}
commands, err := ProcessJSON(jsonContent, command, "test.json")
assert.NoError(t, err)
assert.NotEmpty(t, commands)
result, _ := utils.ExecuteModifications(commands, jsonContent)
assert.Contains(t, result, "1")
assert.Contains(t, result, "4")
assert.NotContains(t, result, ", 5")
}

View File

@@ -0,0 +1,153 @@
package processor
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDeepEqual(t *testing.T) {
tests := []struct {
name string
a interface{}
b interface{}
expected bool
}{
{
name: "both nil",
a: nil,
b: nil,
expected: true,
},
{
name: "first nil",
a: nil,
b: "something",
expected: false,
},
{
name: "second nil",
a: "something",
b: nil,
expected: false,
},
{
name: "equal primitives",
a: 42,
b: 42,
expected: true,
},
{
name: "different primitives",
a: 42,
b: 43,
expected: false,
},
{
name: "equal strings",
a: "hello",
b: "hello",
expected: true,
},
{
name: "equal maps",
a: map[string]interface{}{
"key1": "value1",
"key2": 42,
},
b: map[string]interface{}{
"key1": "value1",
"key2": 42,
},
expected: true,
},
{
name: "maps different lengths",
a: map[string]interface{}{
"key1": "value1",
},
b: map[string]interface{}{
"key1": "value1",
"key2": 42,
},
expected: false,
},
{
name: "maps different values",
a: map[string]interface{}{
"key1": "value1",
},
b: map[string]interface{}{
"key1": "value2",
},
expected: false,
},
{
name: "map vs non-map",
a: map[string]interface{}{
"key1": "value1",
},
b: "not a map",
expected: false,
},
{
name: "equal arrays",
a: []interface{}{1, 2, 3},
b: []interface{}{1, 2, 3},
expected: true,
},
{
name: "arrays different lengths",
a: []interface{}{1, 2},
b: []interface{}{1, 2, 3},
expected: false,
},
{
name: "arrays different values",
a: []interface{}{1, 2, 3},
b: []interface{}{1, 2, 4},
expected: false,
},
{
name: "array vs non-array",
a: []interface{}{1, 2, 3},
b: "not an array",
expected: false,
},
{
name: "nested equal structures",
a: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 3},
},
},
b: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 3},
},
},
expected: true,
},
{
name: "nested different structures",
a: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 3},
},
},
b: map[string]interface{}{
"outer": map[string]interface{}{
"inner": []interface{}{1, 2, 4},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := deepEqual(tt.a, tt.b)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -88,8 +88,7 @@ func TestToLuaValue(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, err := ToLuaValue(L, tt.input) result := ToLuaValue(L, tt.input)
assert.NoError(t, err)
assert.Equal(t, tt.expected, result.String()) assert.Equal(t, tt.expected, result.String())
}) })
} }

View File

@@ -485,6 +485,79 @@ function setAttr(element, attrName, value)
element._attr[attrName] = tostring(value) element._attr[attrName] = tostring(value)
end end
--- Find first element with a specific tag name (searches direct children only)
--- @param parent table The parent XML element
--- @param tagName string The tag name to search for
--- @return table|nil The first matching element or nil
function findFirstElement(parent, tagName)
if not parent._children then return nil end
for _, child in ipairs(parent._children) do
if child._tag == tagName then
return child
end
end
return nil
end
--- Add a child element to a parent
--- @param parent table The parent XML element
--- @param child table The child element to add
function addChild(parent, child)
if not parent._children then
parent._children = {}
end
table.insert(parent._children, child)
end
--- Remove all children with a specific tag name
--- @param parent table The parent XML element
--- @param tagName string The tag name to remove
--- @return number Count of removed children
function removeChildren(parent, tagName)
if not parent._children then return 0 end
local removed = 0
local i = 1
while i <= #parent._children do
if parent._children[i]._tag == tagName then
table.remove(parent._children, i)
removed = removed + 1
else
i = i + 1
end
end
return removed
end
--- Get all direct children with a specific tag name
--- @param parent table The parent XML element
--- @param tagName string The tag name to search for
--- @return table Array of matching children
function getChildren(parent, tagName)
local results = {}
if not parent._children then return results end
for _, child in ipairs(parent._children) do
if child._tag == tagName then
table.insert(results, child)
end
end
return results
end
--- Count children with a specific tag name
--- @param parent table The parent XML element
--- @param tagName string The tag name to count
--- @return number Count of matching children
function countChildren(parent, tagName)
if not parent._children then return 0 end
local count = 0
for _, child in ipairs(parent._children) do
if child._tag == tagName then
count = count + 1
end
end
return count
end
-- ============================================================================ -- ============================================================================
-- JSON HELPER FUNCTIONS -- JSON HELPER FUNCTIONS
-- ============================================================================ -- ============================================================================

View File

@@ -0,0 +1,218 @@
package processor
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
lua "github.com/yuin/gopher-lua"
)
// Test replaceVariables function
func TestReplaceVariables(t *testing.T) {
// Setup global variables
globalVariables = map[string]interface{}{
"multiplier": 2.5,
"prefix": "TEST_",
"enabled": true,
"disabled": false,
"count": 42,
}
defer func() {
globalVariables = make(map[string]interface{})
}()
tests := []struct {
name string
input string
expected string
}{
{
name: "Replace numeric variable",
input: "v1 * $multiplier",
expected: "v1 * 2.5",
},
{
name: "Replace string variable",
input: `s1 = $prefix .. "value"`,
expected: `s1 = "TEST_" .. "value"`,
},
{
name: "Replace boolean true",
input: "enabled = $enabled",
expected: "enabled = true",
},
{
name: "Replace boolean false",
input: "disabled = $disabled",
expected: "disabled = false",
},
{
name: "Replace integer",
input: "count = $count",
expected: "count = 42",
},
{
name: "Multiple replacements",
input: "$count * $multiplier",
expected: "42 * 2.5",
},
{
name: "No variables",
input: "v1 * 2",
expected: "v1 * 2",
},
{
name: "Undefined variable",
input: "v1 * $undefined",
expected: "v1 * $undefined",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := replaceVariables(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// Test SetVariables with all type cases
func TestSetVariablesAllTypes(t *testing.T) {
vars := map[string]interface{}{
"int_val": 42,
"int64_val": int64(100),
"float32_val": float32(3.14),
"float64_val": 2.718,
"bool_true": true,
"bool_false": false,
"string_val": "hello",
}
SetVariables(vars)
// Create Lua state to verify
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
// Verify int64
int64Val := L.GetGlobal("int64_val")
assert.Equal(t, lua.LTNumber, int64Val.Type())
assert.Equal(t, 100.0, float64(int64Val.(lua.LNumber)))
// Verify float32
float32Val := L.GetGlobal("float32_val")
assert.Equal(t, lua.LTNumber, float32Val.Type())
assert.InDelta(t, 3.14, float64(float32Val.(lua.LNumber)), 0.01)
// Verify bool true
boolTrue := L.GetGlobal("bool_true")
assert.Equal(t, lua.LTBool, boolTrue.Type())
assert.True(t, bool(boolTrue.(lua.LBool)))
// Verify bool false
boolFalse := L.GetGlobal("bool_false")
assert.Equal(t, lua.LTBool, boolFalse.Type())
assert.False(t, bool(boolFalse.(lua.LBool)))
// Verify string
stringVal := L.GetGlobal("string_val")
assert.Equal(t, lua.LTString, stringVal.Type())
assert.Equal(t, "hello", string(stringVal.(lua.LString)))
}
// Test HTTP fetch with test server
func TestFetchWithTestServer(t *testing.T) {
// Create test HTTP server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request
assert.Equal(t, "GET", r.Method)
// Send response
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status": "success"}`))
}))
defer server.Close()
// Test fetch
L := lua.NewState()
defer L.Close()
L.SetGlobal("fetch", L.NewFunction(fetch))
script := `
response = fetch("` + server.URL + `")
assert(response ~= nil, "Expected response")
assert(response.ok == true, "Expected ok to be true")
assert(response.status == 200, "Expected status 200")
assert(response.body == '{"status": "success"}', "Expected correct body")
`
err := L.DoString(script)
assert.NoError(t, err)
}
func TestFetchWithTestServerPOST(t *testing.T) {
// Create test HTTP server
receivedBody := ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
// Read body
buf := make([]byte, 1024)
n, _ := r.Body.Read(buf)
receivedBody = string(buf[:n])
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`{"created": true}`))
}))
defer server.Close()
L := lua.NewState()
defer L.Close()
L.SetGlobal("fetch", L.NewFunction(fetch))
script := `
local opts = {
method = "POST",
headers = {["Content-Type"] = "application/json"},
body = '{"test": "data"}'
}
response = fetch("` + server.URL + `", opts)
assert(response ~= nil, "Expected response")
assert(response.ok == true, "Expected ok to be true")
assert(response.status == 201, "Expected status 201")
`
err := L.DoString(script)
assert.NoError(t, err)
assert.Equal(t, `{"test": "data"}`, receivedBody)
}
func TestFetchWithTestServer404(t *testing.T) {
// Create test HTTP server that returns 404
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`{"error": "not found"}`))
}))
defer server.Close()
L := lua.NewState()
defer L.Close()
L.SetGlobal("fetch", L.NewFunction(fetch))
script := `
response = fetch("` + server.URL + `")
assert(response ~= nil, "Expected response")
assert(response.ok == false, "Expected ok to be false for 404")
assert(response.status == 404, "Expected status 404")
`
err := L.DoString(script)
assert.NoError(t, err)
}

View File

@@ -0,0 +1,366 @@
package processor
import (
"testing"
"github.com/stretchr/testify/assert"
lua "github.com/yuin/gopher-lua"
)
func TestSetVariables(t *testing.T) {
// Test with various variable types
vars := map[string]interface{}{
"multiplier": 2.5,
"prefix": "TEST_",
"enabled": true,
"count": 42,
}
SetVariables(vars)
// Create a new Lua state to verify variables are set
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
// Verify the variables are accessible
multiplier := L.GetGlobal("multiplier")
assert.Equal(t, lua.LTNumber, multiplier.Type())
assert.Equal(t, 2.5, float64(multiplier.(lua.LNumber)))
prefix := L.GetGlobal("prefix")
assert.Equal(t, lua.LTString, prefix.Type())
assert.Equal(t, "TEST_", string(prefix.(lua.LString)))
enabled := L.GetGlobal("enabled")
assert.Equal(t, lua.LTBool, enabled.Type())
assert.True(t, bool(enabled.(lua.LBool)))
count := L.GetGlobal("count")
assert.Equal(t, lua.LTNumber, count.Type())
assert.Equal(t, 42.0, float64(count.(lua.LNumber)))
}
func TestSetVariablesEmpty(t *testing.T) {
// Test with empty map
vars := map[string]interface{}{}
SetVariables(vars)
// Should not panic
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
}
func TestSetVariablesNil(t *testing.T) {
// Test with nil map
SetVariables(nil)
// Should not panic
L, err := NewLuaState()
assert.NoError(t, err)
defer L.Close()
}
func TestGetLuaFunctionsHelp(t *testing.T) {
help := GetLuaFunctionsHelp()
// Verify help is not empty
assert.NotEmpty(t, help)
// Verify it contains documentation for key functions
assert.Contains(t, help, "MATH FUNCTIONS")
assert.Contains(t, help, "STRING FUNCTIONS")
assert.Contains(t, help, "TABLE FUNCTIONS")
assert.Contains(t, help, "XML HELPER FUNCTIONS")
assert.Contains(t, help, "JSON HELPER FUNCTIONS")
assert.Contains(t, help, "HTTP FUNCTIONS")
assert.Contains(t, help, "REGEX FUNCTIONS")
assert.Contains(t, help, "UTILITY FUNCTIONS")
assert.Contains(t, help, "EXAMPLES")
// Verify specific functions are documented
assert.Contains(t, help, "min(a, b)")
assert.Contains(t, help, "max(a, b)")
assert.Contains(t, help, "round(x, n)")
assert.Contains(t, help, "fetch(url, options)")
assert.Contains(t, help, "findElements(root, tagName)")
assert.Contains(t, help, "visitJSON(data, callback)")
assert.Contains(t, help, "re(pattern, input)")
assert.Contains(t, help, "print(...)")
}
func TestFetchFunction(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the fetch function
L.SetGlobal("fetch", L.NewFunction(fetch))
// Test 1: Missing URL should return nil and error
err := L.DoString(`
result, err = fetch("")
assert(result == nil, "Expected nil result for empty URL")
assert(err ~= nil, "Expected error for empty URL")
`)
assert.NoError(t, err)
// Test 2: Invalid URL should return error
err = L.DoString(`
result, err = fetch("not-a-valid-url")
assert(result == nil, "Expected nil result for invalid URL")
assert(err ~= nil, "Expected error for invalid URL")
`)
assert.NoError(t, err)
}
func TestFetchFunctionWithOptions(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the fetch function
L.SetGlobal("fetch", L.NewFunction(fetch))
// Test with options (should fail gracefully with invalid URL)
err := L.DoString(`
local opts = {
method = "POST",
headers = {["Content-Type"] = "application/json"},
body = '{"test": "data"}'
}
result, err = fetch("http://invalid-domain-that-does-not-exist.local", opts)
-- Should get error due to invalid domain
assert(result == nil, "Expected nil result for invalid domain")
assert(err ~= nil, "Expected error for invalid domain")
`)
assert.NoError(t, err)
}
func TestPrependLuaAssignment(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Simple assignment",
input: "10",
expected: "v1 = 10",
},
{
name: "Expression",
input: "v1 * 2",
expected: "v1 = v1 * 2",
},
{
name: "Assignment with equal sign",
input: "= 5",
expected: "v1 = 5",
},
{
name: "Complex expression",
input: "math.floor(v1 / 2)",
expected: "v1 = math.floor(v1 / 2)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := PrependLuaAssignment(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestBuildJSONLuaScript(t *testing.T) {
tests := []struct {
name string
input string
contains []string
}{
{
name: "Simple JSON modification",
input: "data.value = data.value * 2; modified = true",
contains: []string{
"data.value = data.value * 2",
"modified = true",
},
},
{
name: "Complex JSON script",
input: "for i, item in ipairs(data.items) do item.price = item.price * 1.5 end; modified = true",
contains: []string{
"for i, item in ipairs(data.items)",
"modified = true",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildJSONLuaScript(tt.input)
for _, substr := range tt.contains {
assert.Contains(t, result, substr)
}
})
}
}
func TestPrintToGo(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the print function
L.SetGlobal("print", L.NewFunction(printToGo))
// Test printing various types
err := L.DoString(`
print("Hello, World!")
print(42)
print(true)
print(3.14)
`)
assert.NoError(t, err)
}
func TestEvalRegex(t *testing.T) {
L := lua.NewState()
defer L.Close()
// Register the regex function
L.SetGlobal("re", L.NewFunction(EvalRegex))
// Test 1: Simple match
err := L.DoString(`
matches = re("(\\d+)", "The answer is 42")
assert(matches ~= nil, "Expected matches")
assert(matches[1] == "42", "Expected full match to be 42")
assert(matches[2] == "42", "Expected capture group to be 42")
`)
assert.NoError(t, err)
// Test 2: No match
err = L.DoString(`
matches = re("(\\d+)", "No numbers here")
assert(matches == nil, "Expected nil for no match")
`)
assert.NoError(t, err)
// Test 3: Multiple capture groups
err = L.DoString(`
matches = re("(\\w+)\\s+(\\d+)", "item 123")
assert(matches ~= nil, "Expected matches")
assert(matches[1] == "item 123", "Expected full match")
assert(matches[2] == "item", "Expected first capture group")
assert(matches[3] == "123", "Expected second capture group")
`)
assert.NoError(t, err)
}
func TestEstimatePatternComplexity(t *testing.T) {
tests := []struct {
name string
pattern string
minExpected int
}{
{
name: "Simple literal",
pattern: "hello",
minExpected: 1,
},
{
name: "With capture group",
pattern: "(\\d+)",
minExpected: 2,
},
{
name: "Complex pattern",
pattern: "(?P<name>\\w+)\\s+(?P<value>\\d+\\.\\d+)",
minExpected: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
complexity := estimatePatternComplexity(tt.pattern)
assert.GreaterOrEqual(t, complexity, tt.minExpected)
})
}
}
func TestParseNumeric(t *testing.T) {
tests := []struct {
name string
input string
expected float64
shouldOk bool
}{
{"Integer", "42", 42.0, true},
{"Float", "3.14", 3.14, true},
{"Negative", "-10", -10.0, true},
{"Invalid", "not a number", 0, false},
{"Empty", "", 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, ok := parseNumeric(tt.input)
assert.Equal(t, tt.shouldOk, ok)
if tt.shouldOk {
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestFormatNumeric(t *testing.T) {
tests := []struct {
name string
input float64
expected string
}{
{"Integer value", 42.0, "42"},
{"Float value", 3.14, "3.14"},
{"Negative integer", -10.0, "-10"},
{"Negative float", -3.14, "-3.14"},
{"Zero", 0.0, "0"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatNumeric(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestLuaHelperFunctionsDocumentation(t *testing.T) {
help := GetLuaFunctionsHelp()
// All main function categories should be documented
expectedCategories := []string{
"MATH FUNCTIONS",
"STRING FUNCTIONS",
"XML HELPER FUNCTIONS",
"JSON HELPER FUNCTIONS",
}
for _, category := range expectedCategories {
assert.Contains(t, help, category, "Help should contain category: %s", category)
}
// Verify some key functions are mentioned
keyFunctions := []string{
"findElements",
"visitElements",
"visitJSON",
"round",
"fetch",
}
for _, fn := range keyFunctions {
assert.Contains(t, help, fn, "Help should mention function: %s", fn)
}
}

View File

@@ -0,0 +1,87 @@
package processor
import (
"cook/utils"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
// Test named capture group fallback when value is not in Lua
func TestNamedCaptureGroupFallback(t *testing.T) {
pattern := `value = (?P<myvalue>\d+)`
input := `value = 42`
// Don't set myvalue in Lua, but do something else so we get a match
lua := `v1 = v1 * 2 -- Set v1 but not myvalue, test fallback`
cmd := utils.ModifyCommand{
Name: "test_fallback",
Regex: pattern,
Lua: lua,
}
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatchIndex(input)
assert.NotNil(t, matches)
replacements, err := ProcessRegex(input, cmd, "test.txt")
// Should not error
assert.NoError(t, err)
// Since only v1 is set, myvalue should keep original
// Should have 1 replacement for v1
if replacements != nil {
assert.GreaterOrEqual(t, len(replacements), 0)
}
}
// Test named capture groups with nil value in Lua
func TestNamedCaptureGroupNilInLua(t *testing.T) {
pattern := `value = (?P<num>\d+)`
input := `value = 123`
// Set num to nil explicitly, and also set v1 to get a modification
lua := `v1 = v1 .. "_test"; num = nil -- v1 modified, num set to nil`
cmd := utils.ModifyCommand{
Name: "test_nil",
Regex: pattern,
Lua: lua,
}
replacements, err := ProcessRegex(input, cmd, "test.txt")
// Should not error
assert.NoError(t, err)
// Should have replacements for v1, num should fallback to original
if replacements != nil {
assert.GreaterOrEqual(t, len(replacements), 0)
}
}
// Test multiple named capture groups with some undefined
func TestMixedNamedCaptureGroups(t *testing.T) {
pattern := `(?P<key>\w+) = (?P<value>\d+)`
input := `count = 100`
lua := `key = key .. "_modified" -- Only modify key, leave value undefined`
cmd := utils.ModifyCommand{
Name: "test_mixed",
Regex: pattern,
Lua: lua,
}
replacements, err := ProcessRegex(input, cmd, "test.txt")
assert.NoError(t, err)
assert.NotNil(t, replacements)
// Apply replacements
result, _ := utils.ExecuteModifications(replacements, input)
// key should be modified, value should remain unchanged
assert.Contains(t, result, "count_modified")
assert.Contains(t, result, "100")
}

View File

@@ -1,27 +0,0 @@
package processor
import (
"io"
"os"
logger "git.site.quack-lab.dev/dave/cylogger"
)
func init() {
// Only modify logger in test mode
// This checks if we're running under 'go test'
if os.Getenv("GO_TESTING") == "1" || os.Getenv("TESTING") == "1" {
// Initialize logger with ERROR level for tests
// to minimize noise in test output
logger.Init(logger.LevelError)
// Optionally redirect logger output to discard
// This prevents logger output from interfering with test output
disableTestLogs := os.Getenv("ENABLE_TEST_LOGS") != "1"
if disableTestLogs {
// Create a new logger that writes to nowhere
silentLogger := logger.New(io.Discard, "", 0)
logger.Default = silentLogger
}
}
}

View File

@@ -29,7 +29,7 @@ type XMLElement struct {
// XMLAttribute represents an attribute with its position in the source // XMLAttribute represents an attribute with its position in the source
type XMLAttribute struct { type XMLAttribute struct {
Value string Value string
ValueStart int64 ValueStart int64
ValueEnd int64 ValueEnd int64
} }
@@ -127,34 +127,6 @@ func parseXMLWithPositions(content string) (*XMLElement, error) {
return root, nil return root, nil
} }
// xmlElementToMap converts XMLElement to a map for comparison
func xmlElementToMap(elem *XMLElement) map[string]interface{} {
result := make(map[string]interface{})
result["_tag"] = elem.Tag
if len(elem.Attributes) > 0 {
attrs := make(map[string]interface{})
for k, v := range elem.Attributes {
attrs[k] = v.Value
}
result["_attr"] = attrs
}
if elem.Text != "" {
result["_text"] = elem.Text
}
if len(elem.Children) > 0 {
children := make([]interface{}, len(elem.Children))
for i, child := range elem.Children {
children[i] = xmlElementToMap(child)
}
result["_children"] = children
}
return result
}
// XMLChange represents a detected difference between original and modified XML structures // XMLChange represents a detected difference between original and modified XML structures
type XMLChange struct { type XMLChange struct {
Type string // "text", "attribute", "add_element", "remove_element" Type string // "text", "attribute", "add_element", "remove_element"
@@ -276,22 +248,6 @@ func findXMLChanges(original, modified *XMLElement, path string) []XMLChange {
} }
} }
// Handle completely new tag types
for tag, modChildren := range modChildMap {
if !processedTags[tag] {
for i, child := range modChildren {
childPath := fmt.Sprintf("%s/%s[%d]", path, tag, i)
xmlText := serializeXMLElement(child, " ")
changes = append(changes, XMLChange{
Type: "add_element",
Path: childPath,
InsertText: xmlText,
StartPos: original.EndPos - int64(len(original.Tag)+3),
})
}
}
}
return changes return changes
} }
@@ -395,14 +351,6 @@ func applyXMLChanges(changes []XMLChange) []utils.ReplaceCommand {
return commands return commands
} }
// modifyXMLElement applies modifications to an XMLElement based on a modification function
func modifyXMLElement(elem *XMLElement, modifyFunc func(*XMLElement)) *XMLElement {
// Deep copy the element
copied := deepCopyXMLElement(elem)
modifyFunc(copied)
return copied
}
// deepCopyXMLElement creates a deep copy of an XMLElement // deepCopyXMLElement creates a deep copy of an XMLElement
func deepCopyXMLElement(elem *XMLElement) *XMLElement { func deepCopyXMLElement(elem *XMLElement) *XMLElement {
if elem == nil { if elem == nil {
@@ -410,12 +358,12 @@ func deepCopyXMLElement(elem *XMLElement) *XMLElement {
} }
copied := &XMLElement{ copied := &XMLElement{
Tag: elem.Tag, Tag: elem.Tag,
Text: elem.Text, Text: elem.Text,
StartPos: elem.StartPos, StartPos: elem.StartPos,
EndPos: elem.EndPos, EndPos: elem.EndPos,
TextStart: elem.TextStart, TextStart: elem.TextStart,
TextEnd: elem.TextEnd, TextEnd: elem.TextEnd,
Attributes: make(map[string]XMLAttribute), Attributes: make(map[string]XMLAttribute),
Children: make([]*XMLElement, len(elem.Children)), Children: make([]*XMLElement, len(elem.Children)),
} }
@@ -534,10 +482,6 @@ func xmlElementToLuaTable(L *lua.LState, elem *XMLElement) *lua.LTable {
table.RawSetString("_attr", attrs) table.RawSetString("_attr", attrs)
} }
if elem.Text != "" {
table.RawSetString("_text", lua.LString(elem.Text))
}
if len(elem.Children) > 0 { if len(elem.Children) > 0 {
children := L.CreateTable(len(elem.Children), 0) children := L.CreateTable(len(elem.Children), 0)
for i, child := range elem.Children { for i, child := range elem.Children {
@@ -551,11 +495,6 @@ func xmlElementToLuaTable(L *lua.LState, elem *XMLElement) *lua.LTable {
// luaTableToXMLElement applies Lua table modifications back to XMLElement // luaTableToXMLElement applies Lua table modifications back to XMLElement
func luaTableToXMLElement(L *lua.LState, table *lua.LTable, elem *XMLElement) { func luaTableToXMLElement(L *lua.LState, table *lua.LTable, elem *XMLElement) {
// Update text
if textVal := table.RawGetString("_text"); textVal.Type() == lua.LTString {
elem.Text = string(textVal.(lua.LString))
}
// Update attributes // Update attributes
if attrVal := table.RawGetString("_attr"); attrVal.Type() == lua.LTTable { if attrVal := table.RawGetString("_attr"); attrVal.Type() == lua.LTTable {
attrTable := attrVal.(*lua.LTable) attrTable := attrVal.(*lua.LTable)

View File

@@ -417,14 +417,14 @@ files = ["test.txt"
assert.NoError(t, err, "Should handle non-existent file without error") assert.NoError(t, err, "Should handle non-existent file without error")
assert.Empty(t, commands, "Should return empty commands for non-existent file") assert.Empty(t, commands, "Should return empty commands for non-existent file")
// Test 3: Empty TOML file creates an error (this is expected behavior) // Test 3: Empty TOML file returns no commands (not an error)
emptyFile := filepath.Join(tmpDir, "empty.toml") emptyFile := filepath.Join(tmpDir, "empty.toml")
err = os.WriteFile(emptyFile, []byte(""), 0644) err = os.WriteFile(emptyFile, []byte(""), 0644)
assert.NoError(t, err, "Should write empty TOML file") assert.NoError(t, err, "Should write empty TOML file")
commands, _, err = utils.LoadCommandsFromTomlFiles("empty.toml") commands, _, err = utils.LoadCommandsFromTomlFiles("empty.toml")
assert.Error(t, err, "Should return error for empty TOML file") assert.NoError(t, err, "Empty TOML should not return error")
assert.Nil(t, commands, "Should return nil commands for empty TOML") assert.Empty(t, commands, "Should return empty commands for empty TOML")
} }
func TestYAMLToTOMLConversion(t *testing.T) { func TestYAMLToTOMLConversion(t *testing.T) {

View File

@@ -2,7 +2,6 @@ package utils
import ( import (
"os" "os"
"strconv"
"strings" "strings"
logger "git.site.quack-lab.dev/dave/cylogger" logger "git.site.quack-lab.dev/dave/cylogger"
@@ -11,16 +10,6 @@ import (
// fileLogger is a scoped logger for the utils/file package. // fileLogger is a scoped logger for the utils/file package.
var fileLogger = logger.Default.WithPrefix("utils/file") var fileLogger = logger.Default.WithPrefix("utils/file")
func CleanPath(path string) string {
// Use the centralized ResolvePath function
return ResolvePath(path)
}
func ToAbs(path string) string {
// Use the centralized ResolvePath function
return ResolvePath(path)
}
// LimitString truncates a string to maxLen and adds "..." if truncated // LimitString truncates a string to maxLen and adds "..." if truncated
func LimitString(s string, maxLen int) string { func LimitString(s string, maxLen int) string {
limitStringLogger := fileLogger.WithPrefix("LimitString").WithField("originalLength", len(s)).WithField("maxLength", maxLen) limitStringLogger := fileLogger.WithPrefix("LimitString").WithField("originalLength", len(s)).WithField("maxLength", maxLen)
@@ -35,19 +24,6 @@ func LimitString(s string, maxLen int) string {
return limited return limited
} }
// StrToFloat converts a string to a float64, returning 0 on error.
func StrToFloat(s string) float64 {
strToFloatLogger := fileLogger.WithPrefix("StrToFloat").WithField("inputString", s)
strToFloatLogger.Debug("Attempting to convert string to float")
f, err := strconv.ParseFloat(s, 64)
if err != nil {
strToFloatLogger.Warning("Failed to convert string %q to float, returning 0: %v", s, err)
return 0
}
strToFloatLogger.Trace("Successfully converted %q to float: %f", s, f)
return f
}
func ResetWhereNecessary(associations map[string]FileCommandAssociation, db DB) error { func ResetWhereNecessary(associations map[string]FileCommandAssociation, db DB) error {
resetWhereNecessaryLogger := fileLogger.WithPrefix("ResetWhereNecessary") resetWhereNecessaryLogger := fileLogger.WithPrefix("ResetWhereNecessary")
resetWhereNecessaryLogger.Debug("Starting reset where necessary operation") resetWhereNecessaryLogger.Debug("Starting reset where necessary operation")

209
utils/file_test.go Normal file
View File

@@ -0,0 +1,209 @@
package utils
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLimitString(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
}{
{
name: "Short string",
input: "hello",
maxLen: 10,
expected: "hello",
},
{
name: "Exact length",
input: "hello",
maxLen: 5,
expected: "hello",
},
{
name: "Too long",
input: "hello world",
maxLen: 8,
expected: "hello...",
},
{
name: "With newlines",
input: "hello\nworld",
maxLen: 20,
expected: "hello\\nworld",
},
{
name: "With newlines truncated",
input: "hello\nworld\nfoo\nbar",
maxLen: 15,
expected: "hello\\nworld...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := LimitString(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
func TestResetWhereNecessary(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "reset-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create test files
file1 := filepath.Join(tmpDir, "file1.txt")
file2 := filepath.Join(tmpDir, "file2.txt")
file3 := filepath.Join(tmpDir, "file3.txt")
err = os.WriteFile(file1, []byte("original1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("original2"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file3, []byte("original3"), 0644)
assert.NoError(t, err)
// Modify files
err = os.WriteFile(file1, []byte("modified1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("modified2"), 0644)
assert.NoError(t, err)
// Create mock DB
db, err := GetDB()
assert.NoError(t, err)
err = db.SaveFile(file1, []byte("original1"))
assert.NoError(t, err)
err = db.SaveFile(file2, []byte("original2"))
assert.NoError(t, err)
// file3 not in DB
// Create associations with reset commands
associations := map[string]FileCommandAssociation{
file1: {
File: file1,
Commands: []ModifyCommand{
{Name: "cmd1", Reset: true},
},
},
file2: {
File: file2,
IsolateCommands: []ModifyCommand{
{Name: "cmd2", Reset: true},
},
},
file3: {
File: file3,
Commands: []ModifyCommand{
{Name: "cmd3", Reset: false}, // No reset
},
},
}
// Run reset
err = ResetWhereNecessary(associations, db)
assert.NoError(t, err)
// Verify file1 was reset
data, _ := os.ReadFile(file1)
assert.Equal(t, "original1", string(data))
// Verify file2 was reset
data, _ = os.ReadFile(file2)
assert.Equal(t, "original2", string(data))
// Verify file3 was NOT reset
data, _ = os.ReadFile(file3)
assert.Equal(t, "original3", string(data))
}
func TestResetWhereNecessaryMissingFromDB(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "reset-missing-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create a test file that's been modified
file1 := filepath.Join(tmpDir, "file1.txt")
err = os.WriteFile(file1, []byte("modified_content"), 0644)
assert.NoError(t, err)
// Create DB but DON'T save file to it
db, err := GetDB()
assert.NoError(t, err)
// Create associations with reset command
associations := map[string]FileCommandAssociation{
file1: {
File: file1,
Commands: []ModifyCommand{
{Name: "cmd1", Reset: true},
},
},
}
// Run reset - should use current disk content as fallback
err = ResetWhereNecessary(associations, db)
assert.NoError(t, err)
// Verify file was "reset" to current content (saved to DB for next time)
data, _ := os.ReadFile(file1)
assert.Equal(t, "modified_content", string(data))
// Verify it was saved to DB
savedData, err := db.GetFile(file1)
assert.NoError(t, err)
assert.Equal(t, "modified_content", string(savedData))
}
func TestResetAllFiles(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "reset-all-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create test files
file1 := filepath.Join(tmpDir, "file1.txt")
file2 := filepath.Join(tmpDir, "file2.txt")
err = os.WriteFile(file1, []byte("original1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("original2"), 0644)
assert.NoError(t, err)
// Create mock DB and save originals
db, err := GetDB()
assert.NoError(t, err)
err = db.SaveFile(file1, []byte("original1"))
assert.NoError(t, err)
err = db.SaveFile(file2, []byte("original2"))
assert.NoError(t, err)
// Modify files
err = os.WriteFile(file1, []byte("modified1"), 0644)
assert.NoError(t, err)
err = os.WriteFile(file2, []byte("modified2"), 0644)
assert.NoError(t, err)
// Verify they're modified
data, _ := os.ReadFile(file1)
assert.Equal(t, "modified1", string(data))
// Reset all
err = ResetAllFiles(db)
assert.NoError(t, err)
// Verify both were reset
data, _ = os.ReadFile(file1)
assert.Equal(t, "original1", string(data))
data, _ = os.ReadFile(file2)
assert.Equal(t, "original2", string(data))
}

View File

@@ -86,28 +86,17 @@ func SplitPattern(pattern string) (string, string) {
splitPatternLogger.Debug("Splitting pattern") splitPatternLogger.Debug("Splitting pattern")
splitPatternLogger.Trace("Original pattern: %q", pattern) splitPatternLogger.Trace("Original pattern: %q", pattern)
// Resolve the pattern first to handle ~ expansion and make it absolute // Split the pattern first to separate static and wildcard parts
resolvedPattern := ResolvePath(pattern) static, remainingPattern := doublestar.SplitPattern(pattern)
splitPatternLogger.Trace("Resolved pattern: %q", resolvedPattern) splitPatternLogger.Trace("After split: static=%q, pattern=%q", static, remainingPattern)
static, pattern := doublestar.SplitPattern(resolvedPattern) // Resolve the static part to handle ~ expansion and make it absolute
// ResolvePath already normalizes to forward slashes
static = ResolvePath(static)
splitPatternLogger.Trace("Resolved static part: %q", static)
// Ensure static part is properly resolved splitPatternLogger.Trace("Final static path: %q, Remaining pattern: %q", static, remainingPattern)
if static == "" { return static, remainingPattern
cwd, err := os.Getwd()
if err != nil {
splitPatternLogger.Error("Error getting current working directory: %v", err)
return "", ""
}
static = cwd
splitPatternLogger.Debug("Static part is empty, defaulting to current working directory: %q", static)
} else {
// Static part should already be resolved by ResolvePath
static = strings.ReplaceAll(static, "\\", "/")
}
splitPatternLogger.Trace("Final static path: %q, Remaining pattern: %q", static, pattern)
return static, pattern
} }
type FileCommandAssociation struct { type FileCommandAssociation struct {
@@ -140,7 +129,7 @@ func AssociateFilesWithCommands(files []string, commands []ModifyCommand) (map[s
static, pattern := SplitPattern(glob) static, pattern := SplitPattern(glob)
associateFilesLogger.Trace("Glob parts for %q → static=%q pattern=%q", glob, static, pattern) associateFilesLogger.Trace("Glob parts for %q → static=%q pattern=%q", glob, static, pattern)
// Use resolved file for matching // Use resolved file for matching (already normalized to forward slashes by ResolvePath)
absFile := resolvedFile absFile := resolvedFile
associateFilesLogger.Trace("Absolute file path resolved for matching: %q", absFile) associateFilesLogger.Trace("Absolute file path resolved for matching: %q", absFile)
@@ -283,9 +272,6 @@ func LoadCommands(args []string) ([]ModifyCommand, map[string]interface{}, error
loadCommandsLogger.Error("Failed to load TOML commands from argument %q: %v", arg, err) loadCommandsLogger.Error("Failed to load TOML commands from argument %q: %v", arg, err)
return nil, nil, fmt.Errorf("failed to load commands from TOML files: %w", err) return nil, nil, fmt.Errorf("failed to load commands from TOML files: %w", err)
} }
for k, v := range newVariables {
variables[k] = v
}
} else { } else {
// Default to YAML for .yml, .yaml, or any other extension // Default to YAML for .yml, .yaml, or any other extension
loadCommandsLogger.Debug("Loading YAML commands from %q", arg) loadCommandsLogger.Debug("Loading YAML commands from %q", arg)
@@ -294,9 +280,9 @@ func LoadCommands(args []string) ([]ModifyCommand, map[string]interface{}, error
loadCommandsLogger.Error("Failed to load YAML commands from argument %q: %v", arg, err) loadCommandsLogger.Error("Failed to load YAML commands from argument %q: %v", arg, err)
return nil, nil, fmt.Errorf("failed to load commands from cook files: %w", err) return nil, nil, fmt.Errorf("failed to load commands from cook files: %w", err)
} }
for k, v := range newVariables { }
variables[k] = v for k, v := range newVariables {
} variables[k] = v
} }
loadCommandsLogger.Debug("Successfully loaded %d commands from %q", len(newCommands), arg) loadCommandsLogger.Debug("Successfully loaded %d commands from %q", len(newCommands), arg)
@@ -485,24 +471,8 @@ func LoadCommandsFromTomlFile(tomlFileData []byte) ([]ModifyCommand, map[string]
} }
} }
// If we found commands in the wrapped structure, use those // Use commands from wrapped structure
if len(tomlData.Commands) > 0 { commands = tomlData.Commands
commands = tomlData.Commands
loadTomlCommandLogger.Debug("Found %d commands in wrapped TOML structure", len(commands))
} else {
// Try to parse as direct array (similar to YAML format)
directCommands := []ModifyCommand{}
err = toml.Unmarshal(tomlFileData, &directCommands)
if err != nil {
loadTomlCommandLogger.Error("Failed to unmarshal TOML file data as direct array: %v", err)
return nil, nil, fmt.Errorf("failed to unmarshal TOML file as direct array: %w", err)
}
if len(directCommands) > 0 {
commands = directCommands
loadTomlCommandLogger.Debug("Found %d commands in direct TOML array", len(directCommands))
}
}
loadTomlCommandLogger.Debug("Successfully unmarshaled %d commands and %d variables", len(commands), len(variables)) loadTomlCommandLogger.Debug("Successfully unmarshaled %d commands and %d variables", len(commands), len(variables))
loadTomlCommandLogger.Trace("Unmarshaled commands: %v", commands) loadTomlCommandLogger.Trace("Unmarshaled commands: %v", commands)
loadTomlCommandLogger.Trace("Unmarshaled variables: %v", variables) loadTomlCommandLogger.Trace("Unmarshaled variables: %v", variables)

View File

@@ -0,0 +1,313 @@
package utils
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAggregateGlobsWithDuplicates(t *testing.T) {
commands := []ModifyCommand{
{Files: []string{"*.txt", "*.md"}},
{Files: []string{"*.txt", "*.go"}}, // *.txt is duplicate
{Files: []string{"test/**/*.xml"}},
}
globs := AggregateGlobs(commands)
// Should deduplicate
assert.Equal(t, 4, len(globs))
// AggregateGlobs resolves paths, which uses forward slashes internally
assert.Contains(t, globs, ResolvePath("*.txt"))
assert.Contains(t, globs, ResolvePath("*.md"))
assert.Contains(t, globs, ResolvePath("*.go"))
assert.Contains(t, globs, ResolvePath("test/**/*.xml"))
}
func TestExpandGlobsWithActualFiles(t *testing.T) {
// Create temp dir with test files
tmpDir, err := os.MkdirTemp("", "glob-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create test files
testFile1 := filepath.Join(tmpDir, "test1.txt")
testFile2 := filepath.Join(tmpDir, "test2.txt")
testFile3 := filepath.Join(tmpDir, "test.md")
os.WriteFile(testFile1, []byte("test"), 0644)
os.WriteFile(testFile2, []byte("test"), 0644)
os.WriteFile(testFile3, []byte("test"), 0644)
// Change to temp directory so glob pattern can find files
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// Test expanding globs using ResolvePath to normalize the pattern
globs := map[string]struct{}{
ResolvePath("*.txt"): {},
}
files, err := ExpandGlobs(globs)
assert.NoError(t, err)
assert.Equal(t, 2, len(files))
}
func TestSplitPatternWithTilde(t *testing.T) {
pattern := "~/test/*.txt"
static, pat := SplitPattern(pattern)
// Should expand ~
assert.NotEqual(t, "~", static)
assert.Contains(t, pat, "*.txt")
}
func TestLoadCommandsWithDisabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "disabled-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
variables:
test: "value"
commands:
- name: "enabled_cmd"
regex: "test"
lua: "v1 * 2"
files: ["*.txt"]
- name: "disabled_cmd"
regex: "test2"
lua: "v1 * 3"
files: ["*.txt"]
disable: true
`
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte(yamlContent), 0644)
assert.NoError(t, err)
// Change to temp directory so LoadCommands can find the file with a simple pattern
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
commands, variables, err := LoadCommands([]string{"test.yml"})
assert.NoError(t, err)
// Should only load enabled command
assert.Equal(t, 1, len(commands))
assert.Equal(t, "enabled_cmd", commands[0].Name)
// Should still load variables
assert.Equal(t, 1, len(variables))
}
func TestFilterCommandsByName(t *testing.T) {
commands := []ModifyCommand{
{Name: "test_multiply"},
{Name: "test_divide"},
{Name: "other_command"},
{Name: "test_add"},
}
// Filter by "test"
filtered := FilterCommands(commands, "test")
assert.Equal(t, 3, len(filtered))
// Filter by multiple
filtered = FilterCommands(commands, "multiply,divide")
assert.Equal(t, 2, len(filtered))
}
func TestCountGlobsBeforeDedup(t *testing.T) {
commands := []ModifyCommand{
{Files: []string{"*.txt", "*.md", "*.go"}},
{Files: []string{"*.xml"}},
{Files: []string{"test/**/*.txt", "data/**/*.json"}},
}
count := CountGlobsBeforeDedup(commands)
assert.Equal(t, 6, count)
}
func TestMatchesWithMemoization(t *testing.T) {
path := "test/file.txt"
glob := "**/*.txt"
// First call
matches1, err1 := Matches(path, glob)
assert.NoError(t, err1)
assert.True(t, matches1)
// Second call should use memo
matches2, err2 := Matches(path, glob)
assert.NoError(t, err2)
assert.Equal(t, matches1, matches2)
}
func TestValidateCommand(t *testing.T) {
tests := []struct {
name string
cmd ModifyCommand
wantErr bool
}{
{
name: "Valid command",
cmd: ModifyCommand{
Regex: "test",
Lua: "v1 * 2",
Files: []string{"*.txt"},
},
wantErr: false,
},
{
name: "Valid JSON mode without regex",
cmd: ModifyCommand{
JSON: true,
Lua: "data.value = data.value * 2; modified = true",
Files: []string{"*.json"},
},
wantErr: false,
},
{
name: "Missing regex in non-JSON mode",
cmd: ModifyCommand{
Lua: "v1 * 2",
Files: []string{"*.txt"},
},
wantErr: true,
},
{
name: "Missing Lua",
cmd: ModifyCommand{
Regex: "test",
Files: []string{"*.txt"},
},
wantErr: true,
},
{
name: "Missing files",
cmd: ModifyCommand{
Regex: "test",
Lua: "v1 * 2",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.cmd.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestLoadCommandsFromTomlWithVariables(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "toml-vars-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
tomlContent := `[variables]
multiplier = 3
prefix = "PREFIX_"
[[commands]]
name = "test_cmd"
regex = "value = !num"
lua = "v1 * multiplier"
files = ["*.txt"]
`
tomlFile := filepath.Join(tmpDir, "test.toml")
err = os.WriteFile(tomlFile, []byte(tomlContent), 0644)
assert.NoError(t, err)
// Change to temp directory so glob pattern can find the file
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
commands, variables, err := LoadCommandsFromTomlFiles("test.toml")
assert.NoError(t, err)
assert.Equal(t, 1, len(commands))
assert.Equal(t, 2, len(variables))
assert.Equal(t, int64(3), variables["multiplier"])
assert.Equal(t, "PREFIX_", variables["prefix"])
}
func TestConvertYAMLToTOMLSkipExisting(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "convert-skip-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create YAML file
yamlContent := `
commands:
- name: "test"
regex: "value"
lua: "v1 * 2"
files: ["*.txt"]
`
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte(yamlContent), 0644)
assert.NoError(t, err)
// Create TOML file (should skip conversion)
tomlFile := filepath.Join(tmpDir, "test.toml")
err = os.WriteFile(tomlFile, []byte("# existing"), 0644)
assert.NoError(t, err)
// Change to temp dir
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// Should skip existing TOML
err = ConvertYAMLToTOML("test.yml")
assert.NoError(t, err)
// TOML content should be unchanged
content, _ := os.ReadFile(tomlFile)
assert.Equal(t, "# existing", string(content))
}
func TestLoadCommandsWithTomlExtension(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "toml-ext-test-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
tomlContent := `
[variables]
test_var = "value"
[[commands]]
name = "TestCmd"
regex = "test"
lua = "return true"
files = ["*.txt"]
`
tomlFile := filepath.Join(tmpDir, "test.toml")
err = os.WriteFile(tomlFile, []byte(tomlContent), 0644)
assert.NoError(t, err)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// This should trigger the .toml suffix check in LoadCommands
commands, variables, err := LoadCommands([]string{"test.toml"})
assert.NoError(t, err)
assert.Len(t, commands, 1)
assert.Equal(t, "TestCmd", commands[0].Name)
assert.Len(t, variables, 1)
assert.Equal(t, "value", variables["test_var"])
}

View File

@@ -0,0 +1,93 @@
package utils
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
// TestConvertYAMLToTOMLReadError tests error handling when YAML file can't be read
func TestConvertYAMLToTOMLReadError(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "convert-read-error-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create YAML file with no read permissions (on Unix) or delete it after creation
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte("commands:\n - name: test\n"), 0000)
assert.NoError(t, err)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// This should fail to read but not crash
err = ConvertYAMLToTOML("test.yml")
// Function continues on error, doesn't return error
assert.NoError(t, err)
// Fix permissions for cleanup
os.Chmod(yamlFile, 0644)
}
// TestConvertYAMLToTOMLParseError tests error handling when YAML is invalid
func TestConvertYAMLToTOMLParseError(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "convert-parse-error-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create invalid YAML
yamlFile := filepath.Join(tmpDir, "invalid.yml")
err = os.WriteFile(yamlFile, []byte("commands:\n - [this is not valid yaml}}"), 0644)
assert.NoError(t, err)
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// This should fail to parse but not crash
err = ConvertYAMLToTOML("invalid.yml")
assert.NoError(t, err)
// TOML file should not exist
_, statErr := os.Stat(filepath.Join(tmpDir, "invalid.toml"))
assert.True(t, os.IsNotExist(statErr))
}
// TestConvertYAMLToTOMLWriteError tests error handling when TOML file can't be written
func TestConvertYAMLToTOMLWriteError(t *testing.T) {
if os.Getenv("CI") != "" {
t.Skip("Skipping write permission test in CI")
}
tmpDir, err := os.MkdirTemp("", "convert-write-error-*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create valid YAML
yamlFile := filepath.Join(tmpDir, "test.yml")
err = os.WriteFile(yamlFile, []byte("commands:\n - name: test\n regex: test\n lua: v1\n files: [test.txt]\n"), 0644)
assert.NoError(t, err)
// Create output directory with no write permissions
outputDir := filepath.Join(tmpDir, "readonly")
err = os.Mkdir(outputDir, 0555)
assert.NoError(t, err)
defer os.Chmod(outputDir, 0755) // Fix for cleanup
origDir, _ := os.Getwd()
defer os.Chdir(origDir)
os.Chdir(tmpDir)
// Move YAML into readonly dir
newYamlFile := filepath.Join(outputDir, "test.yml")
os.Rename(yamlFile, newYamlFile)
os.Chdir(outputDir)
// This should fail to write but not crash
err = ConvertYAMLToTOML("test.yml")
assert.NoError(t, err)
}

View File

@@ -3,7 +3,6 @@ package utils
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
logger "git.site.quack-lab.dev/dave/cylogger" logger "git.site.quack-lab.dev/dave/cylogger"
@@ -12,93 +11,69 @@ import (
// pathLogger is a scoped logger for the utils/path package. // pathLogger is a scoped logger for the utils/path package.
var pathLogger = logger.Default.WithPrefix("utils/path") var pathLogger = logger.Default.WithPrefix("utils/path")
// ResolvePath resolves a file path by: // ResolvePath resolves a path to an absolute path, handling ~ expansion and cleaning
// 1. Expanding ~ to the user's home directory
// 2. Making the path absolute if it's relative
// 3. Normalizing path separators to forward slashes
// 4. Cleaning the path
func ResolvePath(path string) string { func ResolvePath(path string) string {
resolvePathLogger := pathLogger.WithPrefix("ResolvePath").WithField("inputPath", path) resolvePathLogger := pathLogger.WithPrefix("ResolvePath").WithField("inputPath", path)
resolvePathLogger.Debug("Resolving path") resolvePathLogger.Trace("Resolving path: %q", path)
// Handle empty path
if path == "" { if path == "" {
resolvePathLogger.Warning("Empty path provided") resolvePathLogger.Trace("Empty path, returning empty string")
return "" return ""
} }
// Step 1: Expand ~ to home directory // Check if path is absolute
originalPath := path if filepath.IsAbs(path) {
resolvePathLogger.Trace("Path is already absolute: %q", path)
cleaned := filepath.ToSlash(filepath.Clean(path))
resolvePathLogger.Trace("Cleaned absolute path: %q", cleaned)
return cleaned
}
// Handle ~ expansion
if strings.HasPrefix(path, "~") { if strings.HasPrefix(path, "~") {
home := os.Getenv("HOME") homeDir, _ := os.UserHomeDir()
if home == "" { if strings.HasPrefix(path, "~/") || strings.HasPrefix(path, "~\\") {
// Fallback for Windows path = filepath.Join(homeDir, path[2:])
if runtime.GOOS == "windows" { } else if path == "~" {
home = os.Getenv("USERPROFILE") path = homeDir
}
}
if home != "" {
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
} else {
// Handle cases like ~username
// For now, just replace ~ with home directory
path = strings.Replace(path, "~", home, 1)
}
resolvePathLogger.Debug("Expanded tilde to home directory: home=%s, result=%s", home, path)
} else { } else {
resolvePathLogger.Warning("Could not determine home directory for tilde expansion") // ~something (like ~~), treat first ~ as home expansion, rest as literal
path = homeDir + path[1:]
} }
resolvePathLogger.Trace("Expanded ~ to home directory: %q", path)
} }
// Step 2: Make path absolute if it's not already // Make absolute if not already
if !filepath.IsAbs(path) { if !filepath.IsAbs(path) {
cwd, err := os.Getwd() absPath, err := filepath.Abs(path)
if err != nil { if err != nil {
resolvePathLogger.Error("Failed to get current working directory: %v", err) resolvePathLogger.Error("Failed to get absolute path: %v", err)
return path // Return as-is if we can't get CWD return filepath.ToSlash(filepath.Clean(path))
} }
path = filepath.Join(cwd, path) resolvePathLogger.Trace("Made path absolute: %q -> %q", path, absPath)
resolvePathLogger.Debug("Made relative path absolute: cwd=%s, result=%s", cwd, path) path = absPath
} }
// Step 3: Clean the path // Clean the path and normalize to forward slashes for consistency
path = filepath.Clean(path) cleaned := filepath.ToSlash(filepath.Clean(path))
resolvePathLogger.Debug("Cleaned path: result=%s", path) resolvePathLogger.Trace("Final cleaned path: %q", cleaned)
return cleaned
// Step 4: Normalize path separators to forward slashes for consistency
path = strings.ReplaceAll(path, "\\", "/")
resolvePathLogger.Debug("Final resolved path: original=%s, final=%s", originalPath, path)
return path
}
// ResolvePathForLogging is the same as ResolvePath but includes more detailed logging
// for debugging purposes
func ResolvePathForLogging(path string) string {
return ResolvePath(path)
}
// IsAbsolutePath checks if a path is absolute (including tilde expansion)
func IsAbsolutePath(path string) bool {
// Check for tilde expansion first
if strings.HasPrefix(path, "~") {
return true // Tilde paths become absolute after expansion
}
return filepath.IsAbs(path)
} }
// GetRelativePath returns the relative path from base to target // GetRelativePath returns the relative path from base to target
func GetRelativePath(base, target string) (string, error) { func GetRelativePath(base, target string) (string, error) {
resolvedBase := ResolvePath(base) getRelativePathLogger := pathLogger.WithPrefix("GetRelativePath")
resolvedTarget := ResolvePath(target) getRelativePathLogger.Debug("Getting relative path from %q to %q", base, target)
relPath, err := filepath.Rel(resolvedBase, resolvedTarget) relPath, err := filepath.Rel(base, target)
if err != nil { if err != nil {
getRelativePathLogger.Error("Failed to get relative path: %v", err)
return "", err return "", err
} }
// Normalize to forward slashes // Use forward slashes for consistency
return strings.ReplaceAll(relPath, "\\", "/"), nil relPath = filepath.ToSlash(relPath)
getRelativePathLogger.Debug("Relative path: %q", relPath)
return relPath, nil
} }

View File

@@ -224,52 +224,6 @@ func TestResolvePathComplexTilde(t *testing.T) {
} }
} }
func TestIsAbsolutePath(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Empty path",
input: "",
expected: false,
},
{
name: "Absolute Unix path",
input: "/absolute/path",
expected: func() bool {
if runtime.GOOS == "windows" {
// On Windows, paths starting with / are not considered absolute
return false
}
return true
}(),
},
{
name: "Relative path",
input: "relative/path",
expected: false,
},
{
name: "Tilde expansion (becomes absolute)",
input: "~/path",
expected: true,
},
{
name: "Windows absolute path",
input: "C:\\Windows\\System32",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsAbsolutePath(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetRelativePath(t *testing.T) { func TestGetRelativePath(t *testing.T) {
// Create temporary directories for testing // Create temporary directories for testing