diff --git a/cmd/ja4sentinel/main.go b/cmd/ja4sentinel/main.go index 6e6fd61..6a7737a 100644 --- a/cmd/ja4sentinel/main.go +++ b/cmd/ja4sentinel/main.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/signal" + "sync" "syscall" "time" @@ -83,23 +84,37 @@ func main() { helloChan := make(chan api.TLSClientHello, 1000) errorChan := make(chan error, 100) + var wg sync.WaitGroup + // Setup signal handling for graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) // Start capture goroutine + wg.Add(1) go func() { + defer wg.Done() + defer close(packetChan) + logger.Info("capture", "Starting packet capture", map[string]string{ "interface": cfg.Core.Interface, }) + err := captureImpl.Run(cfg.Core, packetChan) if err != nil { - errorChan <- fmt.Errorf("capture error: %w", err) + select { + case errorChan <- fmt.Errorf("capture error: %w", err): + default: + } } }() // Start TLS parsing goroutine + wg.Add(1) go func() { + defer wg.Done() + defer close(helloChan) + for pkt := range packetChan { hello, err := parser.Process(pkt) if err != nil { @@ -121,7 +136,10 @@ func main() { }() // Start fingerprinting and output goroutine + wg.Add(1) go func() { + defer wg.Done() + for hello := range helloChan { fingerprints, err := engine.FromClientHello(hello) if err != nil { @@ -162,6 +180,21 @@ func main() { // Graceful shutdown logger.Info("service", "Shutting down", nil) + if err := captureImpl.Close(); err != nil { + logger.Error("capture", "Error closing capture", map[string]string{ + "error": err.Error(), + }) + } + + wg.Wait() + + // Close parser (stops cleanup goroutine) + if err := parser.Close(); err != nil { + logger.Error("tlsparse", "Error closing parser", map[string]string{ + "error": err.Error(), + }) + } + // Close output writer if closer, ok := writer.(interface{ CloseAll() error }); ok { if err := closer.CloseAll(); err != nil { @@ -171,19 +204,5 @@ func main() { } } - // Close parser (stops cleanup goroutine) - if err := parser.Close(); err != nil { - logger.Error("tlsparse", "Error closing parser", map[string]string{ - "error": err.Error(), - }) - } - - // Close capture - if err := captureImpl.Close(); err != nil { - logger.Error("capture", "Error closing capture", map[string]string{ - "error": err.Error(), - }) - } - logger.Info("service", "ja4sentinel stopped", nil) } diff --git a/internal/capture/capture.go b/internal/capture/capture.go index cb3ac8e..7646334 100644 --- a/internal/capture/capture.go +++ b/internal/capture/capture.go @@ -3,6 +3,7 @@ package capture import ( "fmt" + "sync" "github.com/google/gopacket" "github.com/google/gopacket/pcap" @@ -13,6 +14,7 @@ import ( // CaptureImpl implements the capture.Capture interface for packet capture type CaptureImpl struct { handle *pcap.Handle + mu sync.Mutex } // New creates a new capture instance @@ -22,29 +24,40 @@ func New() *CaptureImpl { // Run starts network packet capture according to the configuration func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error { - var err error - c.handle, err = pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever) + handle, err := pcap.OpenLive(cfg.Interface, 1600, true, pcap.BlockForever) if err != nil { return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err) } - defer c.handle.Close() + + c.mu.Lock() + c.handle = handle + c.mu.Unlock() + + defer func() { + c.mu.Lock() + if c.handle != nil { + c.handle.Close() + c.handle = nil + } + c.mu.Unlock() + }() // Apply BPF filter if provided if cfg.BPFFilter != "" { - err = c.handle.SetBPFFilter(cfg.BPFFilter) + err = handle.SetBPFFilter(cfg.BPFFilter) if err != nil { return fmt.Errorf("failed to set BPF filter: %w", err) } } else { // Create default filter for monitored ports defaultFilter := buildBPFForPorts(cfg.ListenPorts) - err = c.handle.SetBPFFilter(defaultFilter) + err = handle.SetBPFFilter(defaultFilter) if err != nil { return fmt.Errorf("failed to set default BPF filter: %w", err) } } - packetSource := gopacket.NewPacketSource(c.handle, c.handle.LinkType()) + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) for packet := range packetSource.Packets() { // Convert packet to RawPacket @@ -102,8 +115,12 @@ func packetToRawPacket(packet gopacket.Packet) *api.RawPacket { // Close properly closes the capture handle func (c *CaptureImpl) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.handle != nil { c.handle.Close() + c.handle = nil return nil } return nil diff --git a/internal/capture/capture_test.go b/internal/capture/capture_test.go index b95ff3d..c923970 100644 --- a/internal/capture/capture_test.go +++ b/internal/capture/capture_test.go @@ -79,3 +79,20 @@ func TestJoinString(t *testing.T) { func TestCaptureIntegration(t *testing.T) { t.Skip("Skipping integration test requiring network access and elevated privileges") } + +func TestClose_NoHandle_NoError(t *testing.T) { + c := New() + if err := c.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } +} + +func TestClose_Idempotent_NoHandle(t *testing.T) { + c := New() + if err := c.Close(); err != nil { + t.Fatalf("first Close() error = %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("second Close() error = %v", err) + } +} diff --git a/internal/config/loader.go b/internal/config/loader.go index b40f143..f27b2b8 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" + "errors" "fmt" "os" "strconv" @@ -28,13 +29,17 @@ func NewLoader(configPath string) *LoaderImpl { func (l *LoaderImpl) Load() (api.AppConfig, error) { config := api.DefaultConfig() - // Load from YAML file if path is provided - if l.configPath != "" { - fileConfig, err := l.loadFromFile(l.configPath) - if err != nil { - return config, fmt.Errorf("failed to load config file: %w", err) - } + path := l.configPath + explicit := path != "" + if !explicit { + path = "config.yml" + } + + fileConfig, err := l.loadFromFile(path) + if err == nil { config = mergeConfigs(config, fileConfig) + } else if !( !explicit && errors.Is(err, os.ErrNotExist)) { + return config, fmt.Errorf("failed to load config file: %w", err) } // Override with environment variables diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index a65241f..52eda30 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -2,6 +2,7 @@ package config import ( "os" + "path/filepath" "strings" "testing" @@ -211,3 +212,51 @@ func TestToJSON(t *testing.T) { t.Error("ToJSON() result doesn't contain 'eth0'") } } + +func TestLoad_DefaultConfigFileAbsent_DoesNotFail(t *testing.T) { + t.Setenv("JA4SENTINEL_INTERFACE", "") + t.Setenv("JA4SENTINEL_PORTS", "") + t.Setenv("JA4SENTINEL_BPF_FILTER", "") + t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "") + + tempDir := t.TempDir() + oldWD, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + defer func() { + _ = os.Chdir(oldWD) + }() + + if err := os.Chdir(tempDir); err != nil { + t.Fatalf("Chdir() error = %v", err) + } + + _ = os.Remove(filepath.Join(tempDir, "config.yml")) + + loader := NewLoader("") + cfg, err := loader.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + if cfg.Core.Interface != api.DefaultInterface { + t.Errorf("Interface = %q, want %q", cfg.Core.Interface, api.DefaultInterface) + } + if len(cfg.Core.ListenPorts) == 0 || cfg.Core.ListenPorts[0] != api.DefaultPort { + t.Errorf("ListenPorts = %v, want first port %d", cfg.Core.ListenPorts, api.DefaultPort) + } +} + +func TestLoad_ExplicitMissingConfig_Fails(t *testing.T) { + t.Setenv("JA4SENTINEL_INTERFACE", "") + t.Setenv("JA4SENTINEL_PORTS", "") + t.Setenv("JA4SENTINEL_BPF_FILTER", "") + t.Setenv("JA4SENTINEL_FLOW_TIMEOUT", "") + + loader := NewLoader("/tmp/definitely-missing-ja4sentinel.yml") + _, err := loader.Load() + if err == nil { + t.Fatal("Load() should fail with explicit missing config path") + } +} diff --git a/internal/logging/service_logger.go b/internal/logging/service_logger.go index 3a38951..b200ad9 100644 --- a/internal/logging/service_logger.go +++ b/internal/logging/service_logger.go @@ -3,7 +3,6 @@ package logging import ( "encoding/json" - "fmt" "log" "os" "strings" @@ -49,7 +48,8 @@ func NewServiceLogger(level string) *ServiceLogger { // Log emits a structured log entry to stdout in JSON format func (l *ServiceLogger) Log(component, level, message string, details map[string]string) { - if !l.isLogLevelEnabled(level) { + normalizedLevel := strings.ToLower(level) + if !l.isLogLevelEnabled(normalizedLevel) { return } @@ -58,7 +58,7 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string defer l.mutex.Unlock() serviceLog := api.ServiceLog{ - Level: level, + Level: normalizedLevel, Component: component, Message: message, Details: details, @@ -67,40 +67,32 @@ func (l *ServiceLogger) Log(component, level, message string, details map[string jsonData, err := l.formatter(serviceLog) if err != nil { // Fallback to simple logging if JSON formatting fails - fmt.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`, + l.out.Printf(`{"timestamp":%d,"level":"ERROR","component":"logging","message":"%s","original_message":"%s"}`+"\n", time.Now().UnixNano(), err.Error(), message) return } - fmt.Println(string(jsonData)) + l.out.Println(string(jsonData)) } // Debug logs a debug level entry func (l *ServiceLogger) Debug(component, message string, details map[string]string) { - if l.isLogLevelEnabled("debug") { - l.Log(component, "DEBUG", message, details) - } + l.Log(component, "debug", message, details) } // Info logs an info level entry func (l *ServiceLogger) Info(component, message string, details map[string]string) { - if l.isLogLevelEnabled("info") { - l.Log(component, "INFO", message, details) - } + l.Log(component, "info", message, details) } // Warn logs a warning level entry func (l *ServiceLogger) Warn(component, message string, details map[string]string) { - if l.isLogLevelEnabled("warn") { - l.Log(component, "WARN", message, details) - } + l.Log(component, "warn", message, details) } // Error logs an error level entry func (l *ServiceLogger) Error(component, message string, details map[string]string) { - if l.isLogLevelEnabled("error") { - l.Log(component, "ERROR", message, details) - } + l.Log(component, "error", message, details) } // isLogLevelEnabled checks if a log level should be emitted based on configured level diff --git a/internal/logging/service_logger_test.go b/internal/logging/service_logger_test.go new file mode 100644 index 0000000..fdbcf9a --- /dev/null +++ b/internal/logging/service_logger_test.go @@ -0,0 +1,59 @@ +package logging + +import ( + "bytes" + "log" + "strings" + "testing" +) + +func TestIsLogLevelEnabled(t *testing.T) { + tests := []struct { + name string + loggerLevel string + messageLevel string + want bool + }{ + {name: "debug logger accepts debug", loggerLevel: "debug", messageLevel: "debug", want: true}, + {name: "debug logger accepts info", loggerLevel: "debug", messageLevel: "info", want: true}, + {name: "info logger rejects debug", loggerLevel: "info", messageLevel: "debug", want: false}, + {name: "info logger accepts info", loggerLevel: "info", messageLevel: "info", want: true}, + {name: "warn logger rejects info", loggerLevel: "warn", messageLevel: "info", want: false}, + {name: "warn logger accepts error", loggerLevel: "warn", messageLevel: "error", want: true}, + {name: "error logger accepts only error", loggerLevel: "error", messageLevel: "error", want: true}, + {name: "error logger rejects warn", loggerLevel: "error", messageLevel: "warn", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := NewServiceLogger(tt.loggerLevel) + if got := logger.isLogLevelEnabled(tt.messageLevel); got != tt.want { + t.Fatalf("isLogLevelEnabled(%q) = %v, want %v", tt.messageLevel, got, tt.want) + } + }) + } +} + +func TestDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) { + logger := NewServiceLogger("info") + var buf bytes.Buffer + logger.out = log.New(&buf, "", 0) + + logger.Debug("service", "debug message", map[string]string{"k": "v"}) + + if buf.Len() != 0 { + t.Fatalf("expected no output for debug at info level, got: %s", buf.String()) + } +} + +func TestLog_UppercaseDebug_NotEmittedWhenLoggerLevelInfo(t *testing.T) { + logger := NewServiceLogger("info") + var buf bytes.Buffer + logger.out = log.New(&buf, "", 0) + + logger.Log("service", "DEBUG", "debug message", nil) + + if strings.TrimSpace(buf.String()) != "" { + t.Fatalf("expected no output for uppercase DEBUG at info level, got: %s", buf.String()) + } +} diff --git a/internal/output/writers.go b/internal/output/writers.go index d17c6f7..d061693 100644 --- a/internal/output/writers.go +++ b/internal/output/writers.go @@ -103,13 +103,20 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error { w.mutex.Lock() defer w.mutex.Unlock() - // Connect if not already connected - if w.conn == nil { + ensureConn := func() error { + if w.conn != nil { + return nil + } conn, err := net.Dial("unix", w.socketPath) if err != nil { return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err) } w.conn = conn + return nil + } + + if err := ensureConn(); err != nil { + return err } data, err := json.Marshal(rec) @@ -120,12 +127,18 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error { // Add newline for line-based protocols data = append(data, '\n') - _, err = w.conn.Write(data) - if err != nil { - // Connection failed, try to reconnect - w.conn.Close() + if _, err = w.conn.Write(data); err != nil { + _ = w.conn.Close() w.conn = nil - return fmt.Errorf("failed to write to socket: %w", err) + + if err2 := ensureConn(); err2 != nil { + return fmt.Errorf("failed to write to socket and reconnect failed: %w", err2) + } + if _, err2 := w.conn.Write(data); err2 != nil { + _ = w.conn.Close() + w.conn = nil + return fmt.Errorf("failed to write to socket after reconnect: %w", err2) + } } return nil diff --git a/internal/output/writers_test.go b/internal/output/writers_test.go index f19458f..a4d9b4c 100644 --- a/internal/output/writers_test.go +++ b/internal/output/writers_test.go @@ -1,10 +1,15 @@ package output import ( + "bufio" "bytes" "encoding/json" + "net" "os" + "path/filepath" + "sync" "testing" + "time" "ja4sentinel/api" ) @@ -102,7 +107,6 @@ func TestMultiWriter(t *testing.T) { defer fileWriter.Close() multiWriter.Add(fileWriter) - multiWriter.Add(NewStdoutWriter()) rec := api.LogRecord{ SrcIP: "192.168.1.1", @@ -233,3 +237,123 @@ func TestUnixSocketWriter(t *testing.T) { writer.Close() } + +type unixTestServer struct { + listener net.Listener + received chan string + mu sync.Mutex + conns map[net.Conn]struct{} +} + +func newUnixTestServer(path string) (*unixTestServer, error) { + _ = os.Remove(path) + ln, err := net.Listen("unix", path) + if err != nil { + return nil, err + } + + s := &unixTestServer{ + listener: ln, + received: make(chan string, 10), + conns: make(map[net.Conn]struct{}), + } + + go s.serve() + return s, nil +} + +func (s *unixTestServer) serve() { + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + + s.mu.Lock() + s.conns[conn] = struct{}{} + s.mu.Unlock() + + go func(c net.Conn) { + defer func() { + s.mu.Lock() + delete(s.conns, c) + s.mu.Unlock() + _ = c.Close() + }() + + scanner := bufio.NewScanner(c) + for scanner.Scan() { + s.received <- scanner.Text() + } + }(conn) + } +} + +func (s *unixTestServer) close(path string) { + _ = s.listener.Close() + + s.mu.Lock() + for c := range s.conns { + _ = c.Close() + } + s.mu.Unlock() + + _ = os.Remove(path) +} + +func TestUnixSocketWriter_ReconnectAndWrite(t *testing.T) { + socketPath := filepath.Join(t.TempDir(), "ja4sentinel.sock") + + server1, err := newUnixTestServer(socketPath) + if err != nil { + t.Fatalf("failed to start first unix test server: %v", err) + } + + writer, err := NewUnixSocketWriter(socketPath) + if err != nil { + t.Fatalf("NewUnixSocketWriter() error = %v", err) + } + defer writer.Close() + + rec1 := api.LogRecord{ + SrcIP: "192.168.1.1", + SrcPort: 11111, + DstIP: "10.0.0.1", + DstPort: 443, + JA4: "first", + } + if err := writer.Write(rec1); err != nil { + t.Fatalf("first Write() error = %v", err) + } + + select { + case <-server1.received: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting first message on unix socket") + } + + server1.close(socketPath) + + server2, err := newUnixTestServer(socketPath) + if err != nil { + t.Fatalf("failed to restart unix test server: %v", err) + } + defer server2.close(socketPath) + + rec2 := api.LogRecord{ + SrcIP: "192.168.1.2", + SrcPort: 22222, + DstIP: "10.0.0.2", + DstPort: 443, + JA4: "second", + } + if err := writer.Write(rec2); err != nil { + t.Fatalf("second Write() after reconnect error = %v", err) + } + + select { + case <-server2.received: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting second message after reconnect") + } +} diff --git a/internal/tlsparse/parser.go b/internal/tlsparse/parser.go index 3355ab2..8ee11f4 100644 --- a/internal/tlsparse/parser.go +++ b/internal/tlsparse/parser.go @@ -46,6 +46,7 @@ type ParserImpl struct { flowTimeout time.Duration cleanupDone chan struct{} cleanupClose chan struct{} + closeOnce sync.Once } // NewParser creates a new TLS parser with connection state tracking @@ -260,8 +261,10 @@ func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, d // Close cleans up the parser and stops background goroutines func (p *ParserImpl) Close() error { - close(p.cleanupClose) - <-p.cleanupDone + p.closeOnce.Do(func() { + close(p.cleanupClose) + <-p.cleanupDone + }) return nil } @@ -296,8 +299,12 @@ func extractTCPMeta(tcp *layers.TCP) api.TCPMeta { for _, opt := range tcp.Options { switch opt.OptionType { case layers.TCPOptionKindMSS: - meta.MSS = binary.BigEndian.Uint16(opt.OptionData) - meta.Options = append(meta.Options, "MSS") + if len(opt.OptionData) >= 2 { + meta.MSS = binary.BigEndian.Uint16(opt.OptionData[:2]) + meta.Options = append(meta.Options, "MSS") + } else { + meta.Options = append(meta.Options, "MSS_INVALID") + } case layers.TCPOptionKindWindowScale: if len(opt.OptionData) > 0 { meta.WindowScale = opt.OptionData[0] diff --git a/internal/tlsparse/parser_test.go b/internal/tlsparse/parser_test.go index ce0bd68..7244390 100644 --- a/internal/tlsparse/parser_test.go +++ b/internal/tlsparse/parser_test.go @@ -251,3 +251,39 @@ func TestParserConnectionStateTracking(t *testing.T) { t.Error("IsClientHello() should return true for valid ClientHello") } } + +func TestParserClose_Idempotent(t *testing.T) { + parser := NewParser() + + if err := parser.Close(); err != nil { + t.Fatalf("first Close() error = %v", err) + } + if err := parser.Close(); err != nil { + t.Fatalf("second Close() error = %v", err) + } +} + +func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) { + tcp := &layers.TCP{ + Window: 1234, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionData: []byte{0x05}, // malformed (1 byte instead of 2) + }, + }, + } + + meta := extractTCPMeta(tcp) + + found := false + for _, opt := range meta.Options { + if opt == "MSS_INVALID" { + found = true + break + } + } + if !found { + t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options) + } +}