fix: renforcer limites TLS, timeouts socket et validation config
Some checks failed
Build RPM Package / Build RPM Packages (CentOS 7, Rocky 8/9/10) (push) Has been cancelled

Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) <aider@aider.chat>
This commit is contained in:
Jacquin Antoine
2026-02-28 20:01:39 +01:00
parent b15c20b4cc
commit c7e8fe874f
7 changed files with 618 additions and 64 deletions

View File

@ -42,7 +42,7 @@ func main() {
appLogger := loggerFactory.NewDefaultLogger() appLogger := loggerFactory.NewDefaultLogger()
appLogger.Info("main", "Starting ja4sentinel", map[string]string{ appLogger.Info("main", "Starting ja4sentinel", map[string]string{
"version": Version, "version": Version,
"build_time": BuildTime, "build_time": BuildTime,
"git_commit": GitCommit, "git_commit": GitCommit,
}) })
@ -58,7 +58,7 @@ func main() {
} }
appLogger.Info("main", "Configuration loaded", map[string]string{ appLogger.Info("main", "Configuration loaded", map[string]string{
"interface": appConfig.Core.Interface, "interface": appConfig.Core.Interface,
"listen_ports": formatPorts(appConfig.Core.ListenPorts), "listen_ports": formatPorts(appConfig.Core.ListenPorts),
}) })
@ -127,9 +127,9 @@ func main() {
} }
appLogger.Debug("tlsparse", "ClientHello extracted", map[string]string{ appLogger.Debug("tlsparse", "ClientHello extracted", map[string]string{
"src_ip": clientHello.SrcIP, "src_ip": clientHello.SrcIP,
"src_port": fmt.Sprintf("%d", clientHello.SrcPort), "src_port": fmt.Sprintf("%d", clientHello.SrcPort),
"dst_ip": clientHello.DstIP, "dst_ip": clientHello.DstIP,
"dst_port": fmt.Sprintf("%d", clientHello.DstPort), "dst_port": fmt.Sprintf("%d", clientHello.DstPort),
}) })
@ -191,7 +191,13 @@ func main() {
}) })
} }
if closer, ok := outputWriter.(interface{ Close() error }); ok { if mw, ok := outputWriter.(interface{ CloseAll() error }); ok {
if err := mw.CloseAll(); err != nil {
appLogger.Error("main", "Failed to close output writers", map[string]string{
"error": err.Error(),
})
}
} else if closer, ok := outputWriter.(interface{ Close() error }); ok {
if err := closer.Close(); err != nil { if err := closer.Close(); err != nil {
appLogger.Error("main", "Failed to close output writer", map[string]string{ appLogger.Error("main", "Failed to close output writer", map[string]string{
"error": err.Error(), "error": err.Error(),

View File

@ -38,7 +38,7 @@ func (l *LoaderImpl) Load() (api.AppConfig, error) {
fileConfig, err := l.loadFromFile(path) fileConfig, err := l.loadFromFile(path)
if err == nil { if err == nil {
config = mergeConfigs(config, fileConfig) config = mergeConfigs(config, fileConfig)
} else if !( !explicit && errors.Is(err, os.ErrNotExist)) { } else if !(!explicit && errors.Is(err, os.ErrNotExist)) {
return config, fmt.Errorf("failed to load config file: %w", err) return config, fmt.Errorf("failed to load config file: %w", err)
} }
@ -115,6 +115,7 @@ func parsePorts(s string) []uint16 {
parts := strings.Split(s, ",") parts := strings.Split(s, ",")
ports := make([]uint16, 0, len(parts)) ports := make([]uint16, 0, len(parts))
seen := make(map[uint16]struct{}, len(parts))
for _, part := range parts { for _, part := range parts {
part = strings.TrimSpace(part) part = strings.TrimSpace(part)
@ -123,9 +124,19 @@ func parsePorts(s string) []uint16 {
} }
port, err := strconv.ParseUint(part, 10, 16) port, err := strconv.ParseUint(part, 10, 16)
if err == nil { if err != nil {
ports = append(ports, uint16(port)) continue
} }
p := uint16(port)
if p == 0 {
continue
}
if _, exists := seen[p]; exists {
continue
}
seen[p] = struct{}{}
ports = append(ports, p)
} }
return ports return ports
@ -164,19 +175,53 @@ func mergeConfigs(base, override api.AppConfig) api.AppConfig {
// validate checks if the configuration is valid // validate checks if the configuration is valid
func (l *LoaderImpl) validate(config api.AppConfig) error { func (l *LoaderImpl) validate(config api.AppConfig) error {
if config.Core.Interface == "" { if strings.TrimSpace(config.Core.Interface) == "" {
return fmt.Errorf("interface cannot be empty") return fmt.Errorf("interface cannot be empty")
} }
if len(config.Core.ListenPorts) == 0 { if len(config.Core.ListenPorts) == 0 {
return fmt.Errorf("at least one listen port is required") return fmt.Errorf("at least one listen port is required")
} }
for _, p := range config.Core.ListenPorts {
if p == 0 {
return fmt.Errorf("listen port 0 is invalid")
}
}
if config.Core.FlowTimeoutSec <= 0 || config.Core.FlowTimeoutSec > 300 {
return fmt.Errorf("flow_timeout_sec must be between 1 and 300")
}
if config.Core.PacketBufferSize <= 0 || config.Core.PacketBufferSize > 1_000_000 {
return fmt.Errorf("packet_buffer_size must be between 1 and 1000000")
}
allowedTypes := map[string]struct{}{
"stdout": {},
"file": {},
"unix_socket": {},
}
// Validate outputs // Validate outputs
for i, output := range config.Outputs { for i, output := range config.Outputs {
if output.Type == "" { outputType := strings.TrimSpace(output.Type)
if outputType == "" {
return fmt.Errorf("output[%d]: type cannot be empty", i) return fmt.Errorf("output[%d]: type cannot be empty", i)
} }
if _, ok := allowedTypes[outputType]; !ok {
return fmt.Errorf("output[%d]: unknown type %q", i, outputType)
}
switch outputType {
case "file":
if strings.TrimSpace(output.Params["path"]) == "" {
return fmt.Errorf("output[%d]: file output requires non-empty path", i)
}
case "unix_socket":
if strings.TrimSpace(output.Params["socket_path"]) == "" {
return fmt.Errorf("output[%d]: unix_socket output requires non-empty socket_path", i)
}
}
} }
return nil return nil

View File

@ -58,21 +58,39 @@ func TestParsePorts(t *testing.T) {
} }
} }
func TestParsePorts_DeduplicateAndIgnoreZero(t *testing.T) {
got := parsePorts("443, 0, 443, 8443")
want := []uint16{443, 8443}
if len(got) != len(want) {
t.Fatalf("parsePorts() length = %d, want %d (got: %v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("parsePorts()[%d] = %d, want %d", i, got[i], want[i])
}
}
}
func TestMergeConfigs(t *testing.T) { func TestMergeConfigs(t *testing.T) {
base := api.AppConfig{ base := api.AppConfig{
Core: api.Config{ Core: api.Config{
Interface: "eth0", Interface: "eth0",
ListenPorts: []uint16{443}, ListenPorts: []uint16{443},
BPFFilter: "", BPFFilter: "",
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
}, },
Outputs: []api.OutputConfig{}, Outputs: []api.OutputConfig{},
} }
override := api.AppConfig{ override := api.AppConfig{
Core: api.Config{ Core: api.Config{
Interface: "lo", Interface: "lo",
ListenPorts: []uint16{8443}, ListenPorts: []uint16{8443},
BPFFilter: "tcp", BPFFilter: "tcp",
FlowTimeoutSec: 60,
PacketBufferSize: 2000,
}, },
Outputs: []api.OutputConfig{ Outputs: []api.OutputConfig{
{Type: "stdout", Enabled: true}, {Type: "stdout", Enabled: true},
@ -93,6 +111,12 @@ func TestMergeConfigs(t *testing.T) {
if len(result.Outputs) != 1 { if len(result.Outputs) != 1 {
t.Errorf("Outputs length = %v, want 1", len(result.Outputs)) t.Errorf("Outputs length = %v, want 1", len(result.Outputs))
} }
if result.Core.FlowTimeoutSec != 60 {
t.Errorf("FlowTimeoutSec = %v, want 60", result.Core.FlowTimeoutSec)
}
if result.Core.PacketBufferSize != 2000 {
t.Errorf("PacketBufferSize = %v, want 2000", result.Core.PacketBufferSize)
}
} }
func TestValidate(t *testing.T) { func TestValidate(t *testing.T) {
@ -107,8 +131,10 @@ func TestValidate(t *testing.T) {
name: "valid config", name: "valid config",
config: api.AppConfig{ config: api.AppConfig{
Core: api.Config{ Core: api.Config{
Interface: "eth0", Interface: "eth0",
ListenPorts: []uint16{443}, ListenPorts: []uint16{443},
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
}, },
Outputs: []api.OutputConfig{ Outputs: []api.OutputConfig{
{Type: "stdout", Enabled: true}, {Type: "stdout", Enabled: true},
@ -120,8 +146,10 @@ func TestValidate(t *testing.T) {
name: "empty interface", name: "empty interface",
config: api.AppConfig{ config: api.AppConfig{
Core: api.Config{ Core: api.Config{
Interface: "", Interface: "",
ListenPorts: []uint16{443}, ListenPorts: []uint16{443},
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
}, },
}, },
wantErr: true, wantErr: true,
@ -130,8 +158,10 @@ func TestValidate(t *testing.T) {
name: "no listen ports", name: "no listen ports",
config: api.AppConfig{ config: api.AppConfig{
Core: api.Config{ Core: api.Config{
Interface: "eth0", Interface: "eth0",
ListenPorts: []uint16{}, ListenPorts: []uint16{},
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
}, },
}, },
wantErr: true, wantErr: true,
@ -140,8 +170,10 @@ func TestValidate(t *testing.T) {
name: "output with empty type", name: "output with empty type",
config: api.AppConfig{ config: api.AppConfig{
Core: api.Config{ Core: api.Config{
Interface: "eth0", Interface: "eth0",
ListenPorts: []uint16{443}, ListenPorts: []uint16{443},
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
}, },
Outputs: []api.OutputConfig{ Outputs: []api.OutputConfig{
{Type: "", Enabled: true}, {Type: "", Enabled: true},
@ -149,6 +181,18 @@ func TestValidate(t *testing.T) {
}, },
wantErr: true, wantErr: true,
}, },
{
name: "listen port zero",
config: api.AppConfig{
Core: api.Config{
Interface: "eth0",
ListenPorts: []uint16{0},
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
},
},
wantErr: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
@ -161,6 +205,162 @@ func TestValidate(t *testing.T) {
} }
} }
func TestValidate_InvalidCoreBounds(t *testing.T) {
loader := &LoaderImpl{}
tests := []struct {
name string
cfg api.AppConfig
hasErr bool
}{
{
name: "timeout zero",
cfg: api.AppConfig{
Core: api.Config{
Interface: "eth0",
ListenPorts: []uint16{443},
FlowTimeoutSec: 0,
PacketBufferSize: 1000,
},
},
hasErr: true,
},
{
name: "timeout too high",
cfg: api.AppConfig{
Core: api.Config{
Interface: "eth0",
ListenPorts: []uint16{443},
FlowTimeoutSec: 301,
PacketBufferSize: 1000,
},
},
hasErr: true,
},
{
name: "buffer zero",
cfg: api.AppConfig{
Core: api.Config{
Interface: "eth0",
ListenPorts: []uint16{443},
FlowTimeoutSec: 30,
PacketBufferSize: 0,
},
},
hasErr: true,
},
{
name: "buffer too high",
cfg: api.AppConfig{
Core: api.Config{
Interface: "eth0",
ListenPorts: []uint16{443},
FlowTimeoutSec: 30,
PacketBufferSize: 1_000_001,
},
},
hasErr: true,
},
{
name: "valid bounds",
cfg: api.AppConfig{
Core: api.Config{
Interface: "eth0",
ListenPorts: []uint16{443},
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
},
Outputs: []api.OutputConfig{
{Type: "stdout", Enabled: true},
},
},
hasErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := loader.validate(tt.cfg)
if (err != nil) != tt.hasErr {
t.Fatalf("validate() error = %v, wantErr %v", err, tt.hasErr)
}
})
}
}
func TestValidate_InvalidOutputs(t *testing.T) {
loader := &LoaderImpl{}
baseCore := api.Config{
Interface: "eth0",
ListenPorts: []uint16{443},
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
}
tests := []struct {
name string
outputs []api.OutputConfig
wantErr bool
}{
{
name: "unknown output type",
outputs: []api.OutputConfig{
{Type: "unknown", Enabled: true},
},
wantErr: true,
},
{
name: "file without path",
outputs: []api.OutputConfig{
{Type: "file", Enabled: true, Params: map[string]string{}},
},
wantErr: true,
},
{
name: "unix socket without socket_path",
outputs: []api.OutputConfig{
{Type: "unix_socket", Enabled: true, Params: map[string]string{}},
},
wantErr: true,
},
{
name: "valid file output",
outputs: []api.OutputConfig{
{Type: "file", Enabled: true, Params: map[string]string{"path": "/tmp/x.log"}},
},
wantErr: false,
},
{
name: "valid unix socket output",
outputs: []api.OutputConfig{
{Type: "unix_socket", Enabled: true, Params: map[string]string{"socket_path": "/tmp/x.sock"}},
},
wantErr: false,
},
{
name: "valid stdout output",
outputs: []api.OutputConfig{
{Type: "stdout", Enabled: true},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := api.AppConfig{
Core: baseCore,
Outputs: tt.outputs,
}
err := loader.validate(cfg)
if (err != nil) != tt.wantErr {
t.Fatalf("validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestLoadFromEnv(t *testing.T) { func TestLoadFromEnv(t *testing.T) {
// Save original env vars // Save original env vars
origInterface := os.Getenv("JA4SENTINEL_INTERFACE") origInterface := os.Getenv("JA4SENTINEL_INTERFACE")
@ -195,9 +395,11 @@ func TestLoadFromEnv(t *testing.T) {
func TestToJSON(t *testing.T) { func TestToJSON(t *testing.T) {
config := api.AppConfig{ config := api.AppConfig{
Core: api.Config{ Core: api.Config{
Interface: "eth0", Interface: "eth0",
ListenPorts: []uint16{443, 8443}, ListenPorts: []uint16{443, 8443},
BPFFilter: "tcp", BPFFilter: "tcp",
FlowTimeoutSec: 30,
PacketBufferSize: 1000,
}, },
Outputs: []api.OutputConfig{ Outputs: []api.OutputConfig{
{Type: "stdout", Enabled: true, Params: map[string]string{}}, {Type: "stdout", Enabled: true, Params: map[string]string{}},

View File

@ -8,6 +8,7 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"time"
"ja4sentinel/api" "ja4sentinel/api"
) )
@ -76,25 +77,27 @@ func (w *FileWriter) Close() error {
// UnixSocketWriter writes log records to a UNIX socket // UnixSocketWriter writes log records to a UNIX socket
type UnixSocketWriter struct { type UnixSocketWriter struct {
socketPath string socketPath string
conn net.Conn conn net.Conn
mutex sync.Mutex mutex sync.Mutex
dialTimeout time.Duration
writeTimeout time.Duration
} }
// NewUnixSocketWriter creates a new UNIX socket writer // NewUnixSocketWriter creates a new UNIX socket writer
func NewUnixSocketWriter(socketPath string) (*UnixSocketWriter, error) { func NewUnixSocketWriter(socketPath string) (*UnixSocketWriter, error) {
w := &UnixSocketWriter{ w := &UnixSocketWriter{
socketPath: socketPath, socketPath: socketPath,
dialTimeout: 2 * time.Second,
writeTimeout: 2 * time.Second,
} }
// Try to connect (socket may not exist yet) // Try to connect (socket may not exist yet)
conn, err := net.Dial("unix", socketPath) conn, err := net.DialTimeout("unix", socketPath, w.dialTimeout)
if err != nil { if err == nil {
// Socket doesn't exist yet, we'll try to connect on first write w.conn = conn
return w, nil
} }
w.conn = conn
return w, nil return w, nil
} }
@ -107,7 +110,7 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
if w.conn != nil { if w.conn != nil {
return nil return nil
} }
conn, err := net.Dial("unix", w.socketPath) conn, err := net.DialTimeout("unix", w.socketPath, w.dialTimeout)
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err) return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
} }
@ -123,22 +126,32 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal record: %w", err) return fmt.Errorf("failed to marshal record: %w", err)
} }
// Add newline for line-based protocols
data = append(data, '\n') data = append(data, '\n')
if _, err = w.conn.Write(data); err != nil { if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
if _, err = w.conn.Write(data); err == nil {
return nil
}
_ = w.conn.Close()
w.conn = nil
if errConn := ensureConn(); errConn != nil {
return fmt.Errorf("failed to write to socket and reconnect failed: %w", errConn)
}
if errDeadline := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); errDeadline != nil {
_ = w.conn.Close() _ = w.conn.Close()
w.conn = nil w.conn = nil
return fmt.Errorf("failed to set write deadline after reconnect: %w", errDeadline)
}
if err2 := ensureConn(); err2 != nil { if _, errRetry := w.conn.Write(data); errRetry != nil {
return fmt.Errorf("failed to write to socket and reconnect failed: %w", err2) _ = w.conn.Close()
} w.conn = nil
if _, err2 := w.conn.Write(data); err2 != nil { return fmt.Errorf("failed to write to socket after reconnect: %w", errRetry)
_ = w.conn.Close()
w.conn = nil
return fmt.Errorf("failed to write to socket after reconnect: %w", err2)
}
} }
return nil return nil

View File

@ -4,9 +4,11 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -238,6 +240,112 @@ func TestUnixSocketWriter(t *testing.T) {
writer.Close() writer.Close()
} }
func TestUnixSocketWriter_Write_NonexistentSocket_ReturnsQuickly(t *testing.T) {
socketPath := filepath.Join(t.TempDir(), "ja4sentinel_missing.sock")
writer, err := NewUnixSocketWriter(socketPath)
if err != nil {
t.Fatalf("NewUnixSocketWriter() error = %v", err)
}
defer writer.Close()
start := time.Now()
err = writer.Write(api.LogRecord{
SrcIP: "192.168.1.10",
SrcPort: 44444,
DstIP: "10.0.0.10",
DstPort: 443,
})
elapsed := time.Since(start)
if err == nil {
t.Fatal("Write() should fail for non-existent socket")
}
if elapsed >= 3*time.Second {
t.Fatalf("Write() took too long: %v (expected < 3s)", elapsed)
}
}
type timeoutError struct{}
func (timeoutError) Error() string { return "i/o timeout" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
type mockAddr string
func (a mockAddr) Network() string { return "unix" }
func (a mockAddr) String() string { return string(a) }
type mockConn struct {
writeCalls int
closeCalled bool
setWriteDeadlineCalled bool
setReadDeadlineCalled bool
setAnyDeadlineWasCalled bool
}
func (m *mockConn) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") }
func (m *mockConn) Write(_ []byte) (int, error) {
m.writeCalls++
return 0, timeoutError{}
}
func (m *mockConn) Close() error {
m.closeCalled = true
return nil
}
func (m *mockConn) LocalAddr() net.Addr { return mockAddr("local") }
func (m *mockConn) RemoteAddr() net.Addr { return mockAddr("remote") }
func (m *mockConn) SetDeadline(_ time.Time) error {
m.setAnyDeadlineWasCalled = true
return nil
}
func (m *mockConn) SetReadDeadline(_ time.Time) error {
m.setReadDeadlineCalled = true
return nil
}
func (m *mockConn) SetWriteDeadline(_ time.Time) error {
m.setWriteDeadlineCalled = true
return nil
}
func TestUnixSocketWriter_Write_UsesWriteDeadline(t *testing.T) {
mc := &mockConn{}
writer := &UnixSocketWriter{
socketPath: filepath.Join(t.TempDir(), "missing.sock"),
conn: mc,
dialTimeout: 100 * time.Millisecond,
writeTimeout: 100 * time.Millisecond,
}
err := writer.Write(api.LogRecord{
SrcIP: "192.168.1.20",
SrcPort: 55555,
DstIP: "10.0.0.20",
DstPort: 443,
})
if err == nil {
t.Fatal("Write() should fail because reconnect target does not exist")
}
if !mc.setWriteDeadlineCalled {
t.Fatal("expected SetWriteDeadline to be called before write")
}
if !mc.closeCalled {
t.Fatal("expected connection to be closed after first write failure")
}
if mc.writeCalls != 1 {
t.Fatalf("expected exactly 1 write on initial conn, got %d", mc.writeCalls)
}
if !strings.Contains(err.Error(), "reconnect failed") {
t.Fatalf("expected reconnect failure error, got: %v", err)
}
}
type unixTestServer struct { type unixTestServer struct {
listener net.Listener listener net.Listener
received chan string received chan string

View File

@ -42,12 +42,14 @@ type ConnectionFlow struct {
// ParserImpl implements the api.Parser interface for TLS parsing // ParserImpl implements the api.Parser interface for TLS parsing
type ParserImpl struct { type ParserImpl struct {
mu sync.RWMutex mu sync.RWMutex
flows map[string]*ConnectionFlow flows map[string]*ConnectionFlow
flowTimeout time.Duration flowTimeout time.Duration
cleanupDone chan struct{} cleanupDone chan struct{}
cleanupClose chan struct{} cleanupClose chan struct{}
closeOnce sync.Once closeOnce sync.Once
maxTrackedFlows int
maxHelloBufferBytes int
} }
// NewParser creates a new TLS parser with connection state tracking // NewParser creates a new TLS parser with connection state tracking
@ -58,10 +60,12 @@ func NewParser() *ParserImpl {
// NewParserWithTimeout creates a new TLS parser with a custom flow timeout // NewParserWithTimeout creates a new TLS parser with a custom flow timeout
func NewParserWithTimeout(timeout time.Duration) *ParserImpl { func NewParserWithTimeout(timeout time.Duration) *ParserImpl {
p := &ParserImpl{ p := &ParserImpl{
flows: make(map[string]*ConnectionFlow), flows: make(map[string]*ConnectionFlow),
flowTimeout: timeout, flowTimeout: timeout,
cleanupDone: make(chan struct{}), cleanupDone: make(chan struct{}),
cleanupClose: make(chan struct{}), cleanupClose: make(chan struct{}),
maxTrackedFlows: 50000,
maxHelloBufferBytes: 256 * 1024, // 256 KiB
} }
go p.cleanupLoop() go p.cleanupLoop()
return p return p
@ -164,15 +168,26 @@ func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) {
return nil, nil // No payload return nil, nil // No payload
} }
// Get or create connection flow
key := flowKey(srcIP, srcPort, dstIP, dstPort) key := flowKey(srcIP, srcPort, dstIP, dstPort)
p.mu.RLock()
_, flowExists := p.flows[key]
p.mu.RUnlock()
if !flowExists && payload[0] != 22 {
return nil, nil
}
flow := p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta) flow := p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta)
if flow == nil {
return nil, nil
}
// Check if flow is already done // Check if flow is already done
p.mu.RLock() p.mu.RLock()
isDone := flow.State == JA4_DONE state := flow.State
p.mu.RUnlock() p.mu.RUnlock()
if isDone { if state == JA4_DONE {
return nil, nil // Already processed this flow return nil, nil // Already processed this flow
} }
@ -201,8 +216,13 @@ func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) {
} }
// Check for fragmented ClientHello (accumulate segments) // Check for fragmented ClientHello (accumulate segments)
if flow.State == WAIT_CLIENT_HELLO || flow.State == NEW { if state == WAIT_CLIENT_HELLO || state == NEW {
p.mu.Lock() p.mu.Lock()
if len(flow.HelloBuffer)+len(payload) > p.maxHelloBufferBytes {
delete(p.flows, key)
p.mu.Unlock()
return nil, nil
}
flow.State = WAIT_CLIENT_HELLO flow.State = WAIT_CLIENT_HELLO
flow.HelloBuffer = append(flow.HelloBuffer, payload...) flow.HelloBuffer = append(flow.HelloBuffer, payload...)
bufferCopy := make([]byte, len(flow.HelloBuffer)) bufferCopy := make([]byte, len(flow.HelloBuffer))
@ -246,13 +266,17 @@ func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, d
return flow return flow
} }
if len(p.flows) >= p.maxTrackedFlows {
return nil
}
flow := &ConnectionFlow{ flow := &ConnectionFlow{
State: NEW, State: NEW,
CreatedAt: time.Now(), CreatedAt: time.Now(),
LastSeen: time.Now(), LastSeen: time.Now(),
SrcIP: srcIP, // Client IP SrcIP: srcIP, // Client IP
SrcPort: srcPort, // Client port SrcPort: srcPort, // Client port
DstIP: dstIP, // Server IP (local machine) DstIP: dstIP, // Server IP (local machine)
DstPort: dstPort, // Server port (local machine) DstPort: dstPort, // Server port (local machine)
IPMeta: ipMeta, IPMeta: ipMeta,
TCPMeta: tcpMeta, TCPMeta: tcpMeta,

View File

@ -1,8 +1,13 @@
package tlsparse package tlsparse
import ( import (
"net"
"testing" "testing"
"time"
"ja4sentinel/api"
"github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
) )
@ -203,6 +208,8 @@ func createTLSServerHello(version uint16) []byte {
func TestNewParser(t *testing.T) { func TestNewParser(t *testing.T) {
parser := NewParser() parser := NewParser()
defer parser.Close()
if parser == nil { if parser == nil {
t.Error("NewParser() returned nil") t.Error("NewParser() returned nil")
} }
@ -288,3 +295,152 @@ func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) {
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options) t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
} }
} }
func TestGetOrCreateFlow_RespectsMaxTrackedFlows(t *testing.T) {
parser := NewParser()
defer parser.Close()
parser.maxTrackedFlows = 1
flow1 := parser.getOrCreateFlow(
flowKey("192.168.1.1", 12345, "10.0.0.1", 443),
"192.168.1.1", 12345, "10.0.0.1", 443,
api.IPMeta{}, api.TCPMeta{},
)
if flow1 == nil {
t.Fatal("first flow should be created")
}
flow2 := parser.getOrCreateFlow(
flowKey("192.168.1.2", 12346, "10.0.0.1", 443),
"192.168.1.2", 12346, "10.0.0.1", 443,
api.IPMeta{}, api.TCPMeta{},
)
if flow2 != nil {
t.Fatal("second flow should be nil when maxTrackedFlows is reached")
}
}
func TestProcess_DropsWhenHelloBufferExceedsLimit(t *testing.T) {
parser := NewParserWithTimeout(30 * time.Second)
defer parser.Close()
parser.maxHelloBufferBytes = 10
srcIP := "192.168.1.10"
dstIP := "10.0.0.1"
srcPort := uint16(12345)
dstPort := uint16(443)
// TLS-like payload, but intentionally incomplete to trigger accumulation.
payloadChunk := []byte{0x16, 0x03, 0x03, 0x00, 0x20, 0x01} // len = 6
pkt1 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk)
ch, err := parser.Process(pkt1)
if err != nil {
t.Fatalf("first Process() error = %v", err)
}
if ch != nil {
t.Fatal("first Process() should not return complete ClientHello")
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
_, existsAfterFirst := parser.flows[key]
parser.mu.RUnlock()
if !existsAfterFirst {
t.Fatal("flow should exist after first chunk")
}
pkt2 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk)
ch, err = parser.Process(pkt2)
if err != nil {
t.Fatalf("second Process() error = %v", err)
}
if ch != nil {
t.Fatal("second Process() should not return ClientHello")
}
parser.mu.RLock()
_, existsAfterSecond := parser.flows[key]
parser.mu.RUnlock()
if existsAfterSecond {
t.Fatal("flow should be removed when hello buffer exceeds maxHelloBufferBytes")
}
}
func TestProcess_NonTLSNewFlowNotTracked(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.20"
dstIP := "10.0.0.2"
srcPort := uint16(23456)
dstPort := uint16(443)
// Non-TLS content type (not 22)
payload := []byte{0x17, 0x03, 0x03, 0x00, 0x05, 0x00}
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payload)
ch, err := parser.Process(pkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if ch != nil {
t.Fatal("Process() should return nil for non-TLS new flow")
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
_, exists := parser.flows[key]
parser.mu.RUnlock()
if exists {
t.Fatal("non-TLS new flow should not be tracked")
}
}
func buildRawPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket {
t.Helper()
ip := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(srcIP).To4(),
DstIP: net.ParseIP(dstIP).To4(),
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: 1,
ACK: true,
Window: 65535,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
t.Fatalf("SetNetworkLayerForChecksum() error = %v", err)
}
eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
DstMAC: net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
EthernetType: layers.EthernetTypeIPv4,
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
if err := gopacket.SerializeLayers(buf, opts, eth, ip, tcp, gopacket.Payload(payload)); err != nil {
t.Fatalf("SerializeLayers() error = %v", err)
}
return api.RawPacket{
Data: buf.Bytes(),
Timestamp: time.Now().UnixNano(),
}
}