fix: sécuriser shutdown, config par défaut et reconnexion socket
Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) <aider@aider.chat>
This commit is contained in:
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -83,23 +84,37 @@ func main() {
|
|||||||
helloChan := make(chan api.TLSClientHello, 1000)
|
helloChan := make(chan api.TLSClientHello, 1000)
|
||||||
errorChan := make(chan error, 100)
|
errorChan := make(chan error, 100)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
// Setup signal handling for graceful shutdown
|
// Setup signal handling for graceful shutdown
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// Start capture goroutine
|
// Start capture goroutine
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer close(packetChan)
|
||||||
|
|
||||||
logger.Info("capture", "Starting packet capture", map[string]string{
|
logger.Info("capture", "Starting packet capture", map[string]string{
|
||||||
"interface": cfg.Core.Interface,
|
"interface": cfg.Core.Interface,
|
||||||
})
|
})
|
||||||
|
|
||||||
err := captureImpl.Run(cfg.Core, packetChan)
|
err := captureImpl.Run(cfg.Core, packetChan)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorChan <- fmt.Errorf("capture error: %w", err)
|
select {
|
||||||
|
case errorChan <- fmt.Errorf("capture error: %w", err):
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Start TLS parsing goroutine
|
// Start TLS parsing goroutine
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer close(helloChan)
|
||||||
|
|
||||||
for pkt := range packetChan {
|
for pkt := range packetChan {
|
||||||
hello, err := parser.Process(pkt)
|
hello, err := parser.Process(pkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -121,7 +136,10 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Start fingerprinting and output goroutine
|
// Start fingerprinting and output goroutine
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
for hello := range helloChan {
|
for hello := range helloChan {
|
||||||
fingerprints, err := engine.FromClientHello(hello)
|
fingerprints, err := engine.FromClientHello(hello)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -162,6 +180,21 @@ func main() {
|
|||||||
// Graceful shutdown
|
// Graceful shutdown
|
||||||
logger.Info("service", "Shutting down", nil)
|
logger.Info("service", "Shutting down", nil)
|
||||||
|
|
||||||
|
if err := captureImpl.Close(); err != nil {
|
||||||
|
logger.Error("capture", "Error closing capture", map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Close parser (stops cleanup goroutine)
|
||||||
|
if err := parser.Close(); err != nil {
|
||||||
|
logger.Error("tlsparse", "Error closing parser", map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Close output writer
|
// Close output writer
|
||||||
if closer, ok := writer.(interface{ CloseAll() error }); ok {
|
if closer, ok := writer.(interface{ CloseAll() error }); ok {
|
||||||
if err := closer.CloseAll(); err != nil {
|
if err := closer.CloseAll(); err != nil {
|
||||||
@ -171,19 +204,5 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close parser (stops cleanup goroutine)
|
|
||||||
if err := parser.Close(); err != nil {
|
|
||||||
logger.Error("tlsparse", "Error closing parser", map[string]string{
|
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close capture
|
|
||||||
if err := captureImpl.Close(); err != nil {
|
|
||||||
logger.Error("capture", "Error closing capture", map[string]string{
|
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("service", "ja4sentinel stopped", nil)
|
logger.Info("service", "ja4sentinel stopped", nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package capture
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/pcap"
|
"github.com/google/gopacket/pcap"
|
||||||
@ -13,6 +14,7 @@ import (
|
|||||||
// CaptureImpl implements the capture.Capture interface for packet capture
|
// CaptureImpl implements the capture.Capture interface for packet capture
|
||||||
type CaptureImpl struct {
|
type CaptureImpl struct {
|
||||||
handle *pcap.Handle
|
handle *pcap.Handle
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new capture instance
|
// New creates a new capture instance
|
||||||
@ -22,29 +24,40 @@ func New() *CaptureImpl {
|
|||||||
|
|
||||||
// Run starts network packet capture according to the configuration
|
// Run starts network packet capture according to the configuration
|
||||||
func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error {
|
func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error {
|
||||||
var err error
|
handle, err := pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever)
|
||||||
c.handle, err = pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err)
|
return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err)
|
||||||
}
|
}
|
||||||
defer c.handle.Close()
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.handle = handle
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.handle != nil {
|
||||||
|
c.handle.Close()
|
||||||
|
c.handle = nil
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
// Apply BPF filter if provided
|
// Apply BPF filter if provided
|
||||||
if cfg.BPFFilter != "" {
|
if cfg.BPFFilter != "" {
|
||||||
err = c.handle.SetBPFFilter(cfg.BPFFilter)
|
err = handle.SetBPFFilter(cfg.BPFFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set BPF filter: %w", err)
|
return fmt.Errorf("failed to set BPF filter: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Create default filter for monitored ports
|
// Create default filter for monitored ports
|
||||||
defaultFilter := buildBPFForPorts(cfg.ListenPorts)
|
defaultFilter := buildBPFForPorts(cfg.ListenPorts)
|
||||||
err = c.handle.SetBPFFilter(defaultFilter)
|
err = handle.SetBPFFilter(defaultFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set default BPF filter: %w", err)
|
return fmt.Errorf("failed to set default BPF filter: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
packetSource := gopacket.NewPacketSource(c.handle, c.handle.LinkType())
|
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
||||||
|
|
||||||
for packet := range packetSource.Packets() {
|
for packet := range packetSource.Packets() {
|
||||||
// Convert packet to RawPacket
|
// Convert packet to RawPacket
|
||||||
@ -102,8 +115,12 @@ func packetToRawPacket(packet gopacket.Packet) *api.RawPacket {
|
|||||||
|
|
||||||
// Close properly closes the capture handle
|
// Close properly closes the capture handle
|
||||||
func (c *CaptureImpl) Close() error {
|
func (c *CaptureImpl) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
if c.handle != nil {
|
if c.handle != nil {
|
||||||
c.handle.Close()
|
c.handle.Close()
|
||||||
|
c.handle = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -79,3 +79,20 @@ func TestJoinString(t *testing.T) {
|
|||||||
func TestCaptureIntegration(t *testing.T) {
|
func TestCaptureIntegration(t *testing.T) {
|
||||||
t.Skip("Skipping integration test requiring network access and elevated privileges")
|
t.Skip("Skipping integration test requiring network access and elevated privileges")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClose_NoHandle_NoError(t *testing.T) {
|
||||||
|
c := New()
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
t.Fatalf("Close() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClose_Idempotent_NoHandle(t *testing.T) {
|
||||||
|
c := New()
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close() error = %v", err)
|
||||||
|
}
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
t.Fatalf("second Close() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -28,13 +29,17 @@ func NewLoader(configPath string) *LoaderImpl {
|
|||||||
func (l *LoaderImpl) Load() (api.AppConfig, error) {
|
func (l *LoaderImpl) Load() (api.AppConfig, error) {
|
||||||
config := api.DefaultConfig()
|
config := api.DefaultConfig()
|
||||||
|
|
||||||
// Load from YAML file if path is provided
|
path := l.configPath
|
||||||
if l.configPath != "" {
|
explicit := path != ""
|
||||||
fileConfig, err := l.loadFromFile(l.configPath)
|
if !explicit {
|
||||||
if err != nil {
|
path = "config.yml"
|
||||||
return config, fmt.Errorf("failed to load config file: %w", err)
|
}
|
||||||
}
|
|
||||||
|
fileConfig, err := l.loadFromFile(path)
|
||||||
|
if err == nil {
|
||||||
config = mergeConfigs(config, fileConfig)
|
config = mergeConfigs(config, fileConfig)
|
||||||
|
} else if !( !explicit && errors.Is(err, os.ErrNotExist)) {
|
||||||
|
return config, fmt.Errorf("failed to load config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override with environment variables
|
// Override with environment variables
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -211,3 +212,51 @@ func TestToJSON(t *testing.T) {
|
|||||||
t.Error("ToJSON() result doesn't contain 'eth0'")
|
t.Error("ToJSON() result doesn't contain 'eth0'")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoad_DefaultConfigFileAbsent_DoesNotFail(t *testing.T) {
|
||||||
|
t.Setenv("JA4SENTINEL_INTERFACE", "")
|
||||||
|
t.Setenv("JA4SENTINEL_PORTS", "")
|
||||||
|
t.Setenv("JA4SENTINEL_BPF_FILTER", "")
|
||||||
|
t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "")
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
oldWD, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Getwd() error = %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = os.Chdir(oldWD)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := os.Chdir(tempDir); err != nil {
|
||||||
|
t.Fatalf("Chdir() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = os.Remove(filepath.Join(tempDir, "config.yml"))
|
||||||
|
|
||||||
|
loader := NewLoader("")
|
||||||
|
cfg, err := loader.Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Core.Interface != api.DefaultInterface {
|
||||||
|
t.Errorf("Interface = %q, want %q", cfg.Core.Interface, api.DefaultInterface)
|
||||||
|
}
|
||||||
|
if len(cfg.Core.ListenPorts) == 0 || cfg.Core.ListenPorts[0] != api.DefaultPort {
|
||||||
|
t.Errorf("ListenPorts = %v, want first port %d", cfg.Core.ListenPorts, api.DefaultPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad_ExplicitMissingConfig_Fails(t *testing.T) {
|
||||||
|
t.Setenv("JA4SENTINEL_INTERFACE", "")
|
||||||
|
t.Setenv("JA4SENTINEL_PORTS", "")
|
||||||
|
t.Setenv("JA4SENTINEL_BPF_FILTER", "")
|
||||||
|
t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "")
|
||||||
|
|
||||||
|
loader := NewLoader("/tmp/definitely-missing-ja4sentinel.yml")
|
||||||
|
_, err := loader.Load()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Load() should fail with explicit missing config path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -3,7 +3,6 @@ package logging
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@ -49,7 +48,8 @@ func NewServiceLogger(level string) *ServiceLogger {
|
|||||||
|
|
||||||
// Log emits a structured log entry to stdout in JSON format
|
// Log emits a structured log entry to stdout in JSON format
|
||||||
func (l *ServiceLogger) Log(component, level, message string, details map[string]string) {
|
func (l *ServiceLogger) Log(component, level, message string, details map[string]string) {
|
||||||
if !l.isLogLevelEnabled(level) {
|
normalizedLevel := strings.ToLower(level)
|
||||||
|
if !l.isLogLevelEnabled(normalizedLevel) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string
|
|||||||
defer l.mutex.Unlock()
|
defer l.mutex.Unlock()
|
||||||
|
|
||||||
serviceLog := api.ServiceLog{
|
serviceLog := api.ServiceLog{
|
||||||
Level: level,
|
Level: normalizedLevel,
|
||||||
Component: component,
|
Component: component,
|
||||||
Message: message,
|
Message: message,
|
||||||
Details: details,
|
Details: details,
|
||||||
@ -67,40 +67,32 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string
|
|||||||
jsonData, err := l.formatter(serviceLog)
|
jsonData, err := l.formatter(serviceLog)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Fallback to simple logging if JSON formatting fails
|
// Fallback to simple logging if JSON formatting fails
|
||||||
fmt.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`,
|
l.out.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`+"\n",
|
||||||
time.Now().UnixNano(), err.Error(), message)
|
time.Now().UnixNano(), err.Error(), message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(string(jsonData))
|
l.out.Println(string(jsonData))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debug logs a debug level entry
|
// Debug logs a debug level entry
|
||||||
func (l *ServiceLogger) Debug(component, message string, details map[string]string) {
|
func (l *ServiceLogger) Debug(component, message string, details map[string]string) {
|
||||||
if l.isLogLevelEnabled("debug") {
|
l.Log(component, "debug", message, details)
|
||||||
l.Log(component, "DEBUG", message, details)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Info logs an info level entry
|
// Info logs an info level entry
|
||||||
func (l *ServiceLogger) Info(component, message string, details map[string]string) {
|
func (l *ServiceLogger) Info(component, message string, details map[string]string) {
|
||||||
if l.isLogLevelEnabled("info") {
|
l.Log(component, "info", message, details)
|
||||||
l.Log(component, "INFO", message, details)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn logs a warning level entry
|
// Warn logs a warning level entry
|
||||||
func (l *ServiceLogger) Warn(component, message string, details map[string]string) {
|
func (l *ServiceLogger) Warn(component, message string, details map[string]string) {
|
||||||
if l.isLogLevelEnabled("warn") {
|
l.Log(component, "warn", message, details)
|
||||||
l.Log(component, "WARN", message, details)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error logs an error level entry
|
// Error logs an error level entry
|
||||||
func (l *ServiceLogger) Error(component, message string, details map[string]string) {
|
func (l *ServiceLogger) Error(component, message string, details map[string]string) {
|
||||||
if l.isLogLevelEnabled("error") {
|
l.Log(component, "error", message, details)
|
||||||
l.Log(component, "ERROR", message, details)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLogLevelEnabled checks if a log level should be emitted based on configured level
|
// isLogLevelEnabled checks if a log level should be emitted based on configured level
|
||||||
|
|||||||
59
internal/logging/service_logger_test.go
Normal file
59
internal/logging/service_logger_test.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsLogLevelEnabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loggerLevel string
|
||||||
|
messageLevel string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "debug logger accepts debug", loggerLevel: "debug", messageLevel: "debug", want: true},
|
||||||
|
{name: "debug logger accepts info", loggerLevel: "debug", messageLevel: "info", want: true},
|
||||||
|
{name: "info logger rejects debug", loggerLevel: "info", messageLevel: "debug", want: false},
|
||||||
|
{name: "info logger accepts info", loggerLevel: "info", messageLevel: "info", want: true},
|
||||||
|
{name: "warn logger rejects info", loggerLevel: "warn", messageLevel: "info", want: false},
|
||||||
|
{name: "warn logger accepts error", loggerLevel: "warn", messageLevel: "error", want: true},
|
||||||
|
{name: "error logger accepts only error", loggerLevel: "error", messageLevel: "error", want: true},
|
||||||
|
{name: "error logger rejects warn", loggerLevel: "error", messageLevel: "warn", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
logger := NewServiceLogger(tt.loggerLevel)
|
||||||
|
if got := logger.isLogLevelEnabled(tt.messageLevel); got != tt.want {
|
||||||
|
t.Fatalf("isLogLevelEnabled(%q) = %v, want %v", tt.messageLevel, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) {
|
||||||
|
logger := NewServiceLogger("info")
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger.out = log.New(&buf, "", 0)
|
||||||
|
|
||||||
|
logger.Debug("service", "debug message", map[string]string{"k": "v"})
|
||||||
|
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
t.Fatalf("expected no output for debug at info level, got: %s", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLog_UppercaseDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) {
|
||||||
|
logger := NewServiceLogger("info")
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger.out = log.New(&buf, "", 0)
|
||||||
|
|
||||||
|
logger.Log("service", "DEBUG", "debug message", nil)
|
||||||
|
|
||||||
|
if strings.TrimSpace(buf.String()) != "" {
|
||||||
|
t.Fatalf("expected no output for uppercase DEBUG at info level, got: %s", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -103,13 +103,20 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
|
|||||||
w.mutex.Lock()
|
w.mutex.Lock()
|
||||||
defer w.mutex.Unlock()
|
defer w.mutex.Unlock()
|
||||||
|
|
||||||
// Connect if not already connected
|
ensureConn := func() error {
|
||||||
if w.conn == nil {
|
if w.conn != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
conn, err := net.Dial("unix", w.socketPath)
|
conn, err := net.Dial("unix", w.socketPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
|
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
|
||||||
}
|
}
|
||||||
w.conn = conn
|
w.conn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ensureConn(); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(rec)
|
data, err := json.Marshal(rec)
|
||||||
@ -120,12 +127,18 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
|
|||||||
// Add newline for line-based protocols
|
// Add newline for line-based protocols
|
||||||
data = append(data, '\n')
|
data = append(data, '\n')
|
||||||
|
|
||||||
_, err = w.conn.Write(data)
|
if _, err = w.conn.Write(data); err != nil {
|
||||||
if err != nil {
|
_ = w.conn.Close()
|
||||||
// Connection failed, try to reconnect
|
|
||||||
w.conn.Close()
|
|
||||||
w.conn = nil
|
w.conn = nil
|
||||||
return fmt.Errorf("failed to write to socket: %w", err)
|
|
||||||
|
if err2 := ensureConn(); err2 != nil {
|
||||||
|
return fmt.Errorf("failed to write to socket and reconnect failed: %w", err2)
|
||||||
|
}
|
||||||
|
if _, err2 := w.conn.Write(data); err2 != nil {
|
||||||
|
_ = w.conn.Close()
|
||||||
|
w.conn = nil
|
||||||
|
return fmt.Errorf("failed to write to socket after reconnect: %w", err2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -1,10 +1,15 @@
|
|||||||
package output
|
package output
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"ja4sentinel/api"
|
"ja4sentinel/api"
|
||||||
)
|
)
|
||||||
@ -102,7 +107,6 @@ func TestMultiWriter(t *testing.T) {
|
|||||||
defer fileWriter.Close()
|
defer fileWriter.Close()
|
||||||
|
|
||||||
multiWriter.Add(fileWriter)
|
multiWriter.Add(fileWriter)
|
||||||
multiWriter.Add(NewStdoutWriter())
|
|
||||||
|
|
||||||
rec := api.LogRecord{
|
rec := api.LogRecord{
|
||||||
SrcIP: "192.168.1.1",
|
SrcIP: "192.168.1.1",
|
||||||
@ -233,3 +237,123 @@ func TestUnixSocketWriter(t *testing.T) {
|
|||||||
|
|
||||||
writer.Close()
|
writer.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type unixTestServer struct {
|
||||||
|
listener net.Listener
|
||||||
|
received chan string
|
||||||
|
mu sync.Mutex
|
||||||
|
conns map[net.Conn]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUnixTestServer(path string) (*unixTestServer, error) {
|
||||||
|
_ = os.Remove(path)
|
||||||
|
ln, err := net.Listen("unix", path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &unixTestServer{
|
||||||
|
listener: ln,
|
||||||
|
received: make(chan string, 10),
|
||||||
|
conns: make(map[net.Conn]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.serve()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *unixTestServer) serve() {
|
||||||
|
for {
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.conns[conn] = struct{}{}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
go func(c net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.conns, c)
|
||||||
|
s.mu.Unlock()
|
||||||
|
_ = c.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(c)
|
||||||
|
for scanner.Scan() {
|
||||||
|
s.received <- scanner.Text()
|
||||||
|
}
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *unixTestServer) close(path string) {
|
||||||
|
_ = s.listener.Close()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
for c := range s.conns {
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
_ = os.Remove(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSocketWriter_ReconnectAndWrite(t *testing.T) {
|
||||||
|
socketPath := filepath.Join(t.TempDir(), "ja4sentinel.sock")
|
||||||
|
|
||||||
|
server1, err := newUnixTestServer(socketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start first unix test server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
writer, err := NewUnixSocketWriter(socketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewUnixSocketWriter() error = %v", err)
|
||||||
|
}
|
||||||
|
defer writer.Close()
|
||||||
|
|
||||||
|
rec1 := api.LogRecord{
|
||||||
|
SrcIP: "192.168.1.1",
|
||||||
|
SrcPort: 11111,
|
||||||
|
DstIP: "10.0.0.1",
|
||||||
|
DstPort: 443,
|
||||||
|
JA4: "first",
|
||||||
|
}
|
||||||
|
if err := writer.Write(rec1); err != nil {
|
||||||
|
t.Fatalf("first Write() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-server1.received:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting first message on unix socket")
|
||||||
|
}
|
||||||
|
|
||||||
|
server1.close(socketPath)
|
||||||
|
|
||||||
|
server2, err := newUnixTestServer(socketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to restart unix test server: %v", err)
|
||||||
|
}
|
||||||
|
defer server2.close(socketPath)
|
||||||
|
|
||||||
|
rec2 := api.LogRecord{
|
||||||
|
SrcIP: "192.168.1.2",
|
||||||
|
SrcPort: 22222,
|
||||||
|
DstIP: "10.0.0.2",
|
||||||
|
DstPort: 443,
|
||||||
|
JA4: "second",
|
||||||
|
}
|
||||||
|
if err := writer.Write(rec2); err != nil {
|
||||||
|
t.Fatalf("second Write() after reconnect error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-server2.received:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting second message after reconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -46,6 +46,7 @@ type ParserImpl struct {
|
|||||||
flowTimeout time.Duration
|
flowTimeout time.Duration
|
||||||
cleanupDone chan struct{}
|
cleanupDone chan struct{}
|
||||||
cleanupClose chan struct{}
|
cleanupClose chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewParser creates a new TLS parser with connection state tracking
|
// NewParser creates a new TLS parser with connection state tracking
|
||||||
@ -260,8 +261,10 @@ func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, d
|
|||||||
|
|
||||||
// Close cleans up the parser and stops background goroutines
|
// Close cleans up the parser and stops background goroutines
|
||||||
func (p *ParserImpl) Close() error {
|
func (p *ParserImpl) Close() error {
|
||||||
close(p.cleanupClose)
|
p.closeOnce.Do(func() {
|
||||||
<-p.cleanupDone
|
close(p.cleanupClose)
|
||||||
|
<-p.cleanupDone
|
||||||
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,8 +299,12 @@ func extractTCPMeta(tcp *layers.TCP) api.TCPMeta {
|
|||||||
for _, opt := range tcp.Options {
|
for _, opt := range tcp.Options {
|
||||||
switch opt.OptionType {
|
switch opt.OptionType {
|
||||||
case layers.TCPOptionKindMSS:
|
case layers.TCPOptionKindMSS:
|
||||||
meta.MSS = binary.BigEndian.Uint16(opt.OptionData)
|
if len(opt.OptionData) >= 2 {
|
||||||
meta.Options = append(meta.Options, "MSS")
|
meta.MSS = binary.BigEndian.Uint16(opt.OptionData[:2])
|
||||||
|
meta.Options = append(meta.Options, "MSS")
|
||||||
|
} else {
|
||||||
|
meta.Options = append(meta.Options, "MSS_INVALID")
|
||||||
|
}
|
||||||
case layers.TCPOptionKindWindowScale:
|
case layers.TCPOptionKindWindowScale:
|
||||||
if len(opt.OptionData) > 0 {
|
if len(opt.OptionData) > 0 {
|
||||||
meta.WindowScale = opt.OptionData[0]
|
meta.WindowScale = opt.OptionData[0]
|
||||||
|
|||||||
@ -251,3 +251,39 @@ func TestParserConnectionStateTracking(t *testing.T) {
|
|||||||
t.Error("IsClientHello() should return true for valid ClientHello")
|
t.Error("IsClientHello() should return true for valid ClientHello")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParserClose_Idempotent(t *testing.T) {
|
||||||
|
parser := NewParser()
|
||||||
|
|
||||||
|
if err := parser.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close() error = %v", err)
|
||||||
|
}
|
||||||
|
if err := parser.Close(); err != nil {
|
||||||
|
t.Fatalf("second Close() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) {
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
Window: 1234,
|
||||||
|
Options: []layers.TCPOption{
|
||||||
|
{
|
||||||
|
OptionType: layers.TCPOptionKindMSS,
|
||||||
|
OptionData: []byte{0x05}, // malformed (1 byte instead of 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := extractTCPMeta(tcp)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, opt := range meta.Options {
|
||||||
|
if opt == "MSS_INVALID" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user