Implement some sort of timeout

This commit is contained in:
2024-10-01 23:13:28 +02:00
parent 69e053d5da
commit 9196bda1b1
2 changed files with 65 additions and 27 deletions

90
main.go
View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log" "log"
@@ -38,7 +39,8 @@ func main() {
log.Printf("Please set to a comma separated list of process names to forbid") log.Printf("Please set to a comma separated list of process names to forbid")
return return
} }
delay := time.Duration(3) * time.Second
delay := time.Duration(2) * time.Second
scanDelay, exists := os.LookupEnv("HITMAN_SCAN_DELAY") scanDelay, exists := os.LookupEnv("HITMAN_SCAN_DELAY")
if !exists { if !exists {
log.Printf("No scan delay is set, defaulting to %vs", delay.Seconds()) log.Printf("No scan delay is set, defaulting to %vs", delay.Seconds())
@@ -52,7 +54,22 @@ func main() {
} }
} }
timeout := delay / 2
etimeout, exists := os.LookupEnv("HITMAN_TIMEOUT")
if !exists {
log.Printf("No timeout is set, defaulting to %vs", timeout.Seconds())
log.Printf("Set HITMAN_TIMEOUT to change this")
} else {
var err error
timeout, err = time.ParseDuration(etimeout)
if err != nil {
Error.Printf("Error parsing timeout: %v", err)
return
}
}
procs := strings.Split(forbidden, ",") procs := strings.Split(forbidden, ",")
workers := make(chan struct{}, 10)
for { for {
log.Printf("Running") log.Printf("Running")
@@ -63,36 +80,57 @@ func main() {
} }
for _, proc := range procs { for _, proc := range procs {
log.Printf("Checking %s", proc) workers <- struct{}{}
res, ok := procmap.findByName(proc) go func(proc string) {
if ok { defer func() { <-workers }()
log.Printf("Forbidden process %s found (x%d)", proc, len(res)) log.Printf("Checking %s", proc)
for _, node := range res { res, ok := procmap.findByName(proc)
log.Printf("Killing forbidden process %d", node.Proc.ProcessID) if ok {
err := Kill(node.Proc.ProcessID) log.Printf("Forbidden process %s found (x%d)", proc, len(res))
if err != nil { for _, node := range res {
Error.Printf("Error terminating process %d: %v", node.Proc.ProcessID, err) log.Printf("Killing forbidden process %d", node.Proc.ProcessID)
err := KillWithTimeout(int(node.Proc.ProcessID), timeout)
if err != nil {
Error.Printf("Error terminating process %d: %v", node.Proc.ProcessID, err)
}
} }
} else {
log.Printf("No forbidden process %s found", proc)
} }
} else { }(proc)
log.Printf("No forbidden process %s found", proc)
}
} }
time.Sleep(delay) time.Sleep(delay)
} }
} }
func Kill(pid uint32) error { func KillWithTimeout(pid int, timeout time.Duration) error {
handle, err := syscall.OpenProcess(syscall.PROCESS_TERMINATE, false, uint32(pid)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
if err != nil { defer cancel()
return fmt.Errorf("error opening process: %v", err) return Kill(ctx, pid)
} }
defer syscall.CloseHandle(handle)
func Kill(ctx context.Context, pid int) error {
err = syscall.TerminateProcess(handle, 7172) process, err := os.FindProcess(pid)
if err != nil { if err != nil {
return fmt.Errorf("error terminating process: %v", err) return err
} }
return nil err = process.Signal(syscall.SIGKILL)
if err != nil {
return err
}
done := make(chan error, 1)
go func() {
_, err := process.Wait()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
process.Release()
return ctx.Err()
}
} }

View File

@@ -83,7 +83,7 @@ func BuildProcessMap() (*ProcessMap, error) {
for err == nil { for err == nil {
tree.add(&pe32) tree.add(&pe32)
i++ i++
if i > 500 { if i > 5000 {
break break
} }
err = windows.Process32Next(snapshot, &pe32) err = windows.Process32Next(snapshot, &pe32)