diff --git a/go.mod b/go.mod index 829682f..879a77b 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,18 @@ module hitman -go 1.23.0 +go 1.23.6 -require golang.org/x/sys v0.25.0 +require ( + git.site.quack-lab.dev/dave/cylogger v1.3.0 + git.site.quack-lab.dev/dave/cyutils v1.1.3 + golang.org/x/sys v0.25.0 +) + +require ( + github.com/google/go-cmp v0.5.9 // indirect + github.com/hexops/valast v1.5.0 // indirect + golang.org/x/mod v0.7.0 // indirect + golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.4.0 // indirect + mvdan.cc/gofumpt v0.4.0 // indirect +) diff --git a/go.sum b/go.sum index c9930ff..3f0464d 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,32 @@ +git.site.quack-lab.dev/dave/cylogger v1.3.0 h1:eTWPUD+ThVi8kGIsRcE0XDeoH3yFb5miFEODyKUdWJw= +git.site.quack-lab.dev/dave/cylogger v1.3.0/go.mod h1:wctgZplMvroA4X6p8f4B/LaCKtiBcT1Pp+L14kcS8jk= +git.site.quack-lab.dev/dave/cyutils v1.1.3 h1:9Y1GhrPrVLut36hceZwuFm0IMlAFerl6ATRPa9tGHFM= +git.site.quack-lab.dev/dave/cyutils v1.1.3/go.mod h1:fBjALu2Cp2u2bDr+E4zbGVMBeIgFzROg+4TCcTNAiQU= +github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= +github.com/frankban/quicktest v1.14.3/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hexops/autogold v0.8.1 h1:wvyd/bAJ+Dy+DcE09BoLk6r4Fa5R5W+O+GUzmR985WM= +github.com/hexops/autogold v0.8.1/go.mod h1:97HLDXyG23akzAoRYJh/2OBs3kd80eHyKPvZw0S5ZBY= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/hexops/valast v1.5.0 h1:FBTuvVi0wjTngtXJRZXMbkN/Dn6DgsUsBwch2DUJU8Y= +github.com/hexops/valast v1.5.0/go.mod h1:Jcy1pNH7LNraVaAZDLyv21hHg2WBv9Nf9FL6fGxU7o4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= +golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= +golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= +mvdan.cc/gofumpt v0.4.0 h1:JVf4NN1mIpHogBj7ABpgOyZc65/UUOkKQFkoURsz4MM= +mvdan.cc/gofumpt v0.4.0/go.mod h1:PljLOHDeZqgS8opHRKLzp2It2VBuSdteAgqUfzMTxlQ= diff --git a/main.go b/main.go index 61f0b51..844a9c2 100644 --- a/main.go +++ b/main.go @@ -2,105 +2,175 @@ package main import ( "context" - "fmt" - "io" - "log" "os" + "strconv" "strings" "syscall" "time" + + logger "git.site.quack-lab.dev/dave/cylogger" + utils "git.site.quack-lab.dev/dave/cyutils" ) -var Error *log.Logger -var Warning *log.Logger - -func init() { - log.SetFlags(log.Lmicroseconds | log.Lshortfile) - logFile, err := os.Create("main.log") - if err != nil { - log.Printf("Error creating log file: %v", err) - os.Exit(1) - } - logger := io.MultiWriter(os.Stdout, logFile) - log.SetOutput(logger) - - Error = log.New(io.MultiWriter(logFile, os.Stderr, os.Stdout), - fmt.Sprintf("%sERROR:%s ", "\033[0;101m", "\033[0m"), - log.Lmicroseconds|log.Lshortfile) - Warning = log.New(io.MultiWriter(logFile, os.Stdout), - fmt.Sprintf("%sWarning:%s ", "\033[0;93m", "\033[0m"), - log.Lmicroseconds|log.Lshortfile) +type Config struct { + Forbidden []string + ScanInterval time.Duration + Timeout time.Duration + Workers int } -func main() { - forbidden, exists := os.LookupEnv("HITMAN_FORBIDDEN") - if !exists { - Error.Println("HITMAN_FORBIDDEN environment variable not set") - log.Printf("Please set to a comma separated list of process names to forbid") - return +func getenv(key, def string) string { + if v, ok := os.LookupEnv(key); ok { + return v + } + return def +} + +var timeUnits = map[string]int64{ + "ms": 1, + "s": 1000, + "m": 60_000, + "h": 3_600_000, + "d": 86_400_000, + "M": 2_592_000_000, + "y": 31_536_000_000, +} + +// parseDurationMS supports "1s", "500ms", and compound "1s_500ms" +func parseDurationMS(expr string) int64 { + expr = strings.TrimSpace(expr) + if expr == "" { + return 0 + } + var total int64 + var val strings.Builder + var unit strings.Builder + + flush := func() { + if val.Len() == 0 || unit.Len() == 0 { + return + } + v, err := strconv.ParseInt(val.String(), 10, 64) + if err != nil { + logger.Warning("Invalid duration value: %q: %v", val.String(), err) + val.Reset() + unit.Reset() + return + } + u := unit.String() + mul, ok := timeUnits[u] + if !ok { + logger.Warning("Invalid duration unit: %q", u) + val.Reset() + unit.Reset() + return + } + total += v * mul + val.Reset() + unit.Reset() } - delay := time.Duration(2) * time.Second - scanDelay, exists := os.LookupEnv("HITMAN_SCAN_DELAY") - if !exists { - log.Printf("No scan delay is set, defaulting to %vs", delay.Seconds()) - log.Printf("Set HITMAN_SCAN_DELAY to change this") + for _, part := range strings.Split(expr, "_") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + val.Reset() + unit.Reset() + for _, r := range part { + if r >= '0' && r <= '9' { + val.WriteRune(r) + } else if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + unit.WriteRune(r) + } + } + flush() + } + return total +} + +func loadConfig() Config { + logger.InitFlag() + log := logger.Default.WithPrefix("loadConfig") + + // Forbidden names + forbidden := []string{} + if env := strings.TrimSpace(getenv("FORBIDDEN", "")); env != "" { + for _, p := range strings.Split(env, ",") { + p = strings.TrimSpace(p) + if p != "" { + forbidden = append(forbidden, p) + } + } + } + + // Default scan interval: 2s + scan := time.Duration(parseDurationMS(getenv("SCAN_INTERVAL", "2s"))) * time.Millisecond + if scan <= 0 { + scan = 2 * time.Second + } + + // Default timeout: half the scan interval + timeout := time.Duration(parseDurationMS(getenv("TIMEOUT", ""))) * time.Millisecond + if timeout <= 0 { + timeout = scan / 2 + } + + workers := 10 + if w := strings.TrimSpace(getenv("WORKERS", "")); w != "" { + if n, err := strconv.Atoi(w); err == nil && n > 0 { + workers = n + } else if err != nil { + log.Warning("Invalid WORKERS value %q: %v (using default %d)", w, err, workers) + } + } + + cfg := Config{ + Forbidden: forbidden, + ScanInterval: scan, + Timeout: timeout, + Workers: workers, + } + + // Config dump + log.Info("Configuration loaded") + log.Info("SCAN_INTERVAL(ms): %d", cfg.ScanInterval.Milliseconds()) + log.Info("TIMEOUT(ms): %d", cfg.Timeout.Milliseconds()) + log.Info("WORKERS: %d", cfg.Workers) + if len(cfg.Forbidden) == 0 { + log.Warning("FORBIDDEN is empty - nothing to kill") } else { - var err error - delay, err = time.ParseDuration(scanDelay) - if err != nil { - Error.Printf("Error parsing scan delay: %v", err) - return - } + log.Info("Forbidden process names: %d", len(cfg.Forbidden)) + log.Trace("Forbidden list: %v", cfg.Forbidden) } - timeout := delay / 2 - etimeout, exists := os.LookupEnv("HITMAN_TIMEOUT") - if !exists { - log.Printf("No timeout is set, defaulting to %vs", timeout.Seconds()) - log.Printf("Set HITMAN_TIMEOUT to change this") - } else { - var err error - timeout, err = time.ParseDuration(etimeout) - if err != nil { - Error.Printf("Error parsing timeout: %v", err) - return - } + return cfg +} + +// Kill sends SIGKILL (-9) immediately and waits until the process is reaped or timeout elapses. +func Kill(ctx context.Context, pid int) error { + proc, err := os.FindProcess(pid) + if err != nil { + return err } - procs := strings.Split(forbidden, ",") - workers := make(chan struct{}, 10) + // Unconditional -9 + if err := proc.Signal(syscall.SIGKILL); err != nil { + return err + } - for { - log.Printf("Running") - procmap, err := BuildProcessMap() - if err != nil { - Error.Printf("Error building process map: %v", err) - return - } + done := make(chan error, 1) + go func() { + _, err := proc.Wait() + done <- err + }() - for _, proc := range procs { - workers <- struct{}{} - go func(proc string) { - defer func() { <-workers }() - proc = strings.Trim(proc, " ") - log.Printf("Checking %s", proc) - res, ok := procmap.findByName(proc) - if ok { - log.Printf("Forbidden process %s found (x%d)", proc, len(res)) - for _, node := range res { - log.Printf("Killing forbidden process %d", node.Proc.ProcessID) - err := KillWithTimeout(int(node.Proc.ProcessID), timeout) - if err != nil { - Error.Printf("Error terminating process %d: %v", node.Proc.ProcessID, err) - } - } - } else { - log.Printf("No forbidden process %s found", proc) - } - }(proc) - } - time.Sleep(delay) + select { + case err := <-done: + return err + case <-ctx.Done(): + _ = proc.Release() + return ctx.Err() } } @@ -110,28 +180,78 @@ func KillWithTimeout(pid int, timeout time.Duration) error { return Kill(ctx, pid) } -func Kill(ctx context.Context, pid int) error { - process, err := os.FindProcess(pid) +func processCycle(cfg Config) { + log := logger.Default.WithPrefix("processCycle") + log.Info("Starting process cleanup cycle") + log.Debug("Timeout(ms): %d", cfg.Timeout.Milliseconds()) + + procmap, err := BuildProcessMap() if err != nil { - return err + log.Error("BuildProcessMap failed: %v", err) + return } - err = process.Signal(syscall.SIGKILL) - if err != nil { - return err + if len(cfg.Forbidden) == 0 { + log.Warning("No forbidden processes defined; skipping") + return } - done := make(chan error, 1) - go func() { - _, err := process.Wait() - done <- err - }() + // Parallel pass through forbidden names + utils.WithWorkers(cfg.Workers, cfg.Forbidden, func(worker int, name string) { + name = strings.TrimSpace(name) + ilog := log.WithPrefix("forbidden").WithPrefix(name).WithField("worker", worker) - select { - case err := <-done: - return err - case <-ctx.Done(): - process.Release() - return ctx.Err() + if name == "" { + ilog.Warning("Empty name, skipping") + return + } + + ilog.Debug("Searching processes by name") + ilog.Trace("Query: %s", name) + + results, ok := procmap.findByName(name) + if !ok || len(results) == 0 { + ilog.Info("No matching processes found") + return + } + + ilog.Info("Found %d matching processes", len(results)) + for _, node := range results { + pid := int(node.Proc.ProcessID) + plog := ilog.WithPrefix("pid").WithPrefix(strconv.Itoa(pid)) + + plog.Debug("Killing with SIGKILL (-9)") + plog.Trace("PID: %d", pid) + + if err := KillWithTimeout(pid, cfg.Timeout); err != nil { + plog.Error("Kill failed: %v", err) + } else { + plog.Info("Process killed") + } + } + }) + log.Info("Process cleanup cycle complete") +} + +func main() { + app := logger.Default.WithPrefix("main") + app.Info("Starting hitman (no-questions-asked)") + + cfg := loadConfig() + + // Initial cycle + processCycle(cfg) + + // Ticker loop + t := time.NewTicker(cfg.ScanInterval) + defer t.Stop() + + for { + ts := <-t.C + tlog := app.WithPrefix("tick").WithPrefix(strconv.FormatInt(ts.UnixMilli(), 10)) + tlog.Info("Timer tick") + tlog.Trace("Timestamp(ms): %d", ts.UnixMilli()) + + processCycle(cfg) } }