From c7e8fe874fdc68c4e67ace10d1287edc7480be33 Mon Sep 17 00:00:00 2001 From: Jacquin Antoine Date: Sat, 28 Feb 2026 20:01:39 +0100 Subject: [PATCH] fix: renforcer limites TLS, timeouts socket et validation config Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) --- cmd/ja4sentinel/main.go | 16 ++- internal/config/loader.go | 55 ++++++- internal/config/loader_test.go | 236 ++++++++++++++++++++++++++++--- internal/output/writers.go | 55 ++++--- internal/output/writers_test.go | 108 ++++++++++++++ internal/tlsparse/parser.go | 56 +++++--- internal/tlsparse/parser_test.go | 156 ++++++++++++++++++++ 7 files changed, 618 insertions(+), 64 deletions(-) diff --git a/cmd/ja4sentinel/main.go b/cmd/ja4sentinel/main.go index a50f707..7bed7e0 100644 --- a/cmd/ja4sentinel/main.go +++ b/cmd/ja4sentinel/main.go @@ -42,7 +42,7 @@ func main() { appLogger := loggerFactory.NewDefaultLogger() appLogger.Info("main", "Starting ja4sentinel", map[string]string{ - "version": Version, + "version": Version, "build_time": BuildTime, "git_commit": GitCommit, }) @@ -58,7 +58,7 @@ func main() { } appLogger.Info("main", "Configuration loaded", map[string]string{ - "interface": appConfig.Core.Interface, + "interface": appConfig.Core.Interface, "listen_ports": formatPorts(appConfig.Core.ListenPorts), }) @@ -127,9 +127,9 @@ func main() { } appLogger.Debug("tlsparse", "ClientHello extracted", map[string]string{ - "src_ip": clientHello.SrcIP, + "src_ip": clientHello.SrcIP, "src_port": fmt.Sprintf("%d", clientHello.SrcPort), - "dst_ip": clientHello.DstIP, + "dst_ip": clientHello.DstIP, "dst_port": fmt.Sprintf("%d", clientHello.DstPort), }) @@ -191,7 +191,13 @@ func main() { }) } - if closer, ok := outputWriter.(interface{ Close() error }); ok { + if mw, ok := outputWriter.(interface{ CloseAll() error }); ok { + if err := mw.CloseAll(); err != nil { + appLogger.Error("main", "Failed to close output writers", map[string]string{ + "error": err.Error(), + }) + } + } else if closer, ok := outputWriter.(interface{ Close() error }); ok { if err := closer.Close(); err != nil { appLogger.Error("main", "Failed to close output writer", map[string]string{ "error": err.Error(), diff --git a/internal/config/loader.go b/internal/config/loader.go index 51d6ac7..e345d3e 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -38,7 +38,7 @@ func (l *LoaderImpl) Load() (api.AppConfig, error) { fileConfig, err := l.loadFromFile(path) if err == nil { config = mergeConfigs(config, fileConfig) - } else if !( !explicit && errors.Is(err, os.ErrNotExist)) { + } else if !(!explicit && errors.Is(err, os.ErrNotExist)) { return config, fmt.Errorf("failed to load config file: %w", err) } @@ -115,6 +115,7 @@ func parsePorts(s string) []uint16 { parts := strings.Split(s, ",") ports := make([]uint16, 0, len(parts)) + seen := make(map[uint16]struct{}, len(parts)) for _, part := range parts { part = strings.TrimSpace(part) @@ -123,9 +124,19 @@ func parsePorts(s string) []uint16 { } port, err := strconv.ParseUint(part, 10, 16) - if err == nil { - ports = append(ports, uint16(port)) + if err != nil { + continue } + + p := uint16(port) + if p == 0 { + continue + } + if _, exists := seen[p]; exists { + continue + } + seen[p] = struct{}{} + ports = append(ports, p) } return ports @@ -164,19 +175,53 @@ func mergeConfigs(base, override api.AppConfig) api.AppConfig { // validate checks if the configuration is valid func (l *LoaderImpl) validate(config api.AppConfig) error { - if config.Core.Interface == "" { + if strings.TrimSpace(config.Core.Interface) == "" { return fmt.Errorf("interface cannot be empty") } if len(config.Core.ListenPorts) == 0 { return fmt.Errorf("at least one listen port is required") } + for _, p := range config.Core.ListenPorts { + if p == 0 { + return fmt.Errorf("listen port 0 is invalid") + } + } + + if config.Core.FlowTimeoutSec <= 0 || config.Core.FlowTimeoutSec > 300 { + return fmt.Errorf("flow_timeout_sec must be between 1 and 300") + } + + if config.Core.PacketBufferSize <= 0 || config.Core.PacketBufferSize > 1_000_000 { + return fmt.Errorf("packet_buffer_size must be between 1 and 1000000") + } + + allowedTypes := map[string]struct{}{ + "stdout": {}, + "file": {}, + "unix_socket": {}, + } // Validate outputs for i, output := range config.Outputs { - if output.Type == "" { + outputType := strings.TrimSpace(output.Type) + if outputType == "" { return fmt.Errorf("output[%d]: type cannot be empty", i) } + if _, ok := allowedTypes[outputType]; !ok { + return fmt.Errorf("output[%d]: unknown type %q", i, outputType) + } + + switch outputType { + case "file": + if strings.TrimSpace(output.Params["path"]) == "" { + return fmt.Errorf("output[%d]: file output requires non-empty path", i) + } + case "unix_socket": + if strings.TrimSpace(output.Params["socket_path"]) == "" { + return fmt.Errorf("output[%d]: unix_socket output requires non-empty socket_path", i) + } + } } return nil diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 52eda30..fdb4703 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -58,21 +58,39 @@ func TestParsePorts(t *testing.T) { } } +func TestParsePorts_DeduplicateAndIgnoreZero(t *testing.T) { + got := parsePorts("443, 0, 443, 8443") + want := []uint16{443, 8443} + + if len(got) != len(want) { + t.Fatalf("parsePorts() length = %d, want %d (got: %v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("parsePorts()[%d] = %d, want %d", i, got[i], want[i]) + } + } +} + func TestMergeConfigs(t *testing.T) { base := api.AppConfig{ Core: api.Config{ - Interface: "eth0", - ListenPorts: []uint16{443}, - BPFFilter: "", + Interface: "eth0", + ListenPorts: []uint16{443}, + BPFFilter: "", + FlowTimeoutSec: 30, + PacketBufferSize: 1000, }, Outputs: []api.OutputConfig{}, } override := api.AppConfig{ Core: api.Config{ - Interface: "lo", - ListenPorts: []uint16{8443}, - BPFFilter: "tcp", + Interface: "lo", + ListenPorts: []uint16{8443}, + BPFFilter: "tcp", + FlowTimeoutSec: 60, + PacketBufferSize: 2000, }, Outputs: []api.OutputConfig{ {Type: "stdout", Enabled: true}, @@ -93,6 +111,12 @@ func TestMergeConfigs(t *testing.T) { if len(result.Outputs) != 1 { t.Errorf("Outputs length = %v, want 1", len(result.Outputs)) } + if result.Core.FlowTimeoutSec != 60 { + t.Errorf("FlowTimeoutSec = %v, want 60", result.Core.FlowTimeoutSec) + } + if result.Core.PacketBufferSize != 2000 { + t.Errorf("PacketBufferSize = %v, want 2000", result.Core.PacketBufferSize) + } } func TestValidate(t *testing.T) { @@ -107,8 +131,10 @@ func TestValidate(t *testing.T) { name: "valid config", config: api.AppConfig{ Core: api.Config{ - Interface: "eth0", - ListenPorts: []uint16{443}, + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 30, + PacketBufferSize: 1000, }, Outputs: []api.OutputConfig{ {Type: "stdout", Enabled: true}, @@ -120,8 +146,10 @@ func TestValidate(t *testing.T) { name: "empty interface", config: api.AppConfig{ Core: api.Config{ - Interface: "", - ListenPorts: []uint16{443}, + Interface: "", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 30, + PacketBufferSize: 1000, }, }, wantErr: true, @@ -130,8 +158,10 @@ func TestValidate(t *testing.T) { name: "no listen ports", config: api.AppConfig{ Core: api.Config{ - Interface: "eth0", - ListenPorts: []uint16{}, + Interface: "eth0", + ListenPorts: []uint16{}, + FlowTimeoutSec: 30, + PacketBufferSize: 1000, }, }, wantErr: true, @@ -140,8 +170,10 @@ func TestValidate(t *testing.T) { name: "output with empty type", config: api.AppConfig{ Core: api.Config{ - Interface: "eth0", - ListenPorts: []uint16{443}, + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 30, + PacketBufferSize: 1000, }, Outputs: []api.OutputConfig{ {Type: "", Enabled: true}, @@ -149,6 +181,18 @@ func TestValidate(t *testing.T) { }, wantErr: true, }, + { + name: "listen port zero", + config: api.AppConfig{ + Core: api.Config{ + Interface: "eth0", + ListenPorts: []uint16{0}, + FlowTimeoutSec: 30, + PacketBufferSize: 1000, + }, + }, + wantErr: true, + }, } for _, tt := range tests { @@ -161,6 +205,162 @@ func TestValidate(t *testing.T) { } } +func TestValidate_InvalidCoreBounds(t *testing.T) { + loader := &LoaderImpl{} + + tests := []struct { + name string + cfg api.AppConfig + hasErr bool + }{ + { + name: "timeout zero", + cfg: api.AppConfig{ + Core: api.Config{ + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 0, + PacketBufferSize: 1000, + }, + }, + hasErr: true, + }, + { + name: "timeout too high", + cfg: api.AppConfig{ + Core: api.Config{ + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 301, + PacketBufferSize: 1000, + }, + }, + hasErr: true, + }, + { + name: "buffer zero", + cfg: api.AppConfig{ + Core: api.Config{ + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 30, + PacketBufferSize: 0, + }, + }, + hasErr: true, + }, + { + name: "buffer too high", + cfg: api.AppConfig{ + Core: api.Config{ + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 30, + PacketBufferSize: 1_000_001, + }, + }, + hasErr: true, + }, + { + name: "valid bounds", + cfg: api.AppConfig{ + Core: api.Config{ + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 30, + PacketBufferSize: 1000, + }, + Outputs: []api.OutputConfig{ + {Type: "stdout", Enabled: true}, + }, + }, + hasErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := loader.validate(tt.cfg) + if (err != nil) != tt.hasErr { + t.Fatalf("validate() error = %v, wantErr %v", err, tt.hasErr) + } + }) + } +} + +func TestValidate_InvalidOutputs(t *testing.T) { + loader := &LoaderImpl{} + + baseCore := api.Config{ + Interface: "eth0", + ListenPorts: []uint16{443}, + FlowTimeoutSec: 30, + PacketBufferSize: 1000, + } + + tests := []struct { + name string + outputs []api.OutputConfig + wantErr bool + }{ + { + name: "unknown output type", + outputs: []api.OutputConfig{ + {Type: "unknown", Enabled: true}, + }, + wantErr: true, + }, + { + name: "file without path", + outputs: []api.OutputConfig{ + {Type: "file", Enabled: true, Params: map[string]string{}}, + }, + wantErr: true, + }, + { + name: "unix socket without socket_path", + outputs: []api.OutputConfig{ + {Type: "unix_socket", Enabled: true, Params: map[string]string{}}, + }, + wantErr: true, + }, + { + name: "valid file output", + outputs: []api.OutputConfig{ + {Type: "file", Enabled: true, Params: map[string]string{"path": "/tmp/x.log"}}, + }, + wantErr: false, + }, + { + name: "valid unix socket output", + outputs: []api.OutputConfig{ + {Type: "unix_socket", Enabled: true, Params: map[string]string{"socket_path": "/tmp/x.sock"}}, + }, + wantErr: false, + }, + { + name: "valid stdout output", + outputs: []api.OutputConfig{ + {Type: "stdout", Enabled: true}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := api.AppConfig{ + Core: baseCore, + Outputs: tt.outputs, + } + err := loader.validate(cfg) + if (err != nil) != tt.wantErr { + t.Fatalf("validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + func TestLoadFromEnv(t *testing.T) { // Save original env vars origInterface := os.Getenv("JA4SENTINEL_INTERFACE") @@ -195,9 +395,11 @@ func TestLoadFromEnv(t *testing.T) { func TestToJSON(t *testing.T) { config := api.AppConfig{ Core: api.Config{ - Interface: "eth0", - ListenPorts: []uint16{443, 8443}, - BPFFilter: "tcp", + Interface: "eth0", + ListenPorts: []uint16{443, 8443}, + BPFFilter: "tcp", + FlowTimeoutSec: 30, + PacketBufferSize: 1000, }, Outputs: []api.OutputConfig{ {Type: "stdout", Enabled: true, Params: map[string]string{}}, diff --git a/internal/output/writers.go b/internal/output/writers.go index d061693..33f56ee 100644 --- a/internal/output/writers.go +++ b/internal/output/writers.go @@ -8,6 +8,7 @@ import ( "net" "os" "sync" + "time" "ja4sentinel/api" ) @@ -76,25 +77,27 @@ func (w *FileWriter) Close() error { // UnixSocketWriter writes log records to a UNIX socket type UnixSocketWriter struct { - socketPath string - conn net.Conn - mutex sync.Mutex + socketPath string + conn net.Conn + mutex sync.Mutex + dialTimeout time.Duration + writeTimeout time.Duration } // NewUnixSocketWriter creates a new UNIX socket writer func NewUnixSocketWriter(socketPath string) (*UnixSocketWriter, error) { w := &UnixSocketWriter{ - socketPath: socketPath, + socketPath: socketPath, + dialTimeout: 2 * time.Second, + writeTimeout: 2 * time.Second, } // Try to connect (socket may not exist yet) - conn, err := net.Dial("unix", socketPath) - if err != nil { - // Socket doesn't exist yet, we'll try to connect on first write - return w, nil + conn, err := net.DialTimeout("unix", socketPath, w.dialTimeout) + if err == nil { + w.conn = conn } - w.conn = conn return w, nil } @@ -107,7 +110,7 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error { if w.conn != nil { return nil } - conn, err := net.Dial("unix", w.socketPath) + conn, err := net.DialTimeout("unix", w.socketPath, w.dialTimeout) if err != nil { return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err) } @@ -123,22 +126,32 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error { if err != nil { return fmt.Errorf("failed to marshal record: %w", err) } - - // Add newline for line-based protocols data = append(data, '\n') - if _, err = w.conn.Write(data); err != nil { + if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil { + return fmt.Errorf("failed to set write deadline: %w", err) + } + if _, err = w.conn.Write(data); err == nil { + return nil + } + + _ = w.conn.Close() + w.conn = nil + + if errConn := ensureConn(); errConn != nil { + return fmt.Errorf("failed to write to socket and reconnect failed: %w", errConn) + } + + if errDeadline := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); errDeadline != nil { _ = w.conn.Close() w.conn = nil + return fmt.Errorf("failed to set write deadline after reconnect: %w", errDeadline) + } - if err2 := ensureConn(); err2 != nil { - return fmt.Errorf("failed to write to socket and reconnect failed: %w", err2) - } - if _, err2 := w.conn.Write(data); err2 != nil { - _ = w.conn.Close() - w.conn = nil - return fmt.Errorf("failed to write to socket after reconnect: %w", err2) - } + if _, errRetry := w.conn.Write(data); errRetry != nil { + _ = w.conn.Close() + w.conn = nil + return fmt.Errorf("failed to write to socket after reconnect: %w", errRetry) } return nil diff --git a/internal/output/writers_test.go b/internal/output/writers_test.go index a4d9b4c..896d92d 100644 --- a/internal/output/writers_test.go +++ b/internal/output/writers_test.go @@ -4,9 +4,11 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "net" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -238,6 +240,112 @@ func TestUnixSocketWriter(t *testing.T) { writer.Close() } +func TestUnixSocketWriter_Write_NonexistentSocket_ReturnsQuickly(t *testing.T) { + socketPath := filepath.Join(t.TempDir(), "ja4sentinel_missing.sock") + writer, err := NewUnixSocketWriter(socketPath) + if err != nil { + t.Fatalf("NewUnixSocketWriter() error = %v", err) + } + defer writer.Close() + + start := time.Now() + err = writer.Write(api.LogRecord{ + SrcIP: "192.168.1.10", + SrcPort: 44444, + DstIP: "10.0.0.10", + DstPort: 443, + }) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("Write() should fail for non-existent socket") + } + if elapsed >= 3*time.Second { + t.Fatalf("Write() took too long: %v (expected < 3s)", elapsed) + } +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "i/o timeout" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +type mockAddr string + +func (a mockAddr) Network() string { return "unix" } +func (a mockAddr) String() string { return string(a) } + +type mockConn struct { + writeCalls int + closeCalled bool + setWriteDeadlineCalled bool + setReadDeadlineCalled bool + setAnyDeadlineWasCalled bool +} + +func (m *mockConn) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") } + +func (m *mockConn) Write(_ []byte) (int, error) { + m.writeCalls++ + return 0, timeoutError{} +} + +func (m *mockConn) Close() error { + m.closeCalled = true + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { return mockAddr("local") } +func (m *mockConn) RemoteAddr() net.Addr { return mockAddr("remote") } + +func (m *mockConn) SetDeadline(_ time.Time) error { + m.setAnyDeadlineWasCalled = true + return nil +} + +func (m *mockConn) SetReadDeadline(_ time.Time) error { + m.setReadDeadlineCalled = true + return nil +} + +func (m *mockConn) SetWriteDeadline(_ time.Time) error { + m.setWriteDeadlineCalled = true + return nil +} + +func TestUnixSocketWriter_Write_UsesWriteDeadline(t *testing.T) { + mc := &mockConn{} + writer := &UnixSocketWriter{ + socketPath: filepath.Join(t.TempDir(), "missing.sock"), + conn: mc, + dialTimeout: 100 * time.Millisecond, + writeTimeout: 100 * time.Millisecond, + } + + err := writer.Write(api.LogRecord{ + SrcIP: "192.168.1.20", + SrcPort: 55555, + DstIP: "10.0.0.20", + DstPort: 443, + }) + if err == nil { + t.Fatal("Write() should fail because reconnect target does not exist") + } + if !mc.setWriteDeadlineCalled { + t.Fatal("expected SetWriteDeadline to be called before write") + } + if !mc.closeCalled { + t.Fatal("expected connection to be closed after first write failure") + } + if mc.writeCalls != 1 { + t.Fatalf("expected exactly 1 write on initial conn, got %d", mc.writeCalls) + } + if !strings.Contains(err.Error(), "reconnect failed") { + t.Fatalf("expected reconnect failure error, got: %v", err) + } +} + type unixTestServer struct { listener net.Listener received chan string diff --git a/internal/tlsparse/parser.go b/internal/tlsparse/parser.go index ba4f2de..fb6c6e7 100644 --- a/internal/tlsparse/parser.go +++ b/internal/tlsparse/parser.go @@ -42,12 +42,14 @@ type ConnectionFlow struct { // ParserImpl implements the api.Parser interface for TLS parsing type ParserImpl struct { - mu sync.RWMutex - flows map[string]*ConnectionFlow - flowTimeout time.Duration - cleanupDone chan struct{} - cleanupClose chan struct{} - closeOnce sync.Once + mu sync.RWMutex + flows map[string]*ConnectionFlow + flowTimeout time.Duration + cleanupDone chan struct{} + cleanupClose chan struct{} + closeOnce sync.Once + maxTrackedFlows int + maxHelloBufferBytes int } // NewParser creates a new TLS parser with connection state tracking @@ -58,10 +60,12 @@ func NewParser() *ParserImpl { // NewParserWithTimeout creates a new TLS parser with a custom flow timeout func NewParserWithTimeout(timeout time.Duration) *ParserImpl { p := &ParserImpl{ - flows: make(map[string]*ConnectionFlow), - flowTimeout: timeout, - cleanupDone: make(chan struct{}), - cleanupClose: make(chan struct{}), + flows: make(map[string]*ConnectionFlow), + flowTimeout: timeout, + cleanupDone: make(chan struct{}), + cleanupClose: make(chan struct{}), + maxTrackedFlows: 50000, + maxHelloBufferBytes: 256 * 1024, // 256 KiB } go p.cleanupLoop() return p @@ -164,15 +168,26 @@ func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) { return nil, nil // No payload } - // Get or create connection flow key := flowKey(srcIP, srcPort, dstIP, dstPort) + + p.mu.RLock() + _, flowExists := p.flows[key] + p.mu.RUnlock() + + if !flowExists && payload[0] != 22 { + return nil, nil + } + flow := p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta) + if flow == nil { + return nil, nil + } // Check if flow is already done p.mu.RLock() - isDone := flow.State == JA4_DONE + state := flow.State p.mu.RUnlock() - if isDone { + if state == JA4_DONE { return nil, nil // Already processed this flow } @@ -201,8 +216,13 @@ func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) { } // Check for fragmented ClientHello (accumulate segments) - if flow.State == WAIT_CLIENT_HELLO || flow.State == NEW { + if state == WAIT_CLIENT_HELLO || state == NEW { p.mu.Lock() + if len(flow.HelloBuffer)+len(payload) > p.maxHelloBufferBytes { + delete(p.flows, key) + p.mu.Unlock() + return nil, nil + } flow.State = WAIT_CLIENT_HELLO flow.HelloBuffer = append(flow.HelloBuffer, payload...) bufferCopy := make([]byte, len(flow.HelloBuffer)) @@ -246,13 +266,17 @@ func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, d return flow } + if len(p.flows) >= p.maxTrackedFlows { + return nil + } + flow := &ConnectionFlow{ State: NEW, CreatedAt: time.Now(), LastSeen: time.Now(), - SrcIP: srcIP, // Client IP + SrcIP: srcIP, // Client IP SrcPort: srcPort, // Client port - DstIP: dstIP, // Server IP (local machine) + DstIP: dstIP, // Server IP (local machine) DstPort: dstPort, // Server port (local machine) IPMeta: ipMeta, TCPMeta: tcpMeta, diff --git a/internal/tlsparse/parser_test.go b/internal/tlsparse/parser_test.go index ababcb9..08d2828 100644 --- a/internal/tlsparse/parser_test.go +++ b/internal/tlsparse/parser_test.go @@ -1,8 +1,13 @@ package tlsparse import ( + "net" "testing" + "time" + "ja4sentinel/api" + + "github.com/google/gopacket" "github.com/google/gopacket/layers" ) @@ -203,6 +208,8 @@ func createTLSServerHello(version uint16) []byte { func TestNewParser(t *testing.T) { parser := NewParser() + defer parser.Close() + if parser == nil { t.Error("NewParser() returned nil") } @@ -288,3 +295,152 @@ func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) { t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options) } } + +func TestGetOrCreateFlow_RespectsMaxTrackedFlows(t *testing.T) { + parser := NewParser() + defer parser.Close() + + parser.maxTrackedFlows = 1 + + flow1 := parser.getOrCreateFlow( + flowKey("192.168.1.1", 12345, "10.0.0.1", 443), + "192.168.1.1", 12345, "10.0.0.1", 443, + api.IPMeta{}, api.TCPMeta{}, + ) + if flow1 == nil { + t.Fatal("first flow should be created") + } + + flow2 := parser.getOrCreateFlow( + flowKey("192.168.1.2", 12346, "10.0.0.1", 443), + "192.168.1.2", 12346, "10.0.0.1", 443, + api.IPMeta{}, api.TCPMeta{}, + ) + if flow2 != nil { + t.Fatal("second flow should be nil when maxTrackedFlows is reached") + } +} + +func TestProcess_DropsWhenHelloBufferExceedsLimit(t *testing.T) { + parser := NewParserWithTimeout(30 * time.Second) + defer parser.Close() + + parser.maxHelloBufferBytes = 10 + + srcIP := "192.168.1.10" + dstIP := "10.0.0.1" + srcPort := uint16(12345) + dstPort := uint16(443) + + // TLS-like payload, but intentionally incomplete to trigger accumulation. + payloadChunk := []byte{0x16, 0x03, 0x03, 0x00, 0x20, 0x01} // len = 6 + + pkt1 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk) + ch, err := parser.Process(pkt1) + if err != nil { + t.Fatalf("first Process() error = %v", err) + } + if ch != nil { + t.Fatal("first Process() should not return complete ClientHello") + } + + key := flowKey(srcIP, srcPort, dstIP, dstPort) + + parser.mu.RLock() + _, existsAfterFirst := parser.flows[key] + parser.mu.RUnlock() + if !existsAfterFirst { + t.Fatal("flow should exist after first chunk") + } + + pkt2 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk) + ch, err = parser.Process(pkt2) + if err != nil { + t.Fatalf("second Process() error = %v", err) + } + if ch != nil { + t.Fatal("second Process() should not return ClientHello") + } + + parser.mu.RLock() + _, existsAfterSecond := parser.flows[key] + parser.mu.RUnlock() + if existsAfterSecond { + t.Fatal("flow should be removed when hello buffer exceeds maxHelloBufferBytes") + } +} + +func TestProcess_NonTLSNewFlowNotTracked(t *testing.T) { + parser := NewParser() + defer parser.Close() + + srcIP := "192.168.1.20" + dstIP := "10.0.0.2" + srcPort := uint16(23456) + dstPort := uint16(443) + + // Non-TLS content type (not 22) + payload := []byte{0x17, 0x03, 0x03, 0x00, 0x05, 0x00} + + pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payload) + ch, err := parser.Process(pkt) + if err != nil { + t.Fatalf("Process() error = %v", err) + } + if ch != nil { + t.Fatal("Process() should return nil for non-TLS new flow") + } + + key := flowKey(srcIP, srcPort, dstIP, dstPort) + + parser.mu.RLock() + _, exists := parser.flows[key] + parser.mu.RUnlock() + if exists { + t.Fatal("non-TLS new flow should not be tracked") + } +} + +func buildRawPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket { + t.Helper() + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: net.ParseIP(srcIP).To4(), + DstIP: net.ParseIP(dstIP).To4(), + Protocol: layers.IPProtocolTCP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Seq: 1, + ACK: true, + Window: 65535, + } + if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { + t.Fatalf("SetNetworkLayerForChecksum() error = %v", err) + } + + eth := &layers.Ethernet{ + SrcMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, + DstMAC: net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + EthernetType: layers.EthernetTypeIPv4, + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + if err := gopacket.SerializeLayers(buf, opts, eth, ip, tcp, gopacket.Payload(payload)); err != nil { + t.Fatalf("SerializeLayers() error = %v", err) + } + + return api.RawPacket{ + Data: buf.Bytes(), + Timestamp: time.Now().UnixNano(), + } +}