258 lines
5.7 KiB
Go
258 lines
5.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
logger "git.site.quack-lab.dev/dave/cylogger"
|
|
utils "git.site.quack-lab.dev/dave/cyutils"
|
|
)
|
|
|
|
type Config struct {
|
|
Forbidden []string
|
|
ScanInterval time.Duration
|
|
Timeout time.Duration
|
|
Workers int
|
|
}
|
|
|
|
func getenv(key, def string) string {
|
|
if v, ok := os.LookupEnv(key); ok {
|
|
return v
|
|
}
|
|
return def
|
|
}
|
|
|
|
var timeUnits = map[string]int64{
|
|
"ms": 1,
|
|
"s": 1000,
|
|
"m": 60_000,
|
|
"h": 3_600_000,
|
|
"d": 86_400_000,
|
|
"M": 2_592_000_000,
|
|
"y": 31_536_000_000,
|
|
}
|
|
|
|
// parseDurationMS supports "1s", "500ms", and compound "1s_500ms"
|
|
func parseDurationMS(expr string) int64 {
|
|
expr = strings.TrimSpace(expr)
|
|
if expr == "" {
|
|
return 0
|
|
}
|
|
var total int64
|
|
var val strings.Builder
|
|
var unit strings.Builder
|
|
|
|
flush := func() {
|
|
if val.Len() == 0 || unit.Len() == 0 {
|
|
return
|
|
}
|
|
v, err := strconv.ParseInt(val.String(), 10, 64)
|
|
if err != nil {
|
|
logger.Warning("Invalid duration value: %q: %v", val.String(), err)
|
|
val.Reset()
|
|
unit.Reset()
|
|
return
|
|
}
|
|
u := unit.String()
|
|
mul, ok := timeUnits[u]
|
|
if !ok {
|
|
logger.Warning("Invalid duration unit: %q", u)
|
|
val.Reset()
|
|
unit.Reset()
|
|
return
|
|
}
|
|
total += v * mul
|
|
val.Reset()
|
|
unit.Reset()
|
|
}
|
|
|
|
for _, part := range strings.Split(expr, "_") {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
val.Reset()
|
|
unit.Reset()
|
|
for _, r := range part {
|
|
if r >= '0' && r <= '9' {
|
|
val.WriteRune(r)
|
|
} else if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') {
|
|
unit.WriteRune(r)
|
|
}
|
|
}
|
|
flush()
|
|
}
|
|
return total
|
|
}
|
|
|
|
func loadConfig() Config {
|
|
logger.InitFlag()
|
|
log := logger.Default.WithPrefix("loadConfig")
|
|
|
|
// Forbidden names
|
|
forbidden := []string{}
|
|
if env := strings.TrimSpace(getenv("FORBIDDEN", "")); env != "" {
|
|
for _, p := range strings.Split(env, ",") {
|
|
p = strings.TrimSpace(p)
|
|
if p != "" {
|
|
forbidden = append(forbidden, p)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Default scan interval: 2s
|
|
scan := time.Duration(parseDurationMS(getenv("SCAN_INTERVAL", "2s"))) * time.Millisecond
|
|
if scan <= 0 {
|
|
scan = 2 * time.Second
|
|
}
|
|
|
|
// Default timeout: half the scan interval
|
|
timeout := time.Duration(parseDurationMS(getenv("TIMEOUT", ""))) * time.Millisecond
|
|
if timeout <= 0 {
|
|
timeout = scan / 2
|
|
}
|
|
|
|
workers := 10
|
|
if w := strings.TrimSpace(getenv("WORKERS", "")); w != "" {
|
|
if n, err := strconv.Atoi(w); err == nil && n > 0 {
|
|
workers = n
|
|
} else if err != nil {
|
|
log.Warning("Invalid WORKERS value %q: %v (using default %d)", w, err, workers)
|
|
}
|
|
}
|
|
|
|
cfg := Config{
|
|
Forbidden: forbidden,
|
|
ScanInterval: scan,
|
|
Timeout: timeout,
|
|
Workers: workers,
|
|
}
|
|
|
|
// Config dump
|
|
log.Info("Configuration loaded")
|
|
log.Info("SCAN_INTERVAL(ms): %d", cfg.ScanInterval.Milliseconds())
|
|
log.Info("TIMEOUT(ms): %d", cfg.Timeout.Milliseconds())
|
|
log.Info("WORKERS: %d", cfg.Workers)
|
|
if len(cfg.Forbidden) == 0 {
|
|
log.Warning("FORBIDDEN is empty - nothing to kill")
|
|
} else {
|
|
log.Info("Forbidden process names: %d", len(cfg.Forbidden))
|
|
log.Trace("Forbidden list: %v", cfg.Forbidden)
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
// Kill sends SIGKILL (-9) immediately and waits until the process is reaped or timeout elapses.
|
|
func Kill(ctx context.Context, pid int) error {
|
|
proc, err := os.FindProcess(pid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Unconditional -9
|
|
if err := proc.Signal(syscall.SIGKILL); err != nil {
|
|
return err
|
|
}
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := proc.Wait()
|
|
done <- err
|
|
}()
|
|
|
|
select {
|
|
case err := <-done:
|
|
return err
|
|
case <-ctx.Done():
|
|
_ = proc.Release()
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
func KillWithTimeout(pid int, timeout time.Duration) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
return Kill(ctx, pid)
|
|
}
|
|
|
|
func processCycle(cfg Config) {
|
|
log := logger.Default.WithPrefix("processCycle")
|
|
log.Info("Starting process cleanup cycle")
|
|
log.Debug("Timeout(ms): %d", cfg.Timeout.Milliseconds())
|
|
|
|
procmap, err := BuildProcessMap()
|
|
if err != nil {
|
|
log.Error("BuildProcessMap failed: %v", err)
|
|
return
|
|
}
|
|
|
|
if len(cfg.Forbidden) == 0 {
|
|
log.Warning("No forbidden processes defined; skipping")
|
|
return
|
|
}
|
|
|
|
// Parallel pass through forbidden names
|
|
utils.WithWorkers(cfg.Workers, cfg.Forbidden, func(worker int, name string) {
|
|
name = strings.TrimSpace(name)
|
|
ilog := log.WithPrefix("forbidden").WithPrefix(name).WithField("worker", worker)
|
|
|
|
if name == "" {
|
|
ilog.Warning("Empty name, skipping")
|
|
return
|
|
}
|
|
|
|
ilog.Debug("Searching processes by name")
|
|
ilog.Trace("Query: %s", name)
|
|
|
|
results, ok := procmap.findByName(name)
|
|
if !ok || len(results) == 0 {
|
|
ilog.Info("No matching processes found")
|
|
return
|
|
}
|
|
|
|
ilog.Info("Found %d matching processes", len(results))
|
|
for _, node := range results {
|
|
pid := int(node.Proc.ProcessID)
|
|
plog := ilog.WithPrefix("pid").WithPrefix(strconv.Itoa(pid))
|
|
|
|
plog.Debug("Killing with SIGKILL (-9)")
|
|
plog.Trace("PID: %d", pid)
|
|
|
|
if err := KillWithTimeout(pid, cfg.Timeout); err != nil {
|
|
plog.Error("Kill failed: %v", err)
|
|
} else {
|
|
plog.Info("Process killed")
|
|
}
|
|
}
|
|
})
|
|
log.Info("Process cleanup cycle complete")
|
|
}
|
|
|
|
func main() {
|
|
app := logger.Default.WithPrefix("main")
|
|
app.Info("Starting hitman (no-questions-asked)")
|
|
|
|
cfg := loadConfig()
|
|
|
|
// Initial cycle
|
|
processCycle(cfg)
|
|
|
|
// Ticker loop
|
|
t := time.NewTicker(cfg.ScanInterval)
|
|
defer t.Stop()
|
|
|
|
for {
|
|
ts := <-t.C
|
|
tlog := app.WithPrefix("tick").WithPrefix(strconv.FormatInt(ts.UnixMilli(), 10))
|
|
tlog.Info("Timer tick")
|
|
tlog.Trace("Timestamp(ms): %d", ts.UnixMilli())
|
|
|
|
processCycle(cfg)
|
|
}
|
|
}
|