package main import ( "bufio" "bytes" "crypto/sha256" "flag" "fmt" "io" "log" "net/http" "os" "path/filepath" "slices" "strings" "sync" "github.com/bmatcuk/doublestar/v4" ) var Error *log.Logger var Warning *log.Logger func init() { log.SetFlags(log.Lmicroseconds | log.Lshortfile) logFile, err := os.OpenFile("updater.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) if err != nil { log.Fatalf("error opening log file: %v", err) } Error = log.New(io.MultiWriter(os.Stderr, logFile), fmt.Sprintf("%sERROR:%s ", "\033[0;101m", "\033[0m"), log.Lmicroseconds|log.Lshortfile) Warning = log.New(io.MultiWriter(os.Stdout, logFile), fmt.Sprintf("%sWarning:%s ", "\033[0;93m", "\033[0m"), log.Lmicroseconds|log.Lshortfile) } const remoteUrl = "https://git.site.quack-lab.dev/dave/barotrauma-gamefiles/raw/branch/cooked/Content" func main() { nodl := flag.Bool("nodl", false, "nodl") hashfile := flag.String("hashfile", "hashes.txt", "hashfile") hash := flag.Bool("hash", false, "hash") root := flag.String("root", ".", "root") flag.Parse() cwd, err := os.Getwd() if err != nil { Error.Printf("error getting cwd: %v", err) return } log.Printf("cwd: %s", cwd) log.Printf("root: %s", *root) *root = filepath.Join(cwd, *root, "Content") log.Printf("root: %s", *root) *hashfile = filepath.Join(*root, *hashfile) hashes, err := LoadLocalHashes(*hashfile) if err != nil { Error.Printf("error loading hashes: %v", err) return } log.Printf("loaded hashes") remoteHashes, err := LoadRemoteHashes(remoteUrl + "/hashes.txt") if err != nil { Error.Printf("error loading remote hashes: %v", err) return } log.Printf("loaded remote hashes") files, err := doublestar.Glob(os.DirFS(*root), "**/*.xml") if err != nil { Error.Printf("error globbing files: %v", err) return } var wg sync.WaitGroup for _, file := range files { wg.Add(1) go func(file string) { defer wg.Done() file = strings.ReplaceAll(file, "\\", "/") path := filepath.Join(*root, file) path = strings.ReplaceAll(path, "\\", "/") hash, err := GetLocalHash(path) if err != nil { Error.Printf("error getting hash: %v", err) return } relativepath, err := filepath.Rel(*root, path) if err != nil { Error.Printf("error getting relative path: %v", err) return } relativepath = strings.ReplaceAll(relativepath, "\\", "/") hashes.Store(relativepath, hash) }(file) } wg.Wait() if *hash { err = SaveLocalHashes(*hashfile, hashes) if err != nil { Error.Printf("error saving hashes: %v", err) return } log.Printf("saved hashes") return } mismatched := 0 checked := 0 toDownload := []string{} remoteHashes.Range(func(key, value interface{}) bool { localhash, ok := hashes.Load(key) if !ok { Warning.Printf("local hash not found: %s", key) mismatched++ toDownload = append(toDownload, key.(string)) return true } if localhash != value { Warning.Printf("hash mismatch: %s", key) mismatched++ toDownload = append(toDownload, key.(string)) } checked++ return true }) log.Printf("Hashes checked: %d, mismatched: %d", checked, mismatched) if mismatched > 0 { log.Printf("Downloading %d files", len(toDownload)) wg := sync.WaitGroup{} for _, file := range toDownload { wg.Add(1) go func(file string) { defer wg.Done() file = strings.ReplaceAll(file, "\\", "/") path := filepath.Join(*root, file) log.Printf("Downloading %s", file) if *nodl { log.Printf("Skipping download for %s", file) return } err := UpdateLocalFile(path, remoteUrl+"/"+file) if err != nil { Error.Printf("error updating local file: %v", err) } newhash, err := GetLocalHash(path) if err != nil { Error.Printf("error getting local hash: %v", err) return } hashes.Store(file, newhash) }(file) } wg.Wait() } // We want to update the newly downloaded files, if any err = SaveLocalHashes(*hashfile, hashes) if err != nil { Error.Printf("error saving hashes: %v", err) } } func GetLocalHash(path string) (string, error) { hash := sha256.New() file, err := os.Open(path) if err != nil { return "", err } defer file.Close() io.Copy(hash, file) return fmt.Sprintf("%x", hash.Sum(nil)), nil } func LoadLocalHashes(path string) (*sync.Map, error) { hashes := &sync.Map{} file, err := os.Open(path) if err != nil { if os.IsNotExist(err) { Warning.Printf("hashes file not found: %s", path) return hashes, nil } return nil, err } defer file.Close() count := 0 scanner := bufio.NewScanner(file) for scanner.Scan() { // Hopefully none of the files have spaces in the name................. line := scanner.Text() parts := strings.Split(line, " ") if len(parts) != 2 { Warning.Printf("invalid line: %s", line) continue } hashes.Store(parts[0], parts[1]) count++ } log.Printf("loaded %d local hashes", count) return hashes, nil } type FileHash struct { Path string Hash string } func SaveLocalHashes(path string, hashes *sync.Map) error { file, err := os.Create(path) if err != nil { return err } defer file.Close() hashesarr := []FileHash{} hashes.Range(func(key, value interface{}) bool { hashesarr = append(hashesarr, FileHash{ Path: key.(string), Hash: value.(string), }) return true }) slices.SortFunc(hashesarr, func(a, b FileHash) int { return strings.Compare(a.Path, b.Path) }) // We are sorting this shit so that our resulting hashes file is consistent for _, hash := range hashesarr { file.WriteString(fmt.Sprintf("%s %s\n", hash.Path, hash.Hash)) } return nil } func LoadRemoteHashes(url string) (*sync.Map, error) { resp, err := http.Get(url) if err != nil { return nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } count := 0 hashes := &sync.Map{} scanner := bufio.NewScanner(bytes.NewReader(body)) for scanner.Scan() { line := scanner.Text() parts := strings.Split(line, " ") if len(parts) != 2 { return nil, fmt.Errorf("invalid line: %s", line) } count++ hashes.Store(parts[0], parts[1]) } log.Printf("loaded %d remote hashes", count) return hashes, nil } func UpdateLocalFile(path string, remoteUrl string) error { resp, err := http.Get(remoteUrl) if err != nil { return err } if resp.StatusCode != 200 { return fmt.Errorf("failed downloading %s on url %q: %d", path, remoteUrl, resp.StatusCode) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed reading body for %s on url %q: %v", path, remoteUrl, err) } err = os.WriteFile(path, body, 0644) if err != nil { return fmt.Errorf("failed writing file for %s on url %q: %v", path, remoteUrl, err) } return nil }