296 lines
6.7 KiB
Go
296 lines
6.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/bmatcuk/doublestar/v4"
|
|
)
|
|
|
|
var Error *log.Logger
|
|
var Warning *log.Logger
|
|
|
|
func init() {
|
|
log.SetFlags(log.Lmicroseconds | log.Lshortfile)
|
|
logFile, err := os.OpenFile("updater.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
|
if err != nil {
|
|
log.Fatalf("error opening log file: %v", err)
|
|
}
|
|
Error = log.New(io.MultiWriter(os.Stderr, logFile),
|
|
fmt.Sprintf("%sERROR:%s ", "\033[0;101m", "\033[0m"),
|
|
log.Lmicroseconds|log.Lshortfile)
|
|
Warning = log.New(io.MultiWriter(os.Stdout, logFile),
|
|
fmt.Sprintf("%sWarning:%s ", "\033[0;93m", "\033[0m"),
|
|
log.Lmicroseconds|log.Lshortfile)
|
|
}
|
|
|
|
const remoteUrl = "https://git.site.quack-lab.dev/dave/barotrauma-gamefiles/raw/branch/cooked/Content"
|
|
|
|
func main() {
|
|
nodl := flag.Bool("nodl", false, "nodl")
|
|
hashfile := flag.String("hashfile", "hashes.txt", "hashfile")
|
|
hash := flag.Bool("hash", false, "hash")
|
|
root := flag.String("root", ".", "root")
|
|
flag.Parse()
|
|
|
|
cwd, err := os.Getwd()
|
|
if err != nil {
|
|
Error.Printf("error getting cwd: %v", err)
|
|
return
|
|
}
|
|
log.Printf("cwd: %s", cwd)
|
|
log.Printf("root: %s", *root)
|
|
|
|
*root = filepath.Join(cwd, *root, "Content")
|
|
log.Printf("root: %s", *root)
|
|
|
|
*hashfile = filepath.Join(*root, *hashfile)
|
|
hashes, err := LoadLocalHashes(*hashfile)
|
|
if err != nil {
|
|
Error.Printf("error loading hashes: %v", err)
|
|
return
|
|
}
|
|
log.Printf("loaded hashes")
|
|
|
|
remoteHashes, err := LoadRemoteHashes(remoteUrl + "/hashes.txt")
|
|
if err != nil {
|
|
Error.Printf("error loading remote hashes: %v", err)
|
|
return
|
|
}
|
|
log.Printf("loaded remote hashes")
|
|
|
|
files, err := doublestar.Glob(os.DirFS(*root), "**/*.xml")
|
|
if err != nil {
|
|
Error.Printf("error globbing files: %v", err)
|
|
return
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, file := range files {
|
|
wg.Add(1)
|
|
go func(file string) {
|
|
defer wg.Done()
|
|
file = strings.ReplaceAll(file, "\\", "/")
|
|
path := filepath.Join(*root, file)
|
|
path = strings.ReplaceAll(path, "\\", "/")
|
|
hash, err := GetLocalHash(path)
|
|
if err != nil {
|
|
Error.Printf("error getting hash: %v", err)
|
|
return
|
|
}
|
|
relativepath, err := filepath.Rel(*root, path)
|
|
if err != nil {
|
|
Error.Printf("error getting relative path: %v", err)
|
|
return
|
|
}
|
|
relativepath = strings.ReplaceAll(relativepath, "\\", "/")
|
|
hashes.Store(relativepath, hash)
|
|
}(file)
|
|
}
|
|
wg.Wait()
|
|
|
|
if *hash {
|
|
err = SaveLocalHashes(*hashfile, hashes)
|
|
if err != nil {
|
|
Error.Printf("error saving hashes: %v", err)
|
|
return
|
|
}
|
|
log.Printf("saved hashes")
|
|
return
|
|
}
|
|
|
|
mismatched := 0
|
|
checked := 0
|
|
toDownload := []string{}
|
|
remoteHashes.Range(func(key, value interface{}) bool {
|
|
localhash, ok := hashes.Load(key)
|
|
if !ok {
|
|
Warning.Printf("local hash not found: %s", key)
|
|
mismatched++
|
|
toDownload = append(toDownload, key.(string))
|
|
return true
|
|
}
|
|
if localhash != value {
|
|
Warning.Printf("hash mismatch: %s", key)
|
|
mismatched++
|
|
toDownload = append(toDownload, key.(string))
|
|
}
|
|
checked++
|
|
return true
|
|
})
|
|
log.Printf("Hashes checked: %d, mismatched: %d", checked, mismatched)
|
|
|
|
if mismatched > 0 {
|
|
log.Printf("Downloading %d files", len(toDownload))
|
|
wg := sync.WaitGroup{}
|
|
for _, file := range toDownload {
|
|
wg.Add(1)
|
|
go func(file string) {
|
|
defer wg.Done()
|
|
file = strings.ReplaceAll(file, "\\", "/")
|
|
path := filepath.Join(*root, file)
|
|
log.Printf("Downloading %s", file)
|
|
if *nodl {
|
|
log.Printf("Skipping download for %s", file)
|
|
return
|
|
}
|
|
err := UpdateLocalFile(path, remoteUrl+"/"+file)
|
|
if err != nil {
|
|
Error.Printf("error updating local file: %v", err)
|
|
}
|
|
newhash, err := GetLocalHash(path)
|
|
if err != nil {
|
|
Error.Printf("error getting local hash: %v", err)
|
|
return
|
|
}
|
|
hashes.Store(file, newhash)
|
|
}(file)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
// We want to update the newly downloaded files, if any
|
|
err = SaveLocalHashes(*hashfile, hashes)
|
|
if err != nil {
|
|
Error.Printf("error saving hashes: %v", err)
|
|
}
|
|
}
|
|
|
|
func GetLocalHash(path string) (string, error) {
|
|
hash := sha256.New()
|
|
|
|
file, err := os.Open(path)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer file.Close()
|
|
|
|
io.Copy(hash, file)
|
|
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
|
}
|
|
|
|
func LoadLocalHashes(path string) (*sync.Map, error) {
|
|
hashes := &sync.Map{}
|
|
|
|
file, err := os.Open(path)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
Warning.Printf("hashes file not found: %s", path)
|
|
return hashes, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
defer file.Close()
|
|
|
|
count := 0
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
// Hopefully none of the files have spaces in the name.................
|
|
line := scanner.Text()
|
|
parts := strings.Split(line, " ")
|
|
if len(parts) != 2 {
|
|
return nil, fmt.Errorf("invalid line: %s", line)
|
|
}
|
|
hashes.Store(parts[0], parts[1])
|
|
count++
|
|
}
|
|
log.Printf("loaded %d local hashes", count)
|
|
|
|
return hashes, nil
|
|
}
|
|
|
|
type FileHash struct {
|
|
Path string
|
|
Hash string
|
|
}
|
|
|
|
func SaveLocalHashes(path string, hashes *sync.Map) error {
|
|
file, err := os.Create(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
hashesarr := []FileHash{}
|
|
hashes.Range(func(key, value interface{}) bool {
|
|
hashesarr = append(hashesarr, FileHash{
|
|
Path: key.(string),
|
|
Hash: value.(string),
|
|
})
|
|
return true
|
|
})
|
|
|
|
slices.SortFunc(hashesarr, func(a, b FileHash) int {
|
|
return strings.Compare(a.Path, b.Path)
|
|
})
|
|
|
|
// We are sorting this shit so that our resulting hashes file is consistent
|
|
for _, hash := range hashesarr {
|
|
file.WriteString(fmt.Sprintf("%s %s\n", hash.Path, hash.Hash))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func LoadRemoteHashes(url string) (*sync.Map, error) {
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
count := 0
|
|
hashes := &sync.Map{}
|
|
scanner := bufio.NewScanner(bytes.NewReader(body))
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
parts := strings.Split(line, " ")
|
|
if len(parts) != 2 {
|
|
return nil, fmt.Errorf("invalid line: %s", line)
|
|
}
|
|
count++
|
|
hashes.Store(parts[0], parts[1])
|
|
}
|
|
log.Printf("loaded %d remote hashes", count)
|
|
|
|
return hashes, nil
|
|
}
|
|
|
|
func UpdateLocalFile(path string, remoteUrl string) error {
|
|
resp, err := http.Get(remoteUrl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
return fmt.Errorf("failed downloading %s on url %q: %d", path, remoteUrl, resp.StatusCode)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed reading body for %s on url %q: %v", path, remoteUrl, err)
|
|
}
|
|
|
|
err = os.WriteFile(path, body, 0644)
|
|
if err != nil {
|
|
return fmt.Errorf("failed writing file for %s on url %q: %v", path, remoteUrl, err)
|
|
}
|
|
|
|
return nil
|
|
}
|