Files
barotrauma-gamefiles/updater/main.go

230 lines
4.7 KiB
Go

package main
import (
"bufio"
"bytes"
"crypto/sha256"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"github.com/bmatcuk/doublestar/v4"
)
var Error *log.Logger
var Warning *log.Logger
func init() {
log.SetFlags(log.Lmicroseconds | log.Lshortfile)
Error = log.New(io.MultiWriter(os.Stderr, os.Stdout),
fmt.Sprintf("%sERROR:%s ", "\033[0;101m", "\033[0m"),
log.Lmicroseconds|log.Lshortfile)
Warning = log.New(io.MultiWriter(os.Stdout),
fmt.Sprintf("%sWarning:%s ", "\033[0;93m", "\033[0m"),
log.Lmicroseconds|log.Lshortfile)
}
const remoteUrl = "https://git.site.quack-lab.dev/dave/barotrauma-gamefiles/raw/branch/master/Content"
func main() {
savehash := flag.Bool("savehash", false, "save hash")
hashfile := flag.String("hashfile", "hashes.txt", "hashfile")
flag.Parse()
cwd, err := os.Getwd()
if err != nil {
Error.Printf("error getting cwd: %v", err)
return
}
root := filepath.Join(cwd, "..", "Content")
*hashfile = filepath.Join(root, *hashfile)
hashes, err := LoadLocalHashes(*hashfile)
if err != nil {
Error.Printf("error loading hashes: %v", err)
return
}
log.Printf("loaded hashes")
remoteHashes, err := LoadRemoteHashes(remoteUrl + "/hashes.txt")
if err != nil {
Error.Printf("error loading remote hashes: %v", err)
return
}
log.Printf("loaded remote hashes")
files, err := doublestar.Glob(os.DirFS(root), "**/*.xml")
if err != nil {
Error.Printf("error globbing files: %v", err)
return
}
var wg sync.WaitGroup
for _, file := range files {
wg.Add(1)
go func(file string) {
defer wg.Done()
path := filepath.Join(root, file)
hash, err := GetLocalHash(path)
if err != nil {
Error.Printf("error getting hash: %v", err)
return
}
hashes.Store(file, hash)
}(file)
}
wg.Wait()
mismatched := 0
checked := 0
toDownload := []string{}
remoteHashes.Range(func(key, value interface{}) bool {
localhash, ok := hashes.Load(key)
if !ok {
Error.Printf("local hash not found: %s", key)
return true
}
if localhash != value {
Warning.Printf("hash mismatch: %s", key)
mismatched++
toDownload = append(toDownload, key.(string))
}
checked++
return true
})
log.Printf("Hashes checked: %d, mismatched: %d", checked, mismatched)
if mismatched > 0 {
log.Printf("Downloading %d files", len(toDownload))
wg := sync.WaitGroup{}
for _, file := range toDownload {
wg.Add(1)
go func(file string) {
defer wg.Done()
log.Printf("Downloading %s", file)
err := UpdateLocalFile(file, remoteUrl)
if err != nil {
Error.Printf("error updating local file: %v", err)
}
}(file)
}
wg.Wait()
}
if *savehash {
err := SaveLocalHashes(*hashfile, hashes)
if err != nil {
Error.Printf("error saving hashes: %v", err)
}
}
}
func GetLocalHash(path string) (string, error) {
hash := sha256.New()
file, err := os.Open(path)
if err != nil {
return "", err
}
defer file.Close()
io.Copy(hash, file)
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}
func LoadLocalHashes(path string) (*sync.Map, error) {
hashes := &sync.Map{}
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
count := 0
scanner := bufio.NewScanner(file)
for scanner.Scan() {
// Hopefully none of the files have spaces in the name.................
line := scanner.Text()
parts := strings.Split(line, " ")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid line: %s", line)
}
hashes.Store(parts[0], parts[1])
count++
}
log.Printf("loaded %d local hashes", count)
return hashes, nil
}
func SaveLocalHashes(path string, hashes *sync.Map) error {
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()
hashes.Range(func(key, value interface{}) bool {
file.WriteString(fmt.Sprintf("%s %s\n", key, value))
return true
})
return nil
}
func LoadRemoteHashes(url string) (*sync.Map, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
count := 0
hashes := &sync.Map{}
scanner := bufio.NewScanner(bytes.NewReader(body))
for scanner.Scan() {
line := scanner.Text()
parts := strings.Split(line, " ")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid line: %s", line)
}
count++
hashes.Store(parts[0], parts[1])
}
log.Printf("loaded %d remote hashes", count)
return hashes, nil
}
func UpdateLocalFile(path string, remoteUrl string) error {
resp, err := http.Get(remoteUrl)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = os.WriteFile(path, body, 0644)
if err != nil {
return err
}
return nil
}