Files
hitman/main.go

137 lines
3.2 KiB
Go

package main
import (
"context"
"fmt"
"io"
"log"
"os"
"strings"
"syscall"
"time"
)
var Error *log.Logger
var Warning *log.Logger
func init() {
log.SetFlags(log.Lmicroseconds | log.Lshortfile)
logFile, err := os.Create("main.log")
if err != nil {
log.Printf("Error creating log file: %v", err)
os.Exit(1)
}
logger := io.MultiWriter(os.Stdout, logFile)
log.SetOutput(logger)
Error = log.New(io.MultiWriter(logFile, os.Stderr, os.Stdout),
fmt.Sprintf("%sERROR:%s ", "\033[0;101m", "\033[0m"),
log.Lmicroseconds|log.Lshortfile)
Warning = log.New(io.MultiWriter(logFile, os.Stdout),
fmt.Sprintf("%sWarning:%s ", "\033[0;93m", "\033[0m"),
log.Lmicroseconds|log.Lshortfile)
}
func main() {
forbidden, exists := os.LookupEnv("HITMAN_FORBIDDEN")
if !exists {
Error.Println("HITMAN_FORBIDDEN environment variable not set")
log.Printf("Please set to a comma separated list of process names to forbid")
return
}
delay := time.Duration(2) * time.Second
scanDelay, exists := os.LookupEnv("HITMAN_SCAN_DELAY")
if !exists {
log.Printf("No scan delay is set, defaulting to %vs", delay.Seconds())
log.Printf("Set HITMAN_SCAN_DELAY to change this")
} else {
var err error
delay, err = time.ParseDuration(scanDelay)
if err != nil {
Error.Printf("Error parsing scan delay: %v", err)
return
}
}
timeout := delay / 2
etimeout, exists := os.LookupEnv("HITMAN_TIMEOUT")
if !exists {
log.Printf("No timeout is set, defaulting to %vs", timeout.Seconds())
log.Printf("Set HITMAN_TIMEOUT to change this")
} else {
var err error
timeout, err = time.ParseDuration(etimeout)
if err != nil {
Error.Printf("Error parsing timeout: %v", err)
return
}
}
procs := strings.Split(forbidden, ",")
workers := make(chan struct{}, 10)
for {
log.Printf("Running")
procmap, err := BuildProcessMap()
if err != nil {
Error.Printf("Error building process map: %v", err)
return
}
for _, proc := range procs {
workers <- struct{}{}
go func(proc string) {
defer func() { <-workers }()
log.Printf("Checking %s", proc)
res, ok := procmap.findByName(proc)
if ok {
log.Printf("Forbidden process %s found (x%d)", proc, len(res))
for _, node := range res {
log.Printf("Killing forbidden process %d", node.Proc.ProcessID)
err := KillWithTimeout(int(node.Proc.ProcessID), timeout)
if err != nil {
Error.Printf("Error terminating process %d: %v", node.Proc.ProcessID, err)
}
}
} else {
log.Printf("No forbidden process %s found", proc)
}
}(proc)
}
time.Sleep(delay)
}
}
func KillWithTimeout(pid int, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return Kill(ctx, pid)
}
func Kill(ctx context.Context, pid int) error {
process, err := os.FindProcess(pid)
if err != nil {
return err
}
err = process.Signal(syscall.SIGKILL)
if err != nil {
return err
}
done := make(chan error, 1)
go func() {
_, err := process.Wait()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
process.Release()
return ctx.Err()
}
}