Files
2025-03-28 18:51:57 +01:00

297 lines
6.7 KiB
Go

package main
import (
"bufio"
"bytes"
"crypto/sha256"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"github.com/bmatcuk/doublestar/v4"
)
var Error *log.Logger
var Warning *log.Logger
func init() {
log.SetFlags(log.Lmicroseconds | log.Lshortfile)
logFile, err := os.OpenFile("updater.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Fatalf("error opening log file: %v", err)
}
Error = log.New(io.MultiWriter(os.Stderr, logFile),
fmt.Sprintf("%sERROR:%s ", "\033[0;101m", "\033[0m"),
log.Lmicroseconds|log.Lshortfile)
Warning = log.New(io.MultiWriter(os.Stdout, logFile),
fmt.Sprintf("%sWarning:%s ", "\033[0;93m", "\033[0m"),
log.Lmicroseconds|log.Lshortfile)
}
const remoteUrl = "https://git.site.quack-lab.dev/dave/barotrauma-gamefiles/raw/branch/cooked/Content"
func main() {
nodl := flag.Bool("nodl", false, "nodl")
hashfile := flag.String("hashfile", "hashes.txt", "hashfile")
hash := flag.Bool("hash", false, "hash")
root := flag.String("root", ".", "root")
flag.Parse()
cwd, err := os.Getwd()
if err != nil {
Error.Printf("error getting cwd: %v", err)
return
}
log.Printf("cwd: %s", cwd)
log.Printf("root: %s", *root)
*root = filepath.Join(cwd, *root, "Content")
log.Printf("root: %s", *root)
*hashfile = filepath.Join(*root, *hashfile)
hashes, err := LoadLocalHashes(*hashfile)
if err != nil {
Error.Printf("error loading hashes: %v", err)
return
}
log.Printf("loaded hashes")
remoteHashes, err := LoadRemoteHashes(remoteUrl + "/hashes.txt")
if err != nil {
Error.Printf("error loading remote hashes: %v", err)
return
}
log.Printf("loaded remote hashes")
files, err := doublestar.Glob(os.DirFS(*root), "**/*.xml")
if err != nil {
Error.Printf("error globbing files: %v", err)
return
}
var wg sync.WaitGroup
for _, file := range files {
wg.Add(1)
go func(file string) {
defer wg.Done()
file = strings.ReplaceAll(file, "\\", "/")
path := filepath.Join(*root, file)
path = strings.ReplaceAll(path, "\\", "/")
hash, err := GetLocalHash(path)
if err != nil {
Error.Printf("error getting hash: %v", err)
return
}
relativepath, err := filepath.Rel(*root, path)
if err != nil {
Error.Printf("error getting relative path: %v", err)
return
}
relativepath = strings.ReplaceAll(relativepath, "\\", "/")
hashes.Store(relativepath, hash)
}(file)
}
wg.Wait()
if *hash {
err = SaveLocalHashes(*hashfile, hashes)
if err != nil {
Error.Printf("error saving hashes: %v", err)
return
}
log.Printf("saved hashes")
return
}
mismatched := 0
checked := 0
toDownload := []string{}
remoteHashes.Range(func(key, value interface{}) bool {
localhash, ok := hashes.Load(key)
if !ok {
Warning.Printf("local hash not found: %s", key)
mismatched++
toDownload = append(toDownload, key.(string))
return true
}
if localhash != value {
Warning.Printf("hash mismatch: %s", key)
mismatched++
toDownload = append(toDownload, key.(string))
}
checked++
return true
})
log.Printf("Hashes checked: %d, mismatched: %d", checked, mismatched)
if mismatched > 0 {
log.Printf("Downloading %d files", len(toDownload))
wg := sync.WaitGroup{}
for _, file := range toDownload {
wg.Add(1)
go func(file string) {
defer wg.Done()
file = strings.ReplaceAll(file, "\\", "/")
path := filepath.Join(*root, file)
log.Printf("Downloading %s", file)
if *nodl {
log.Printf("Skipping download for %s", file)
return
}
err := UpdateLocalFile(path, remoteUrl+"/"+file)
if err != nil {
Error.Printf("error updating local file: %v", err)
}
newhash, err := GetLocalHash(path)
if err != nil {
Error.Printf("error getting local hash: %v", err)
return
}
hashes.Store(file, newhash)
}(file)
}
wg.Wait()
}
// We want to update the newly downloaded files, if any
err = SaveLocalHashes(*hashfile, hashes)
if err != nil {
Error.Printf("error saving hashes: %v", err)
}
}
func GetLocalHash(path string) (string, error) {
hash := sha256.New()
file, err := os.Open(path)
if err != nil {
return "", err
}
defer file.Close()
io.Copy(hash, file)
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}
func LoadLocalHashes(path string) (*sync.Map, error) {
hashes := &sync.Map{}
file, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
Warning.Printf("hashes file not found: %s", path)
return hashes, nil
}
return nil, err
}
defer file.Close()
count := 0
scanner := bufio.NewScanner(file)
for scanner.Scan() {
// Hopefully none of the files have spaces in the name.................
line := scanner.Text()
parts := strings.Split(line, " ")
if len(parts) != 2 {
Warning.Printf("invalid line: %s", line)
continue
}
hashes.Store(parts[0], parts[1])
count++
}
log.Printf("loaded %d local hashes", count)
return hashes, nil
}
type FileHash struct {
Path string
Hash string
}
func SaveLocalHashes(path string, hashes *sync.Map) error {
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()
hashesarr := []FileHash{}
hashes.Range(func(key, value interface{}) bool {
hashesarr = append(hashesarr, FileHash{
Path: key.(string),
Hash: value.(string),
})
return true
})
slices.SortFunc(hashesarr, func(a, b FileHash) int {
return strings.Compare(a.Path, b.Path)
})
// We are sorting this shit so that our resulting hashes file is consistent
for _, hash := range hashesarr {
file.WriteString(fmt.Sprintf("%s %s\n", hash.Path, hash.Hash))
}
return nil
}
func LoadRemoteHashes(url string) (*sync.Map, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
count := 0
hashes := &sync.Map{}
scanner := bufio.NewScanner(bytes.NewReader(body))
for scanner.Scan() {
line := scanner.Text()
parts := strings.Split(line, " ")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid line: %s", line)
}
count++
hashes.Store(parts[0], parts[1])
}
log.Printf("loaded %d remote hashes", count)
return hashes, nil
}
func UpdateLocalFile(path string, remoteUrl string) error {
resp, err := http.Get(remoteUrl)
if err != nil {
return err
}
if resp.StatusCode != 200 {
return fmt.Errorf("failed downloading %s on url %q: %d", path, remoteUrl, resp.StatusCode)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed reading body for %s on url %q: %v", path, remoteUrl, err)
}
err = os.WriteFile(path, body, 0644)
if err != nil {
return fmt.Errorf("failed writing file for %s on url %q: %v", path, remoteUrl, err)
}
return nil
}