fix: sécuriser shutdown, config par défaut et reconnexion socket

Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) <aider@aider.chat>
This commit is contained in:
Jacquin Antoine
2026-02-25 21:44:40 +01:00
parent 617ecd2014
commit 6cd6c4c3b8
11 changed files with 394 additions and 56 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"time" "time"
@ -83,23 +84,37 @@ func main() {
helloChan := make(chan api.TLSClientHello, 1000) helloChan := make(chan api.TLSClientHello, 1000)
errorChan := make(chan error, 100) errorChan := make(chan error, 100)
var wg sync.WaitGroup
// Setup signal handling for graceful shutdown // Setup signal handling for graceful shutdown
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Start capture goroutine // Start capture goroutine
wg.Add(1)
go func() { go func() {
defer wg.Done()
defer close(packetChan)
logger.Info("capture", "Starting packet capture", map[string]string{ logger.Info("capture", "Starting packet capture", map[string]string{
"interface": cfg.Core.Interface, "interface": cfg.Core.Interface,
}) })
err := captureImpl.Run(cfg.Core, packetChan) err := captureImpl.Run(cfg.Core, packetChan)
if err != nil { if err != nil {
errorChan <- fmt.Errorf("capture error: %w", err) select {
case errorChan <- fmt.Errorf("capture error: %w", err):
default:
}
} }
}() }()
// Start TLS parsing goroutine // Start TLS parsing goroutine
wg.Add(1)
go func() { go func() {
defer wg.Done()
defer close(helloChan)
for pkt := range packetChan { for pkt := range packetChan {
hello, err := parser.Process(pkt) hello, err := parser.Process(pkt)
if err != nil { if err != nil {
@ -121,7 +136,10 @@ func main() {
}() }()
// Start fingerprinting and output goroutine // Start fingerprinting and output goroutine
wg.Add(1)
go func() { go func() {
defer wg.Done()
for hello := range helloChan { for hello := range helloChan {
fingerprints, err := engine.FromClientHello(hello) fingerprints, err := engine.FromClientHello(hello)
if err != nil { if err != nil {
@ -162,6 +180,21 @@ func main() {
// Graceful shutdown // Graceful shutdown
logger.Info("service", "Shutting down", nil) logger.Info("service", "Shutting down", nil)
if err := captureImpl.Close(); err != nil {
logger.Error("capture", "Error closing capture", map[string]string{
"error": err.Error(),
})
}
wg.Wait()
// Close parser (stops cleanup goroutine)
if err := parser.Close(); err != nil {
logger.Error("tlsparse", "Error closing parser", map[string]string{
"error": err.Error(),
})
}
// Close output writer // Close output writer
if closer, ok := writer.(interface{ CloseAll() error }); ok { if closer, ok := writer.(interface{ CloseAll() error }); ok {
if err := closer.CloseAll(); err != nil { if err := closer.CloseAll(); err != nil {
@ -171,19 +204,5 @@ func main() {
} }
} }
// Close parser (stops cleanup goroutine)
if err := parser.Close(); err != nil {
logger.Error("tlsparse", "Error closing parser", map[string]string{
"error": err.Error(),
})
}
// Close capture
if err := captureImpl.Close(); err != nil {
logger.Error("capture", "Error closing capture", map[string]string{
"error": err.Error(),
})
}
logger.Info("service", "ja4sentinel stopped", nil) logger.Info("service", "ja4sentinel stopped", nil)
} }

View File

@ -3,6 +3,7 @@ package capture
import ( import (
"fmt" "fmt"
"sync"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/pcap" "github.com/google/gopacket/pcap"
@ -13,6 +14,7 @@ import (
// CaptureImpl implements the capture.Capture interface for packet capture // CaptureImpl implements the capture.Capture interface for packet capture
type CaptureImpl struct { type CaptureImpl struct {
handle *pcap.Handle handle *pcap.Handle
mu sync.Mutex
} }
// New creates a new capture instance // New creates a new capture instance
@ -22,29 +24,40 @@ func New() *CaptureImpl {
// Run starts network packet capture according to the configuration // Run starts network packet capture according to the configuration
func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error { func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error {
var err error handle, err := pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever)
c.handle, err = pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever)
if err != nil { if err != nil {
return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err) return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err)
} }
defer c.handle.Close()
c.mu.Lock()
c.handle = handle
c.mu.Unlock()
defer func() {
c.mu.Lock()
if c.handle != nil {
c.handle.Close()
c.handle = nil
}
c.mu.Unlock()
}()
// Apply BPF filter if provided // Apply BPF filter if provided
if cfg.BPFFilter != "" { if cfg.BPFFilter != "" {
err = c.handle.SetBPFFilter(cfg.BPFFilter) err = handle.SetBPFFilter(cfg.BPFFilter)
if err != nil { if err != nil {
return fmt.Errorf("failed to set BPF filter: %w", err) return fmt.Errorf("failed to set BPF filter: %w", err)
} }
} else { } else {
// Create default filter for monitored ports // Create default filter for monitored ports
defaultFilter := buildBPFForPorts(cfg.ListenPorts) defaultFilter := buildBPFForPorts(cfg.ListenPorts)
err = c.handle.SetBPFFilter(defaultFilter) err = handle.SetBPFFilter(defaultFilter)
if err != nil { if err != nil {
return fmt.Errorf("failed to set default BPF filter: %w", err) return fmt.Errorf("failed to set default BPF filter: %w", err)
} }
} }
packetSource := gopacket.NewPacketSource(c.handle, c.handle.LinkType()) packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
for packet := range packetSource.Packets() { for packet := range packetSource.Packets() {
// Convert packet to RawPacket // Convert packet to RawPacket
@ -102,8 +115,12 @@ func packetToRawPacket(packet gopacket.Packet) *api.RawPacket {
// Close properly closes the capture handle // Close properly closes the capture handle
func (c *CaptureImpl) Close() error { func (c *CaptureImpl) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.handle != nil { if c.handle != nil {
c.handle.Close() c.handle.Close()
c.handle = nil
return nil return nil
} }
return nil return nil

View File

@ -79,3 +79,20 @@ func TestJoinString(t *testing.T) {
func TestCaptureIntegration(t *testing.T) { func TestCaptureIntegration(t *testing.T) {
t.Skip("Skipping integration test requiring network access and elevated privileges") t.Skip("Skipping integration test requiring network access and elevated privileges")
} }
func TestClose_NoHandle_NoError(t *testing.T) {
c := New()
if err := c.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
}
func TestClose_Idempotent_NoHandle(t *testing.T) {
c := New()
if err := c.Close(); err != nil {
t.Fatalf("first Close() error = %v", err)
}
if err := c.Close(); err != nil {
t.Fatalf("second Close() error = %v", err)
}
}

View File

@ -3,6 +3,7 @@ package config
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@ -28,13 +29,17 @@ func NewLoader(configPath string) *LoaderImpl {
func (l *LoaderImpl) Load() (api.AppConfig, error) { func (l *LoaderImpl) Load() (api.AppConfig, error) {
config := api.DefaultConfig() config := api.DefaultConfig()
// Load from YAML file if path is provided path := l.configPath
if l.configPath != "" { explicit := path != ""
fileConfig, err := l.loadFromFile(l.configPath) if !explicit {
if err != nil { path = "config.yml"
return config, fmt.Errorf("failed to load config file: %w", err) }
}
fileConfig, err := l.loadFromFile(path)
if err == nil {
config = mergeConfigs(config, fileConfig) config = mergeConfigs(config, fileConfig)
} else if !( !explicit && errors.Is(err, os.ErrNotExist)) {
return config, fmt.Errorf("failed to load config file: %w", err)
} }
// Override with environment variables // Override with environment variables

View File

@ -2,6 +2,7 @@ package config
import ( import (
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
@ -211,3 +212,51 @@ func TestToJSON(t *testing.T) {
t.Error("ToJSON() result doesn't contain 'eth0'") t.Error("ToJSON() result doesn't contain 'eth0'")
} }
} }
func TestLoad_DefaultConfigFileAbsent_DoesNotFail(t *testing.T) {
t.Setenv("JA4SENTINEL_INTERFACE", "")
t.Setenv("JA4SENTINEL_PORTS", "")
t.Setenv("JA4SENTINEL_BPF_FILTER", "")
t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "")
tempDir := t.TempDir()
oldWD, err := os.Getwd()
if err != nil {
t.Fatalf("Getwd() error = %v", err)
}
defer func() {
_ = os.Chdir(oldWD)
}()
if err := os.Chdir(tempDir); err != nil {
t.Fatalf("Chdir() error = %v", err)
}
_ = os.Remove(filepath.Join(tempDir, "config.yml"))
loader := NewLoader("")
cfg, err := loader.Load()
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Core.Interface != api.DefaultInterface {
t.Errorf("Interface = %q, want %q", cfg.Core.Interface, api.DefaultInterface)
}
if len(cfg.Core.ListenPorts) == 0 || cfg.Core.ListenPorts[0] != api.DefaultPort {
t.Errorf("ListenPorts = %v, want first port %d", cfg.Core.ListenPorts, api.DefaultPort)
}
}
func TestLoad_ExplicitMissingConfig_Fails(t *testing.T) {
t.Setenv("JA4SENTINEL_INTERFACE", "")
t.Setenv("JA4SENTINEL_PORTS", "")
t.Setenv("JA4SENTINEL_BPF_FILTER", "")
t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "")
loader := NewLoader("/tmp/definitely-missing-ja4sentinel.yml")
_, err := loader.Load()
if err == nil {
t.Fatal("Load() should fail with explicit missing config path")
}
}

View File

@ -3,7 +3,6 @@ package logging
import ( import (
"encoding/json" "encoding/json"
"fmt"
"log" "log"
"os" "os"
"strings" "strings"
@ -49,7 +48,8 @@ func NewServiceLogger(level string) *ServiceLogger {
// Log emits a structured log entry to stdout in JSON format // Log emits a structured log entry to stdout in JSON format
func (l *ServiceLogger) Log(component, level, message string, details map[string]string) { func (l *ServiceLogger) Log(component, level, message string, details map[string]string) {
if !l.isLogLevelEnabled(level) { normalizedLevel := strings.ToLower(level)
if !l.isLogLevelEnabled(normalizedLevel) {
return return
} }
@ -58,7 +58,7 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string
defer l.mutex.Unlock() defer l.mutex.Unlock()
serviceLog := api.ServiceLog{ serviceLog := api.ServiceLog{
Level: level, Level: normalizedLevel,
Component: component, Component: component,
Message: message, Message: message,
Details: details, Details: details,
@ -67,40 +67,32 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string
jsonData, err := l.formatter(serviceLog) jsonData, err := l.formatter(serviceLog)
if err != nil { if err != nil {
// Fallback to simple logging if JSON formatting fails // Fallback to simple logging if JSON formatting fails
fmt.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`, l.out.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`+"\n",
time.Now().UnixNano(), err.Error(), message) time.Now().UnixNano(), err.Error(), message)
return return
} }
fmt.Println(string(jsonData)) l.out.Println(string(jsonData))
} }
// Debug logs a debug level entry // Debug logs a debug level entry
func (l *ServiceLogger) Debug(component, message string, details map[string]string) { func (l *ServiceLogger) Debug(component, message string, details map[string]string) {
if l.isLogLevelEnabled("debug") { l.Log(component, "debug", message, details)
l.Log(component, "DEBUG", message, details)
}
} }
// Info logs an info level entry // Info logs an info level entry
func (l *ServiceLogger) Info(component, message string, details map[string]string) { func (l *ServiceLogger) Info(component, message string, details map[string]string) {
if l.isLogLevelEnabled("info") { l.Log(component, "info", message, details)
l.Log(component, "INFO", message, details)
}
} }
// Warn logs a warning level entry // Warn logs a warning level entry
func (l *ServiceLogger) Warn(component, message string, details map[string]string) { func (l *ServiceLogger) Warn(component, message string, details map[string]string) {
if l.isLogLevelEnabled("warn") { l.Log(component, "warn", message, details)
l.Log(component, "WARN", message, details)
}
} }
// Error logs an error level entry // Error logs an error level entry
func (l *ServiceLogger) Error(component, message string, details map[string]string) { func (l *ServiceLogger) Error(component, message string, details map[string]string) {
if l.isLogLevelEnabled("error") { l.Log(component, "error", message, details)
l.Log(component, "ERROR", message, details)
}
} }
// isLogLevelEnabled checks if a log level should be emitted based on configured level // isLogLevelEnabled checks if a log level should be emitted based on configured level

View File

@ -0,0 +1,59 @@
package logging
import (
"bytes"
"log"
"strings"
"testing"
)
func TestIsLogLevelEnabled(t *testing.T) {
tests := []struct {
name string
loggerLevel string
messageLevel string
want bool
}{
{name: "debug logger accepts debug", loggerLevel: "debug", messageLevel: "debug", want: true},
{name: "debug logger accepts info", loggerLevel: "debug", messageLevel: "info", want: true},
{name: "info logger rejects debug", loggerLevel: "info", messageLevel: "debug", want: false},
{name: "info logger accepts info", loggerLevel: "info", messageLevel: "info", want: true},
{name: "warn logger rejects info", loggerLevel: "warn", messageLevel: "info", want: false},
{name: "warn logger accepts error", loggerLevel: "warn", messageLevel: "error", want: true},
{name: "error logger accepts only error", loggerLevel: "error", messageLevel: "error", want: true},
{name: "error logger rejects warn", loggerLevel: "error", messageLevel: "warn", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := NewServiceLogger(tt.loggerLevel)
if got := logger.isLogLevelEnabled(tt.messageLevel); got != tt.want {
t.Fatalf("isLogLevelEnabled(%q) = %v, want %v", tt.messageLevel, got, tt.want)
}
})
}
}
func TestDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) {
logger := NewServiceLogger("info")
var buf bytes.Buffer
logger.out = log.New(&buf, "", 0)
logger.Debug("service", "debug message", map[string]string{"k": "v"})
if buf.Len() != 0 {
t.Fatalf("expected no output for debug at info level, got: %s", buf.String())
}
}
func TestLog_UppercaseDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) {
logger := NewServiceLogger("info")
var buf bytes.Buffer
logger.out = log.New(&buf, "", 0)
logger.Log("service", "DEBUG", "debug message", nil)
if strings.TrimSpace(buf.String()) != "" {
t.Fatalf("expected no output for uppercase DEBUG at info level, got: %s", buf.String())
}
}

View File

@ -103,13 +103,20 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
w.mutex.Lock() w.mutex.Lock()
defer w.mutex.Unlock() defer w.mutex.Unlock()
// Connect if not already connected ensureConn := func() error {
if w.conn == nil { if w.conn != nil {
return nil
}
conn, err := net.Dial("unix", w.socketPath) conn, err := net.Dial("unix", w.socketPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err) return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
} }
w.conn = conn w.conn = conn
return nil
}
if err := ensureConn(); err != nil {
return err
} }
data, err := json.Marshal(rec) data, err := json.Marshal(rec)
@ -120,12 +127,18 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
// Add newline for line-based protocols // Add newline for line-based protocols
data = append(data, '\n') data = append(data, '\n')
_, err = w.conn.Write(data) if _, err = w.conn.Write(data); err != nil {
if err != nil { _ = w.conn.Close()
// Connection failed, try to reconnect
w.conn.Close()
w.conn = nil w.conn = nil
return fmt.Errorf("failed to write to socket: %w", err)
if err2 := ensureConn(); err2 != nil {
return fmt.Errorf("failed to write to socket and reconnect failed: %w", err2)
}
if _, err2 := w.conn.Write(data); err2 != nil {
_ = w.conn.Close()
w.conn = nil
return fmt.Errorf("failed to write to socket after reconnect: %w", err2)
}
} }
return nil return nil

View File

@ -1,10 +1,15 @@
package output package output
import ( import (
"bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"net"
"os" "os"
"path/filepath"
"sync"
"testing" "testing"
"time"
"ja4sentinel/api" "ja4sentinel/api"
) )
@ -102,7 +107,6 @@ func TestMultiWriter(t *testing.T) {
defer fileWriter.Close() defer fileWriter.Close()
multiWriter.Add(fileWriter) multiWriter.Add(fileWriter)
multiWriter.Add(NewStdoutWriter())
rec := api.LogRecord{ rec := api.LogRecord{
SrcIP: "192.168.1.1", SrcIP: "192.168.1.1",
@ -233,3 +237,123 @@ func TestUnixSocketWriter(t *testing.T) {
writer.Close() writer.Close()
} }
type unixTestServer struct {
listener net.Listener
received chan string
mu sync.Mutex
conns map[net.Conn]struct{}
}
func newUnixTestServer(path string) (*unixTestServer, error) {
_ = os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
return nil, err
}
s := &unixTestServer{
listener: ln,
received: make(chan string, 10),
conns: make(map[net.Conn]struct{}),
}
go s.serve()
return s, nil
}
func (s *unixTestServer) serve() {
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
s.mu.Lock()
s.conns[conn] = struct{}{}
s.mu.Unlock()
go func(c net.Conn) {
defer func() {
s.mu.Lock()
delete(s.conns, c)
s.mu.Unlock()
_ = c.Close()
}()
scanner := bufio.NewScanner(c)
for scanner.Scan() {
s.received <- scanner.Text()
}
}(conn)
}
}
func (s *unixTestServer) close(path string) {
_ = s.listener.Close()
s.mu.Lock()
for c := range s.conns {
_ = c.Close()
}
s.mu.Unlock()
_ = os.Remove(path)
}
func TestUnixSocketWriter_ReconnectAndWrite(t *testing.T) {
socketPath := filepath.Join(t.TempDir(), "ja4sentinel.sock")
server1, err := newUnixTestServer(socketPath)
if err != nil {
t.Fatalf("failed to start first unix test server: %v", err)
}
writer, err := NewUnixSocketWriter(socketPath)
if err != nil {
t.Fatalf("NewUnixSocketWriter() error = %v", err)
}
defer writer.Close()
rec1 := api.LogRecord{
SrcIP: "192.168.1.1",
SrcPort: 11111,
DstIP: "10.0.0.1",
DstPort: 443,
JA4: "first",
}
if err := writer.Write(rec1); err != nil {
t.Fatalf("first Write() error = %v", err)
}
select {
case <-server1.received:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting first message on unix socket")
}
server1.close(socketPath)
server2, err := newUnixTestServer(socketPath)
if err != nil {
t.Fatalf("failed to restart unix test server: %v", err)
}
defer server2.close(socketPath)
rec2 := api.LogRecord{
SrcIP: "192.168.1.2",
SrcPort: 22222,
DstIP: "10.0.0.2",
DstPort: 443,
JA4: "second",
}
if err := writer.Write(rec2); err != nil {
t.Fatalf("second Write() after reconnect error = %v", err)
}
select {
case <-server2.received:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting second message after reconnect")
}
}

View File

@ -46,6 +46,7 @@ type ParserImpl struct {
flowTimeout time.Duration flowTimeout time.Duration
cleanupDone chan struct{} cleanupDone chan struct{}
cleanupClose chan struct{} cleanupClose chan struct{}
closeOnce sync.Once
} }
// NewParser creates a new TLS parser with connection state tracking // NewParser creates a new TLS parser with connection state tracking
@ -260,8 +261,10 @@ func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, d
// Close cleans up the parser and stops background goroutines // Close cleans up the parser and stops background goroutines
func (p *ParserImpl) Close() error { func (p *ParserImpl) Close() error {
close(p.cleanupClose) p.closeOnce.Do(func() {
<-p.cleanupDone close(p.cleanupClose)
<-p.cleanupDone
})
return nil return nil
} }
@ -296,8 +299,12 @@ func extractTCPMeta(tcp *layers.TCP) api.TCPMeta {
for _, opt := range tcp.Options { for _, opt := range tcp.Options {
switch opt.OptionType { switch opt.OptionType {
case layers.TCPOptionKindMSS: case layers.TCPOptionKindMSS:
meta.MSS = binary.BigEndian.Uint16(opt.OptionData) if len(opt.OptionData) >= 2 {
meta.Options = append(meta.Options, "MSS") meta.MSS = binary.BigEndian.Uint16(opt.OptionData[:2])
meta.Options = append(meta.Options, "MSS")
} else {
meta.Options = append(meta.Options, "MSS_INVALID")
}
case layers.TCPOptionKindWindowScale: case layers.TCPOptionKindWindowScale:
if len(opt.OptionData) > 0 { if len(opt.OptionData) > 0 {
meta.WindowScale = opt.OptionData[0] meta.WindowScale = opt.OptionData[0]

View File

@ -251,3 +251,39 @@ func TestParserConnectionStateTracking(t *testing.T) {
t.Error("IsClientHello() should return true for valid ClientHello") t.Error("IsClientHello() should return true for valid ClientHello")
} }
} }
func TestParserClose_Idempotent(t *testing.T) {
parser := NewParser()
if err := parser.Close(); err != nil {
t.Fatalf("first Close() error = %v", err)
}
if err := parser.Close(); err != nil {
t.Fatalf("second Close() error = %v", err)
}
}
func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) {
tcp := &layers.TCP{
Window: 1234,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionData: []byte{0x05}, // malformed (1 byte instead of 2)
},
},
}
meta := extractTCPMeta(tcp)
found := false
for _, opt := range meta.Options {
if opt == "MSS_INVALID" {
found = true
break
}
}
if !found {
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
}
}