diff --git a/main.go b/main.go index fc49eb4..c3771c3 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io" "log" @@ -38,7 +39,8 @@ func main() { log.Printf("Please set to a comma separated list of process names to forbid") return } - delay := time.Duration(3) * time.Second + + delay := time.Duration(2) * time.Second scanDelay, exists := os.LookupEnv("HITMAN_SCAN_DELAY") if !exists { log.Printf("No scan delay is set, defaulting to %vs", delay.Seconds()) @@ -52,7 +54,22 @@ func main() { } } + timeout := delay / 2 + etimeout, exists := os.LookupEnv("HITMAN_TIMEOUT") + if !exists { + log.Printf("No timeout is set, defaulting to %vs", timeout.Seconds()) + log.Printf("Set HITMAN_TIMEOUT to change this") + } else { + var err error + timeout, err = time.ParseDuration(etimeout) + if err != nil { + Error.Printf("Error parsing timeout: %v", err) + return + } + } + procs := strings.Split(forbidden, ",") + workers := make(chan struct{}, 10) for { log.Printf("Running") @@ -63,36 +80,57 @@ func main() { } for _, proc := range procs { - log.Printf("Checking %s", proc) - res, ok := procmap.findByName(proc) - if ok { - log.Printf("Forbidden process %s found (x%d)", proc, len(res)) - for _, node := range res { - log.Printf("Killing forbidden process %d", node.Proc.ProcessID) - err := Kill(node.Proc.ProcessID) - if err != nil { - Error.Printf("Error terminating process %d: %v", node.Proc.ProcessID, err) + workers <- struct{}{} + go func(proc string) { + defer func() { <-workers }() + log.Printf("Checking %s", proc) + res, ok := procmap.findByName(proc) + if ok { + log.Printf("Forbidden process %s found (x%d)", proc, len(res)) + for _, node := range res { + log.Printf("Killing forbidden process %d", node.Proc.ProcessID) + err := KillWithTimeout(int(node.Proc.ProcessID), timeout) + if err != nil { + Error.Printf("Error terminating process %d: %v", node.Proc.ProcessID, err) + } } + } else { + log.Printf("No forbidden process %s found", proc) } - } else { - log.Printf("No forbidden process %s found", proc) - } + }(proc) } time.Sleep(delay) } } -func Kill(pid uint32) error { - handle, err := syscall.OpenProcess(syscall.PROCESS_TERMINATE, false, uint32(pid)) - if err != nil { - return fmt.Errorf("error opening process: %v", err) - } - defer syscall.CloseHandle(handle) - - err = syscall.TerminateProcess(handle, 7172) - if err != nil { - return fmt.Errorf("error terminating process: %v", err) - } - - return nil +func KillWithTimeout(pid int, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return Kill(ctx, pid) +} + +func Kill(ctx context.Context, pid int) error { + process, err := os.FindProcess(pid) + if err != nil { + return err + } + + err = process.Signal(syscall.SIGKILL) + if err != nil { + return err + } + + done := make(chan error, 1) + go func() { + _, err := process.Wait() + done <- err + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + process.Release() + return ctx.Err() + } } diff --git a/procmap.go b/procmap.go index 305f83e..cb11929 100644 --- a/procmap.go +++ b/procmap.go @@ -83,7 +83,7 @@ func BuildProcessMap() (*ProcessMap, error) { for err == nil { tree.add(&pe32) i++ - if i > 500 { + if i > 5000 { break } err = windows.Process32Next(snapshot, &pe32)