feat: ja4-platform monorepo — 5 services unified, tests & RPM builds standardized

Services:
- ja4sentinel: TLS/JA4 fingerprint capture daemon (Go, libpcap)
- logcorrelator: JA4 log correlation engine (Go, ClickHouse)
- mod_reqin_log: Apache module (C, JSON request logging)
- bot_detector: ML bot detection pipeline (Python)
- dashboard: FastAPI/Streamlit analytics UI (Python)

Shared libraries:
- shared/go/ja4common: logger, config, shutdown, ipfilter (Go module)
- shared/python/ja4_common: ClickHouseClient, ClickHouseSettings (Python package)
- shared/clickhouse/: canonical SQL migrations (10 files)

Build & packaging:
- Unified 3-stage Dockerfile.package for Go RPMs (el8/el9/el10)
- go.work workspace linking sentinel, correlator, ja4common
- Makefile with test-all, build-all, rpm-* targets

Fixes applied:
- go.work: 1.21 → 1.24.6 (required by sentinel)
- correlator Dockerfiles: golang:1.21 → golang:1.24
- replace directives in go.mod for ja4common local path
- pyproject.toml: setuptools.backends → setuptools.build_meta
- Removed static libpcap linking (unavailable on Rocky 9)
- Fixed data races in output/writers_test.go (sync.Mutex + atomic.Int32)
- Rewrote corrupted test files (logger_test.go × 2)

Test coverage:
- correlator: 67.1% total (unixsocket 80.5%, config 91.7%, app 83.3%, multi 87.7%, stdout 100%)
- sentinel: all 10 packages pass (api, capture, config, fingerprint, ipfilter, logging, output, tlsparse)

Documentation:
- README.md + docs/ (architecture, development, 5 services, shared libs, DB schema & migrations)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
toto
2026-04-07 16:42:59 +02:00
commit d469e39da7
278 changed files with 1621301 additions and 0 deletions

View File

@ -0,0 +1,114 @@
// Package config provides generic YAML config loading with env override support.
package config
import (
"fmt"
"os"
"reflect"
"strconv"
"strings"
"gopkg.in/yaml.v3"
)
// LoadYAML reads a YAML file at path and unmarshals it into T.
// If path is empty or the file does not exist and optional is true, the zero value of T is returned.
func LoadYAML[T any](path string, optional bool) (T, error) {
var zero T
if path == "" {
if optional {
return zero, nil
}
return zero, fmt.Errorf("config path is empty")
}
data, err := os.ReadFile(path)
if err != nil {
if optional && os.IsNotExist(err) {
return zero, nil
}
return zero, fmt.Errorf("reading config file %q: %w", path, err)
}
var cfg T
if err := yaml.Unmarshal(data, &cfg); err != nil {
return zero, fmt.Errorf("parsing config file %q: %w", path, err)
}
return cfg, nil
}
// OverrideFromEnv applies environment variable overrides to a struct using struct tags.
// Tag format: env:"ENV_VAR_NAME"
// Supports field types: string, int, bool, []string (comma-separated).
// envPrefix is prepended to tag values if non-empty (e.g. envPrefix="APP" + tag="PORT" → "APP_PORT").
func OverrideFromEnv[T any](cfg *T, envPrefix string) error {
return overrideStruct(reflect.ValueOf(cfg).Elem(), envPrefix)
}
func overrideStruct(v reflect.Value, envPrefix string) error {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fv := v.Field(i)
if !fv.CanSet() {
continue
}
// Recurse into embedded/nested structs
if fv.Kind() == reflect.Struct {
if err := overrideStruct(fv, envPrefix); err != nil {
return err
}
continue
}
envTag := field.Tag.Get("env")
if envTag == "" {
continue
}
envKey := envTag
if envPrefix != "" {
envKey = envPrefix + "_" + envTag
}
val := os.Getenv(envKey)
if val == "" {
continue
}
switch fv.Kind() {
case reflect.String:
fv.SetString(val)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return fmt.Errorf("env %s: cannot parse %q as int: %w", envKey, val, err)
}
fv.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return fmt.Errorf("env %s: cannot parse %q as uint: %w", envKey, val, err)
}
fv.SetUint(n)
case reflect.Bool:
b, err := strconv.ParseBool(val)
if err != nil {
return fmt.Errorf("env %s: cannot parse %q as bool: %w", envKey, val, err)
}
fv.SetBool(b)
case reflect.Slice:
if fv.Type().Elem().Kind() == reflect.String {
parts := strings.Split(val, ",")
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}
fv.Set(reflect.ValueOf(parts))
}
}
}
return nil
}

View File

@ -0,0 +1,139 @@
package config
import (
"os"
"path/filepath"
"testing"
)
type testConfig struct {
Host string `yaml:"host" env:"HOST"`
Port int `yaml:"port" env:"PORT"`
TLS bool `yaml:"tls" env:"TLS"`
Tags []string `yaml:"tags" env:"TAGS"`
}
func TestLoadYAML_Basic(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.yml")
content := `
host: myhost
port: 9000
tls: true
tags:
- a
- b
`
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
t.Fatal(err)
}
cfg, err := LoadYAML[testConfig](path, false)
if err != nil {
t.Fatalf("LoadYAML error: %v", err)
}
if cfg.Host != "myhost" {
t.Errorf("Host = %q, want %q", cfg.Host, "myhost")
}
if cfg.Port != 9000 {
t.Errorf("Port = %d, want 9000", cfg.Port)
}
if !cfg.TLS {
t.Error("TLS should be true")
}
if len(cfg.Tags) != 2 {
t.Errorf("Tags len = %d, want 2", len(cfg.Tags))
}
}
func TestLoadYAML_Optional_MissingFile(t *testing.T) {
cfg, err := LoadYAML[testConfig]("/nonexistent/path.yml", true)
if err != nil {
t.Fatalf("optional missing file should not error: %v", err)
}
if cfg.Host != "" {
t.Errorf("zero value expected, got host=%q", cfg.Host)
}
}
func TestLoadYAML_Required_MissingFile(t *testing.T) {
_, err := LoadYAML[testConfig]("/nonexistent/path.yml", false)
if err == nil {
t.Error("expected error for missing required file")
}
}
func TestLoadYAML_EmptyPath_Optional(t *testing.T) {
cfg, err := LoadYAML[testConfig]("", true)
if err != nil {
t.Fatalf("empty optional path should not error: %v", err)
}
_ = cfg
}
func TestOverrideFromEnv_String(t *testing.T) {
t.Setenv("HOST", "envhost")
cfg := testConfig{Host: "original"}
if err := OverrideFromEnv(&cfg, ""); err != nil {
t.Fatal(err)
}
if cfg.Host != "envhost" {
t.Errorf("Host = %q, want envhost", cfg.Host)
}
}
func TestOverrideFromEnv_Int(t *testing.T) {
t.Setenv("PORT", "8080")
cfg := testConfig{Port: 1234}
if err := OverrideFromEnv(&cfg, ""); err != nil {
t.Fatal(err)
}
if cfg.Port != 8080 {
t.Errorf("Port = %d, want 8080", cfg.Port)
}
}
func TestOverrideFromEnv_Bool(t *testing.T) {
t.Setenv("TLS", "false")
cfg := testConfig{TLS: true}
if err := OverrideFromEnv(&cfg, ""); err != nil {
t.Fatal(err)
}
if cfg.TLS {
t.Error("TLS should be false after env override")
}
}
func TestOverrideFromEnv_Slice(t *testing.T) {
t.Setenv("TAGS", "x,y,z")
cfg := testConfig{}
if err := OverrideFromEnv(&cfg, ""); err != nil {
t.Fatal(err)
}
if len(cfg.Tags) != 3 || cfg.Tags[0] != "x" {
t.Errorf("Tags = %v, want [x y z]", cfg.Tags)
}
}
func TestOverrideFromEnv_WithPrefix(t *testing.T) {
t.Setenv("APP_HOST", "prefixed")
cfg := testConfig{Host: "original"}
if err := OverrideFromEnv(&cfg, "APP"); err != nil {
t.Fatal(err)
}
if cfg.Host != "prefixed" {
t.Errorf("Host = %q, want prefixed", cfg.Host)
}
}
func TestOverrideFromEnv_NoEnvSet_NoChange(t *testing.T) {
os.Unsetenv("HOST")
os.Unsetenv("PORT")
cfg := testConfig{Host: "keep", Port: 42}
if err := OverrideFromEnv(&cfg, ""); err != nil {
t.Fatal(err)
}
if cfg.Host != "keep" || cfg.Port != 42 {
t.Errorf("unset env should not change values: host=%q port=%d", cfg.Host, cfg.Port)
}
}

View File

@ -0,0 +1,9 @@
module github.com/antitbone/ja4/ja4common
go 1.21
require (
gopkg.in/yaml.v3 v3.0.1
)
require gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect

View File

@ -0,0 +1,84 @@
// Package ipfilter provides IP address and CIDR range matching for ja4-platform services.
package ipfilter
import (
"fmt"
"net"
"sync"
)
// Filter checks if an IP address should be excluded based on a list of IPs or CIDR ranges
type Filter struct {
mu sync.RWMutex
networks []*net.IPNet
ips []net.IP
}
// New creates a new IP filter from a list of IP addresses or CIDR ranges
// Accepts formats like: "192.168.1.1", "10.0.0.0/8", "2001:db8::/32"
func New(excludeList []string) (*Filter, error) {
f := &Filter{
networks: make([]*net.IPNet, 0),
ips: make([]net.IP, 0),
}
for _, entry := range excludeList {
if entry == "" {
continue
}
// Try parsing as CIDR first
if _, ipNet, err := net.ParseCIDR(entry); err == nil {
f.networks = append(f.networks, ipNet)
continue
}
// Try parsing as single IP
if ip := net.ParseIP(entry); ip != nil {
f.ips = append(f.ips, ip)
continue
}
return nil, fmt.Errorf("invalid IP or CIDR: %s", entry)
}
return f, nil
}
// ShouldExclude checks if an IP address should be excluded
func (f *Filter) ShouldExclude(ipStr string) bool {
if f == nil {
return false
}
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
f.mu.RLock()
defer f.mu.RUnlock()
// Check against single IPs
for _, filterIP := range f.ips {
if ip.Equal(filterIP) {
return true
}
}
// Check against CIDR ranges
for _, network := range f.networks {
if network.Contains(ip) {
return true
}
}
return false
}
// Count returns the number of loaded filter entries
func (f *Filter) Count() (ips int, networks int) {
f.mu.RLock()
defer f.mu.RUnlock()
return len(f.ips), len(f.networks)
}

View File

@ -0,0 +1,160 @@
package ipfilter
import (
"testing"
)
func TestFilter_New(t *testing.T) {
tests := []struct {
name string
list []string
wantErr bool
}{
{
name: "empty list",
list: []string{},
wantErr: false,
},
{
name: "single IP",
list: []string{"192.168.1.1"},
wantErr: false,
},
{
name: "single CIDR",
list: []string{"10.0.0.0/8"},
wantErr: false,
},
{
name: "mixed IPs and CIDRs",
list: []string{"192.168.1.1", "10.0.0.0/8", "172.16.0.0/12"},
wantErr: false,
},
{
name: "invalid IP",
list: []string{"999.999.999.999"},
wantErr: true,
},
{
name: "invalid CIDR",
list: []string{"10.0.0.0/33"},
wantErr: true,
},
{
name: "IPv6 address",
list: []string{"2001:db8::1"},
wantErr: false,
},
{
name: "IPv6 CIDR",
list: []string{"2001:db8::/32"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f, err := New(tt.list)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && f == nil {
t.Error("New() should return non-nil filter on success")
}
})
}
}
func TestFilter_ShouldExclude(t *testing.T) {
f, err := New([]string{
"192.168.1.1",
"10.0.0.0/8",
"172.16.0.0/12",
"2001:db8::1",
"fc00::/7",
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
tests := []struct {
name string
ip string
want bool
}{
// Exact IP matches
{"exact match", "192.168.1.1", true},
{"exact IPv6 match", "2001:db8::1", true},
// CIDR matches
{"CIDR match 10.0.0.1", "10.0.0.1", true},
{"CIDR match 10.255.255.255", "10.255.255.255", true},
{"CIDR match 172.16.0.1", "172.16.0.1", true},
{"CIDR match 172.31.255.255", "172.31.255.255", true},
{"CIDR IPv6 match", "fc00::1", true},
// No matches
{"no match 192.168.2.1", "192.168.2.1", false},
{"no match 11.0.0.1", "11.0.0.1", false},
{"no match 172.32.0.1", "172.32.0.1", false},
{"no match 8.8.8.8", "8.8.8.8", false},
// Invalid IP
{"invalid IP", "invalid", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := f.ShouldExclude(tt.ip); got != tt.want {
t.Errorf("ShouldExclude(%q) = %v, want %v", tt.ip, got, tt.want)
}
})
}
}
func TestFilter_ShouldExclude_NilFilter(t *testing.T) {
var f *Filter
if f.ShouldExclude("192.168.1.1") {
t.Error("ShouldExclude on nil filter should return false")
}
}
func TestFilter_Count(t *testing.T) {
f, err := New([]string{
"192.168.1.1",
"10.0.0.1",
"10.0.0.0/8",
"172.16.0.0/12",
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
ips, networks := f.Count()
if ips != 2 {
t.Errorf("Count() ips = %d, want 2", ips)
}
if networks != 2 {
t.Errorf("Count() networks = %d, want 2", networks)
}
}
func TestFilter_EmptyEntries(t *testing.T) {
f, err := New([]string{"", "192.168.1.1", ""})
if err != nil {
t.Fatalf("New() error = %v", err)
}
ips, _ := f.Count()
if ips != 1 {
t.Errorf("Count() ips = %d, want 1 (empty entries should be skipped)", ips)
}
if !f.ShouldExclude("192.168.1.1") {
t.Error("Should exclude 192.168.1.1")
}
if f.ShouldExclude("192.168.1.2") {
t.Error("Should not exclude 192.168.1.2")
}
}

View File

@ -0,0 +1,263 @@
// Package logger provides unified structured logging for the ja4-platform services.
// It merges the component-based logger from sentinel and the prefix/fields-based
// logger from correlator into a single implementation.
package logger
import (
"fmt"
"log"
"os"
"sort"
"strings"
"sync"
)
// LogLevel represents the severity of a log message.
type LogLevel int
const (
DEBUG LogLevel = iota
INFO
WARN
ERROR
)
// ParseLogLevel converts a string to LogLevel.
func ParseLogLevel(level string) LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return DEBUG
case "INFO":
return INFO
case "WARN", "WARNING":
return WARN
case "ERROR":
return ERROR
default:
return INFO
}
}
// String returns the string representation of a LogLevel.
func (l LogLevel) String() string {
switch l {
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARN:
return "WARN"
case ERROR:
return "ERROR"
default:
return "INFO"
}
}
// Logger provides structured prefix+fields-based logging (correlator style).
type Logger struct {
mu sync.RWMutex
logger *log.Logger
prefix string
fields map[string]any
minLevel LogLevel
}
// New creates a new Logger with INFO level.
func New(prefix string) *Logger {
return &Logger{
logger: log.New(os.Stderr, "", log.LstdFlags|log.Lmicroseconds),
prefix: prefix,
fields: make(map[string]any),
minLevel: INFO,
}
}
// NewWithLevel creates a new Logger with the specified minimum level.
func NewWithLevel(prefix string, level string) *Logger {
return &Logger{
logger: log.New(os.Stderr, "", log.LstdFlags|log.Lmicroseconds),
prefix: prefix,
fields: make(map[string]any),
minLevel: ParseLogLevel(level),
}
}
// SetLevel sets the minimum log level.
func (l *Logger) SetLevel(level string) {
l.mu.Lock()
defer l.mu.Unlock()
l.minLevel = ParseLogLevel(level)
}
// ShouldLog returns true if the given level should be logged.
func (l *Logger) ShouldLog(level LogLevel) bool {
l.mu.RLock()
defer l.mu.RUnlock()
return level >= l.minLevel
}
// WithFields returns a new Logger with additional structured fields.
func (l *Logger) WithFields(fields map[string]any) *Logger {
l.mu.RLock()
minLevel := l.minLevel
prefix := l.prefix
existing := make(map[string]any, len(l.fields))
for k, v := range l.fields {
existing[k] = v
}
l.mu.RUnlock()
for k, v := range fields {
existing[k] = v
}
return &Logger{
logger: l.logger,
prefix: prefix,
fields: existing,
minLevel: minLevel,
}
}
// Info logs an info message.
func (l *Logger) Info(msg string) {
if l.ShouldLog(INFO) {
l.emit("INFO", msg)
}
}
// Infof logs a formatted info message.
func (l *Logger) Infof(msg string, args ...any) {
if l.ShouldLog(INFO) {
l.emit("INFO", fmt.Sprintf(msg, args...))
}
}
// Warn logs a warning message.
func (l *Logger) Warn(msg string) {
if l.ShouldLog(WARN) {
l.emit("WARN", msg)
}
}
// Warnf logs a formatted warning message.
func (l *Logger) Warnf(msg string, args ...any) {
if l.ShouldLog(WARN) {
l.emit("WARN", fmt.Sprintf(msg, args...))
}
}
// Error logs an error message with an optional error value.
func (l *Logger) Error(msg string, err error) {
if !l.ShouldLog(ERROR) {
return
}
if err != nil {
l.emit("ERROR", msg+" "+err.Error())
} else {
l.emit("ERROR", msg)
}
}
// Debug logs a debug message.
func (l *Logger) Debug(msg string) {
if l.ShouldLog(DEBUG) {
l.emit("DEBUG", msg)
}
}
// Debugf logs a formatted debug message.
func (l *Logger) Debugf(msg string, args ...any) {
if l.ShouldLog(DEBUG) {
l.emit("DEBUG", fmt.Sprintf(msg, args...))
}
}
func (l *Logger) emit(level, msg string) {
l.mu.RLock()
prefix := l.prefix
fields := make(map[string]any, len(l.fields))
for k, v := range l.fields {
fields[k] = v
}
l.mu.RUnlock()
var b strings.Builder
if prefix != "" {
b.WriteString("[")
b.WriteString(prefix)
b.WriteString("] ")
}
b.WriteString(level)
b.WriteString(" ")
b.WriteString(msg)
if len(fields) > 0 {
keys := make([]string, 0, len(fields))
for k := range fields {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
b.WriteString(" ")
b.WriteString(k)
b.WriteString("=")
b.WriteString(fmt.Sprintf("%v", fields[k]))
}
}
l.logger.Print(b.String())
}
// ComponentLogger wraps Logger to satisfy the sentinel-style component-based Logger interface.
// This allows new services to use ja4common while sentinel's existing api.Logger interface
// is still satisfied.
type ComponentLogger struct {
*Logger
}
// NewComponentLogger creates a ComponentLogger with the specified log level.
func NewComponentLogger(level string) *ComponentLogger {
return &ComponentLogger{Logger: NewWithLevel("", level)}
}
// Log emits a structured log entry for the given component.
func (c *ComponentLogger) Log(component, level, message string, details map[string]string) {
fields := make(map[string]any, len(details)+1)
fields["component"] = component
for k, v := range details {
fields[k] = v
}
cl := c.Logger.WithFields(fields)
switch strings.ToLower(level) {
case "debug":
cl.Debug(message)
case "warn", "warning":
cl.Warn(message)
case "error":
cl.Error(message, nil)
default:
cl.Info(message)
}
}
// Debug logs a debug entry for the given component.
func (c *ComponentLogger) Debug(component, message string, details map[string]string) {
c.Log(component, "debug", message, details)
}
// Info logs an info entry for the given component.
func (c *ComponentLogger) Info(component, message string, details map[string]string) {
c.Log(component, "info", message, details)
}
// Warn logs a warning entry for the given component.
func (c *ComponentLogger) Warn(component, message string, details map[string]string) {
c.Log(component, "warn", message, details)
}
// Error logs an error entry for the given component.
func (c *ComponentLogger) Error(component, message string, details map[string]string) {
c.Log(component, "error", message, details)
}

View File

@ -0,0 +1,139 @@
package logger
import (
"strings"
"testing"
)
func TestParseLogLevel(t *testing.T) {
tests := []struct {
input string
want LogLevel
}{
{"debug", DEBUG},
{"DEBUG", DEBUG},
{"info", INFO},
{"INFO", INFO},
{"warn", WARN},
{"WARN", WARN},
{"warning", WARN},
{"WARNING", WARN},
{"error", ERROR},
{"ERROR", ERROR},
{"invalid", INFO},
{"", INFO},
}
for _, tt := range tests {
got := ParseLogLevel(tt.input)
if got != tt.want {
t.Errorf("ParseLogLevel(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
func TestLogger_LevelFiltering(t *testing.T) {
tests := []struct {
loggerLevel string
logLevel LogLevel
shouldLog bool
}{
{"debug", DEBUG, true},
{"debug", INFO, true},
{"info", DEBUG, false},
{"info", INFO, true},
{"warn", INFO, false},
{"warn", WARN, true},
{"error", WARN, false},
{"error", ERROR, true},
}
for _, tt := range tests {
l := NewWithLevel("test", tt.loggerLevel)
got := l.ShouldLog(tt.logLevel)
if got != tt.shouldLog {
t.Errorf("level=%s ShouldLog(%v)=%v want %v", tt.loggerLevel, tt.logLevel, got, tt.shouldLog)
}
}
}
func TestLogger_WithFields(t *testing.T) {
l := New("test")
l2 := l.WithFields(map[string]any{"key": "value", "n": 42})
if l2 == l {
t.Error("WithFields should return a new Logger")
}
if len(l2.fields) != 2 {
t.Errorf("expected 2 fields, got %d", len(l2.fields))
}
}
func TestLogger_SetLevel(t *testing.T) {
l := New("test")
if l.minLevel != INFO {
t.Errorf("default level should be INFO, got %v", l.minLevel)
}
l.SetLevel("debug")
if l.minLevel != DEBUG {
t.Errorf("level after SetLevel(debug) should be DEBUG, got %v", l.minLevel)
}
}
func TestComponentLogger_Interface(t *testing.T) {
cl := NewComponentLogger("debug")
// Verify it implements the component-based interface by calling all methods
cl.Debug("component", "debug msg", nil)
cl.Info("component", "info msg", map[string]string{"key": "val"})
cl.Warn("component", "warn msg", nil)
cl.Error("component", "error msg", map[string]string{"err": "test"})
cl.Log("component", "info", "log msg", nil)
}
func TestComponentLogger_LevelFiltering(t *testing.T) {
cl := NewComponentLogger("warn")
// At warn level, debug and info should be filtered
if cl.Logger.ShouldLog(DEBUG) {
t.Error("DEBUG should be filtered at warn level")
}
if cl.Logger.ShouldLog(INFO) {
t.Error("INFO should be filtered at warn level")
}
if !cl.Logger.ShouldLog(WARN) {
t.Error("WARN should pass at warn level")
}
}
func TestLogger_LogLevelString(t *testing.T) {
tests := []struct {
level LogLevel
want string
}{
{DEBUG, "DEBUG"},
{INFO, "INFO"},
{WARN, "WARN"},
{ERROR, "ERROR"},
}
for _, tt := range tests {
if got := tt.level.String(); got != tt.want {
t.Errorf("%v.String() = %q, want %q", tt.level, got, tt.want)
}
}
}
func TestLogger_EmitContainsLevel(t *testing.T) {
// Use a custom logger that captures output
var buf strings.Builder
l := New("myservice")
l.logger.SetOutput(&buf)
l.SetLevel("debug")
l.Info("hello from info")
if !strings.Contains(buf.String(), "INFO") {
t.Errorf("expected INFO in output, got: %s", buf.String())
}
buf.Reset()
l.Debug("hello from debug")
if !strings.Contains(buf.String(), "DEBUG") {
t.Errorf("expected DEBUG in output, got: %s", buf.String())
}
}

View File

@ -0,0 +1,45 @@
// Package shutdown provides graceful shutdown handling for services.
package shutdown
import (
"context"
"os"
"os/signal"
"syscall"
)
// Hook is a cleanup function called during shutdown.
type Hook struct {
Name string
Fn func() error
}
// simpleLogger is the minimal interface required from a logger.
type simpleLogger interface {
Info(string)
Error(string, error)
}
// Handle blocks until SIGTERM or SIGINT is received, then cancels the context
// and runs hooks in order. Each hook error is logged but does not abort remaining hooks.
func Handle(ctx context.Context, cancel context.CancelFunc, hooks []Hook, logger simpleLogger) {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(quit)
select {
case sig := <-quit:
logger.Info("shutdown signal received: " + sig.String())
case <-ctx.Done():
logger.Info("context cancelled, shutting down")
}
cancel()
for _, h := range hooks {
logger.Info("running shutdown hook: " + h.Name)
if err := h.Fn(); err != nil {
logger.Error("shutdown hook "+h.Name+" failed", err)
}
}
}

View File

@ -0,0 +1,133 @@
package shutdown
import (
"context"
"errors"
"sync/atomic"
"syscall"
"testing"
"time"
)
type mockLogger struct {
infoMsgs []string
errorMsgs []string
}
func (m *mockLogger) Info(msg string) { m.infoMsgs = append(m.infoMsgs, msg) }
func (m *mockLogger) Error(msg string, _ error) { m.errorMsgs = append(m.errorMsgs, msg) }
func TestHandle_RunsHooks(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
logger := &mockLogger{}
var hookCalled int32
hooks := []Hook{
{
Name: "test-hook",
Fn: func() error {
atomic.StoreInt32(&hookCalled, 1)
return nil
},
},
}
done := make(chan struct{})
go func() {
Handle(ctx, cancel, hooks, logger)
close(done)
}()
// Send SIGTERM to trigger shutdown
time.Sleep(50 * time.Millisecond)
p, _ := syscall.Getpid(), 0
syscall.Kill(p, syscall.SIGTERM)
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("Handle did not return within timeout")
}
if atomic.LoadInt32(&hookCalled) != 1 {
t.Error("hook was not called")
}
}
func TestHandle_HookError_ContinuesOtherHooks(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
logger := &mockLogger{}
var secondCalled int32
hooks := []Hook{
{
Name: "failing-hook",
Fn: func() error { return errors.New("hook error") },
},
{
Name: "second-hook",
Fn: func() error {
atomic.StoreInt32(&secondCalled, 1)
return nil
},
},
}
done := make(chan struct{})
go func() {
Handle(ctx, cancel, hooks, logger)
close(done)
}()
time.Sleep(50 * time.Millisecond)
syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("Handle did not return within timeout")
}
if atomic.LoadInt32(&secondCalled) != 1 {
t.Error("second hook should still run after first hook error")
}
if len(logger.errorMsgs) == 0 {
t.Error("error should be logged for failing hook")
}
}
func TestHandle_ContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
logger := &mockLogger{}
var hookCalled int32
hooks := []Hook{
{
Name: "ctx-hook",
Fn: func() error {
atomic.StoreInt32(&hookCalled, 1)
return nil
},
},
}
done := make(chan struct{})
go func() {
Handle(ctx, cancel, hooks, logger)
close(done)
}()
// Cancel context directly instead of sending signal
time.Sleep(50 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("Handle did not return within timeout after context cancel")
}
if atomic.LoadInt32(&hookCalled) != 1 {
t.Error("hook should run on context cancel")
}
}