fix: sécuriser shutdown, config par défaut et reconnexion socket

Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) <aider@aider.chat>
This commit is contained in:
Jacquin Antoine
2026-02-25 21:44:40 +01:00
parent 617ecd2014
commit 6cd6c4c3b8
11 changed files with 394 additions and 56 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"
@ -83,23 +84,37 @@ func main() {
helloChan := make(chan api.TLSClientHello, 1000)
errorChan := make(chan error, 100)
var wg sync.WaitGroup
// Setup signal handling for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Start capture goroutine
wg.Add(1)
go func() {
defer wg.Done()
defer close(packetChan)
logger.Info("capture", "Starting packet capture", map[string]string{
"interface": cfg.Core.Interface,
})
err := captureImpl.Run(cfg.Core, packetChan)
if err != nil {
errorChan <- fmt.Errorf("capture error: %w", err)
select {
case errorChan <- fmt.Errorf("capture error: %w", err):
default:
}
}
}()
// Start TLS parsing goroutine
wg.Add(1)
go func() {
defer wg.Done()
defer close(helloChan)
for pkt := range packetChan {
hello, err := parser.Process(pkt)
if err != nil {
@ -121,7 +136,10 @@ func main() {
}()
// Start fingerprinting and output goroutine
wg.Add(1)
go func() {
defer wg.Done()
for hello := range helloChan {
fingerprints, err := engine.FromClientHello(hello)
if err != nil {
@ -162,6 +180,21 @@ func main() {
// Graceful shutdown
logger.Info("service", "Shutting down", nil)
if err := captureImpl.Close(); err != nil {
logger.Error("capture", "Error closing capture", map[string]string{
"error": err.Error(),
})
}
wg.Wait()
// Close parser (stops cleanup goroutine)
if err := parser.Close(); err != nil {
logger.Error("tlsparse", "Error closing parser", map[string]string{
"error": err.Error(),
})
}
// Close output writer
if closer, ok := writer.(interface{ CloseAll() error }); ok {
if err := closer.CloseAll(); err != nil {
@ -171,19 +204,5 @@ func main() {
}
}
// Close parser (stops cleanup goroutine)
if err := parser.Close(); err != nil {
logger.Error("tlsparse", "Error closing parser", map[string]string{
"error": err.Error(),
})
}
// Close capture
if err := captureImpl.Close(); err != nil {
logger.Error("capture", "Error closing capture", map[string]string{
"error": err.Error(),
})
}
logger.Info("service", "ja4sentinel stopped", nil)
}

View File

@ -3,6 +3,7 @@ package capture
import (
"fmt"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/pcap"
@ -13,6 +14,7 @@ import (
// CaptureImpl implements the capture.Capture interface for packet capture
type CaptureImpl struct {
handle *pcap.Handle
mu sync.Mutex
}
// New creates a new capture instance
@ -22,29 +24,40 @@ func New() *CaptureImpl {
// Run starts network packet capture according to the configuration
func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error {
var err error
c.handle, err = pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever)
handle, err := pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever)
if err != nil {
return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err)
}
defer c.handle.Close()
c.mu.Lock()
c.handle = handle
c.mu.Unlock()
defer func() {
c.mu.Lock()
if c.handle != nil {
c.handle.Close()
c.handle = nil
}
c.mu.Unlock()
}()
// Apply BPF filter if provided
if cfg.BPFFilter != "" {
err = c.handle.SetBPFFilter(cfg.BPFFilter)
err = handle.SetBPFFilter(cfg.BPFFilter)
if err != nil {
return fmt.Errorf("failed to set BPF filter: %w", err)
}
} else {
// Create default filter for monitored ports
defaultFilter := buildBPFForPorts(cfg.ListenPorts)
err = c.handle.SetBPFFilter(defaultFilter)
err = handle.SetBPFFilter(defaultFilter)
if err != nil {
return fmt.Errorf("failed to set default BPF filter: %w", err)
}
}
packetSource := gopacket.NewPacketSource(c.handle, c.handle.LinkType())
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
for packet := range packetSource.Packets() {
// Convert packet to RawPacket
@ -102,8 +115,12 @@ func packetToRawPacket(packet gopacket.Packet) *api.RawPacket {
// Close properly closes the capture handle
func (c *CaptureImpl) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.handle != nil {
c.handle.Close()
c.handle = nil
return nil
}
return nil

View File

@ -79,3 +79,20 @@ func TestJoinString(t *testing.T) {
func TestCaptureIntegration(t *testing.T) {
t.Skip("Skipping integration test requiring network access and elevated privileges")
}
func TestClose_NoHandle_NoError(t *testing.T) {
c := New()
if err := c.Close(); err != nil {
t.Fatalf("Close() error = %v", err)
}
}
func TestClose_Idempotent_NoHandle(t *testing.T) {
c := New()
if err := c.Close(); err != nil {
t.Fatalf("first Close() error = %v", err)
}
if err := c.Close(); err != nil {
t.Fatalf("second Close() error = %v", err)
}
}

View File

@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
@ -28,13 +29,17 @@ func NewLoader(configPath string) *LoaderImpl {
func (l *LoaderImpl) Load() (api.AppConfig, error) {
config := api.DefaultConfig()
// Load from YAML file if path is provided
if l.configPath != "" {
fileConfig, err := l.loadFromFile(l.configPath)
if err != nil {
return config, fmt.Errorf("failed to load config file: %w", err)
}
path := l.configPath
explicit := path != ""
if !explicit {
path = "config.yml"
}
fileConfig, err := l.loadFromFile(path)
if err == nil {
config = mergeConfigs(config, fileConfig)
} else if !( !explicit && errors.Is(err, os.ErrNotExist)) {
return config, fmt.Errorf("failed to load config file: %w", err)
}
// Override with environment variables

View File

@ -2,6 +2,7 @@ package config
import (
"os"
"path/filepath"
"strings"
"testing"
@ -211,3 +212,51 @@ func TestToJSON(t *testing.T) {
t.Error("ToJSON() result doesn't contain 'eth0'")
}
}
func TestLoad_DefaultConfigFileAbsent_DoesNotFail(t *testing.T) {
t.Setenv("JA4SENTINEL_INTERFACE", "")
t.Setenv("JA4SENTINEL_PORTS", "")
t.Setenv("JA4SENTINEL_BPF_FILTER", "")
t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "")
tempDir := t.TempDir()
oldWD, err := os.Getwd()
if err != nil {
t.Fatalf("Getwd() error = %v", err)
}
defer func() {
_ = os.Chdir(oldWD)
}()
if err := os.Chdir(tempDir); err != nil {
t.Fatalf("Chdir() error = %v", err)
}
_ = os.Remove(filepath.Join(tempDir, "config.yml"))
loader := NewLoader("")
cfg, err := loader.Load()
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Core.Interface != api.DefaultInterface {
t.Errorf("Interface = %q, want %q", cfg.Core.Interface, api.DefaultInterface)
}
if len(cfg.Core.ListenPorts) == 0 || cfg.Core.ListenPorts[0] != api.DefaultPort {
t.Errorf("ListenPorts = %v, want first port %d", cfg.Core.ListenPorts, api.DefaultPort)
}
}
func TestLoad_ExplicitMissingConfig_Fails(t *testing.T) {
t.Setenv("JA4SENTINEL_INTERFACE", "")
t.Setenv("JA4SENTINEL_PORTS", "")
t.Setenv("JA4SENTINEL_BPF_FILTER", "")
t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "")
loader := NewLoader("/tmp/definitely-missing-ja4sentinel.yml")
_, err := loader.Load()
if err == nil {
t.Fatal("Load() should fail with explicit missing config path")
}
}

View File

@ -3,7 +3,6 @@ package logging
import (
"encoding/json"
"fmt"
"log"
"os"
"strings"
@ -49,7 +48,8 @@ func NewServiceLogger(level string) *ServiceLogger {
// Log emits a structured log entry to stdout in JSON format
func (l *ServiceLogger) Log(component, level, message string, details map[string]string) {
if !l.isLogLevelEnabled(level) {
normalizedLevel := strings.ToLower(level)
if !l.isLogLevelEnabled(normalizedLevel) {
return
}
@ -58,7 +58,7 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string
defer l.mutex.Unlock()
serviceLog := api.ServiceLog{
Level: level,
Level: normalizedLevel,
Component: component,
Message: message,
Details: details,
@ -67,40 +67,32 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string
jsonData, err := l.formatter(serviceLog)
if err != nil {
// Fallback to simple logging if JSON formatting fails
fmt.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`,
l.out.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`+"\n",
time.Now().UnixNano(), err.Error(), message)
return
}
fmt.Println(string(jsonData))
l.out.Println(string(jsonData))
}
// Debug logs a debug level entry
func (l *ServiceLogger) Debug(component, message string, details map[string]string) {
if l.isLogLevelEnabled("debug") {
l.Log(component, "DEBUG", message, details)
}
l.Log(component, "debug", message, details)
}
// Info logs an info level entry
func (l *ServiceLogger) Info(component, message string, details map[string]string) {
if l.isLogLevelEnabled("info") {
l.Log(component, "INFO", message, details)
}
l.Log(component, "info", message, details)
}
// Warn logs a warning level entry
func (l *ServiceLogger) Warn(component, message string, details map[string]string) {
if l.isLogLevelEnabled("warn") {
l.Log(component, "WARN", message, details)
}
l.Log(component, "warn", message, details)
}
// Error logs an error level entry
func (l *ServiceLogger) Error(component, message string, details map[string]string) {
if l.isLogLevelEnabled("error") {
l.Log(component, "ERROR", message, details)
}
l.Log(component, "error", message, details)
}
// isLogLevelEnabled checks if a log level should be emitted based on configured level

View File

@ -0,0 +1,59 @@
package logging
import (
"bytes"
"log"
"strings"
"testing"
)
func TestIsLogLevelEnabled(t *testing.T) {
tests := []struct {
name string
loggerLevel string
messageLevel string
want bool
}{
{name: "debug logger accepts debug", loggerLevel: "debug", messageLevel: "debug", want: true},
{name: "debug logger accepts info", loggerLevel: "debug", messageLevel: "info", want: true},
{name: "info logger rejects debug", loggerLevel: "info", messageLevel: "debug", want: false},
{name: "info logger accepts info", loggerLevel: "info", messageLevel: "info", want: true},
{name: "warn logger rejects info", loggerLevel: "warn", messageLevel: "info", want: false},
{name: "warn logger accepts error", loggerLevel: "warn", messageLevel: "error", want: true},
{name: "error logger accepts only error", loggerLevel: "error", messageLevel: "error", want: true},
{name: "error logger rejects warn", loggerLevel: "error", messageLevel: "warn", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := NewServiceLogger(tt.loggerLevel)
if got := logger.isLogLevelEnabled(tt.messageLevel); got != tt.want {
t.Fatalf("isLogLevelEnabled(%q) = %v, want %v", tt.messageLevel, got, tt.want)
}
})
}
}
func TestDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) {
logger := NewServiceLogger("info")
var buf bytes.Buffer
logger.out = log.New(&buf, "", 0)
logger.Debug("service", "debug message", map[string]string{"k": "v"})
if buf.Len() != 0 {
t.Fatalf("expected no output for debug at info level, got: %s", buf.String())
}
}
func TestLog_UppercaseDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) {
logger := NewServiceLogger("info")
var buf bytes.Buffer
logger.out = log.New(&buf, "", 0)
logger.Log("service", "DEBUG", "debug message", nil)
if strings.TrimSpace(buf.String()) != "" {
t.Fatalf("expected no output for uppercase DEBUG at info level, got: %s", buf.String())
}
}

View File

@ -103,13 +103,20 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
w.mutex.Lock()
defer w.mutex.Unlock()
// Connect if not already connected
if w.conn == nil {
ensureConn := func() error {
if w.conn != nil {
return nil
}
conn, err := net.Dial("unix", w.socketPath)
if err != nil {
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
}
w.conn = conn
return nil
}
if err := ensureConn(); err != nil {
return err
}
data, err := json.Marshal(rec)
@ -120,12 +127,18 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
// Add newline for line-based protocols
data = append(data, '\n')
_, err = w.conn.Write(data)
if err != nil {
// Connection failed, try to reconnect
w.conn.Close()
if _, err = w.conn.Write(data); err != nil {
_ = w.conn.Close()
w.conn = nil
return fmt.Errorf("failed to write to socket: %w", err)
if err2 := ensureConn(); err2 != nil {
return fmt.Errorf("failed to write to socket and reconnect failed: %w", err2)
}
if _, err2 := w.conn.Write(data); err2 != nil {
_ = w.conn.Close()
w.conn = nil
return fmt.Errorf("failed to write to socket after reconnect: %w", err2)
}
}
return nil

View File

@ -1,10 +1,15 @@
package output
import (
"bufio"
"bytes"
"encoding/json"
"net"
"os"
"path/filepath"
"sync"
"testing"
"time"
"ja4sentinel/api"
)
@ -102,7 +107,6 @@ func TestMultiWriter(t *testing.T) {
defer fileWriter.Close()
multiWriter.Add(fileWriter)
multiWriter.Add(NewStdoutWriter())
rec := api.LogRecord{
SrcIP: "192.168.1.1",
@ -233,3 +237,123 @@ func TestUnixSocketWriter(t *testing.T) {
writer.Close()
}
type unixTestServer struct {
listener net.Listener
received chan string
mu sync.Mutex
conns map[net.Conn]struct{}
}
func newUnixTestServer(path string) (*unixTestServer, error) {
_ = os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
return nil, err
}
s := &unixTestServer{
listener: ln,
received: make(chan string, 10),
conns: make(map[net.Conn]struct{}),
}
go s.serve()
return s, nil
}
func (s *unixTestServer) serve() {
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
s.mu.Lock()
s.conns[conn] = struct{}{}
s.mu.Unlock()
go func(c net.Conn) {
defer func() {
s.mu.Lock()
delete(s.conns, c)
s.mu.Unlock()
_ = c.Close()
}()
scanner := bufio.NewScanner(c)
for scanner.Scan() {
s.received <- scanner.Text()
}
}(conn)
}
}
func (s *unixTestServer) close(path string) {
_ = s.listener.Close()
s.mu.Lock()
for c := range s.conns {
_ = c.Close()
}
s.mu.Unlock()
_ = os.Remove(path)
}
func TestUnixSocketWriter_ReconnectAndWrite(t *testing.T) {
socketPath := filepath.Join(t.TempDir(), "ja4sentinel.sock")
server1, err := newUnixTestServer(socketPath)
if err != nil {
t.Fatalf("failed to start first unix test server: %v", err)
}
writer, err := NewUnixSocketWriter(socketPath)
if err != nil {
t.Fatalf("NewUnixSocketWriter() error = %v", err)
}
defer writer.Close()
rec1 := api.LogRecord{
SrcIP: "192.168.1.1",
SrcPort: 11111,
DstIP: "10.0.0.1",
DstPort: 443,
JA4: "first",
}
if err := writer.Write(rec1); err != nil {
t.Fatalf("first Write() error = %v", err)
}
select {
case <-server1.received:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting first message on unix socket")
}
server1.close(socketPath)
server2, err := newUnixTestServer(socketPath)
if err != nil {
t.Fatalf("failed to restart unix test server: %v", err)
}
defer server2.close(socketPath)
rec2 := api.LogRecord{
SrcIP: "192.168.1.2",
SrcPort: 22222,
DstIP: "10.0.0.2",
DstPort: 443,
JA4: "second",
}
if err := writer.Write(rec2); err != nil {
t.Fatalf("second Write() after reconnect error = %v", err)
}
select {
case <-server2.received:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting second message after reconnect")
}
}

View File

@ -46,6 +46,7 @@ type ParserImpl struct {
flowTimeout time.Duration
cleanupDone chan struct{}
cleanupClose chan struct{}
closeOnce sync.Once
}
// NewParser creates a new TLS parser with connection state tracking
@ -260,8 +261,10 @@ func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, d
// Close cleans up the parser and stops background goroutines
func (p *ParserImpl) Close() error {
close(p.cleanupClose)
<-p.cleanupDone
p.closeOnce.Do(func() {
close(p.cleanupClose)
<-p.cleanupDone
})
return nil
}
@ -296,8 +299,12 @@ func extractTCPMeta(tcp *layers.TCP) api.TCPMeta {
for _, opt := range tcp.Options {
switch opt.OptionType {
case layers.TCPOptionKindMSS:
meta.MSS = binary.BigEndian.Uint16(opt.OptionData)
meta.Options = append(meta.Options, "MSS")
if len(opt.OptionData) >= 2 {
meta.MSS = binary.BigEndian.Uint16(opt.OptionData[:2])
meta.Options = append(meta.Options, "MSS")
} else {
meta.Options = append(meta.Options, "MSS_INVALID")
}
case layers.TCPOptionKindWindowScale:
if len(opt.OptionData) > 0 {
meta.WindowScale = opt.OptionData[0]

View File

@ -251,3 +251,39 @@ func TestParserConnectionStateTracking(t *testing.T) {
t.Error("IsClientHello() should return true for valid ClientHello")
}
}
func TestParserClose_Idempotent(t *testing.T) {
parser := NewParser()
if err := parser.Close(); err != nil {
t.Fatalf("first Close() error = %v", err)
}
if err := parser.Close(); err != nil {
t.Fatalf("second Close() error = %v", err)
}
}
func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) {
tcp := &layers.TCP{
Window: 1234,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionData: []byte{0x05}, // malformed (1 byte instead of 2)
},
},
}
meta := extractTCPMeta(tcp)
found := false
for _, opt := range meta.Options {
if opt == "MSS_INVALID" {
found = true
break
}
}
if !found {
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
}
}