From 774ac0f0ca27950c9e69409e328b09cee204aab3 Mon Sep 17 00:00:00 2001 From: PhatPhuckDave Date: Sun, 20 Jul 2025 11:43:25 +0200 Subject: [PATCH] Implement proper "reset" that reads snapshots from database --- main.go | 18 ++++++++++++------ utils/db.go | 13 +++++++++++++ utils/file.go | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index c7be14d..40db827 100644 --- a/main.go +++ b/main.go @@ -56,6 +56,12 @@ func main() { logger.InitFlag() logger.Info("Initializing with log level: %s", logger.GetLevel().String()) + db, err := utils.GetDB() + if err != nil { + logger.Error("Failed to get database: %v", err) + return + } + // The plan is: // Load all commands commands, err := utils.LoadCommands(args) @@ -105,6 +111,12 @@ func main() { return } + err = utils.ResetWhereNecessary(associations, db) + if err != nil { + logger.Error("Failed to reset files where necessary: %v", err) + return + } + // Then for each file run all commands associated with the file workers := make(chan struct{}, *utils.ParallelFiles) wg := sync.WaitGroup{} @@ -137,12 +149,6 @@ func main() { logger.Debug("Created logger for command %q with log level %s", cmdName, cmdLogLevel.String()) } - db, err := utils.GetDB() - if err != nil { - logger.Error("Failed to get database: %v", err) - return - } - for file, association := range associations { workers <- struct{}{} wg.Add(1) diff --git a/utils/db.go b/utils/db.go index bb4e662..29ede0d 100644 --- a/utils/db.go +++ b/utils/db.go @@ -14,6 +14,7 @@ type DB interface { DB() *gorm.DB Raw(sql string, args ...any) *gorm.DB SaveFile(filePath string, fileData []byte) error + GetFile(filePath string) ([]byte, error) } type FileSnapshot struct { @@ -80,3 +81,15 @@ func (db *DBWrapper) SaveFile(filePath string, fileData []byte) error { FileData: fileData, }).Error } + +func (db *DBWrapper) GetFile(filePath string) ([]byte, error) { + log := cylogger.Default.WithPrefix(fmt.Sprintf("GetFile: %q", filePath)) + log.Debug("Getting file from database") + var fileSnapshot FileSnapshot + err := db.db.Model(&FileSnapshot{}).Where("file_path = ?", filePath).First(&fileSnapshot).Error + if err != nil { + return nil, err + } + log.Debug("File found in database") + return fileSnapshot.FileData, nil +} diff --git a/utils/file.go b/utils/file.go index 027d6ca..b576537 100644 --- a/utils/file.go +++ b/utils/file.go @@ -33,3 +33,41 @@ func ToAbs(path string) string { log.Trace("Cwd: %q", cwd) return CleanPath(filepath.Join(cwd, path)) } + +func ResetWhereNecessary(associations map[string]FileCommandAssociation, db DB) error { + log := cylogger.Default.WithPrefix("ResetWhereNecessary") + log.Debug("Start") + dirtyFiles := make(map[string]struct{}) + for _, association := range associations { + for _, command := range association.Commands { + log.Debug("Checking command %q for file %q", command.Name, association.File) + if command.Reset { + log.Debug("Command %q requires reset for file %q", command.Name, association.File) + dirtyFiles[association.File] = struct{}{} + } + } + for _, command := range association.IsolateCommands { + log.Debug("Checking isolate command %q for file %q", command.Name, association.File) + if command.Reset { + log.Debug("Isolate command %q requires reset for file %q", command.Name, association.File) + dirtyFiles[association.File] = struct{}{} + } + } + } + log.Debug("Dirty files: %v", dirtyFiles) + for file := range dirtyFiles { + log.Debug("Resetting file %q", file) + fileData, err := db.GetFile(file) + if err != nil { + return err + } + log.Debug("Writing file %q to disk", file) + err = os.WriteFile(file, fileData, 0644) + if err != nil { + return err + } + log.Debug("File %q written to disk", file) + } + log.Debug("Done") + return nil +}