Files
ja4sentinel/internal/tlsparse/parser_test.go
Jacquin Antoine 23f3012fb1
Some checks failed
Build RPM Package / Build RPM Packages (CentOS 7, Rocky 8/9/10) (push) Has been cancelled
release: version 1.1.2 - Add error callback mechanism and comprehensive test suite
Features:
- Add ErrorCallback type for UNIX socket connection error reporting
- Add WithErrorCallback option for UnixSocketWriter configuration
- Add BuilderImpl.WithErrorCallback() for propagating callbacks
- Add consecutive failure tracking in processQueue

Testing (50+ new tests):
- Add integration tests for full pipeline (capture → tlsparse → fingerprint → output)
- Add tests for FileWriter.rotate() and Reopen() log rotation
- Add tests for cleanupExpiredFlows() and cleanupLoop() in TLS parser
- Add tests for extractSNIFromPayload() and extractJA4Hash() helpers
- Add tests for config load error paths (invalid YAML, permission denied)
- Add tests for capture.Run() error conditions
- Add tests for signal handling documentation

Documentation:
- Update architecture.yml with new fields (LogLevel, TLSClientHello extensions)
- Update architecture.yml with Close() methods for Capture and Parser interfaces
- Update RPM spec changelog

Cleanup:
- Remove empty internal/api/ directory

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
2026-03-02 23:24:56 +01:00

1050 lines
26 KiB
Go

package tlsparse
import (
"net"
"testing"
"time"
"ja4sentinel/api"
tlsfingerprint "github.com/psanford/tlsfingerprint"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
func TestIsClientHello(t *testing.T) {
tests := []struct {
name string
payload []byte
want bool
}{
{
name: "empty payload",
payload: []byte{},
want: false,
},
{
name: "too short",
payload: []byte{0x16, 0x03, 0x03},
want: false,
},
{
name: "valid TLS 1.2 ClientHello",
payload: createTLSClientHello(0x0303),
want: true,
},
{
name: "valid TLS 1.3 ClientHello",
payload: createTLSClientHello(0x0304),
want: true,
},
{
name: "not a handshake",
payload: []byte{0x17, 0x03, 0x03, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
want: false,
},
{
name: "ServerHello (type 2)",
payload: createTLSServerHello(0x0303),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsClientHello(tt.payload)
if got != tt.want {
t.Errorf("IsClientHello() = %v, want %v", got, tt.want)
}
})
}
}
func TestParseClientHello(t *testing.T) {
tests := []struct {
name string
payload []byte
wantErr bool
wantNil bool
}{
{
name: "empty payload",
payload: []byte{},
wantErr: false,
wantNil: true,
},
{
name: "valid ClientHello",
payload: createTLSClientHello(0x0303),
wantErr: false,
wantNil: false,
},
{
name: "incomplete record",
payload: []byte{0x16, 0x03, 0x03, 0x01, 0x00, 0x01},
wantErr: false,
wantNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseClientHello(tt.payload)
if (err != nil) != tt.wantErr {
t.Errorf("parseClientHello() error = %v, wantErr %v", err, tt.wantErr)
return
}
if (got == nil) != tt.wantNil {
t.Errorf("parseClientHello() = %v, wantNil %v", got == nil, tt.wantNil)
}
})
}
}
func TestExtractIPMeta(t *testing.T) {
ipLayer := &layers.IPv4{
TTL: 64,
Length: 1500,
Id: 12345,
Flags: layers.IPv4DontFragment,
SrcIP: []byte{192, 168, 1, 1},
DstIP: []byte{10, 0, 0, 1},
}
meta := extractIPMeta(ipLayer)
if meta.TTL != 64 {
t.Errorf("TTL = %v, want 64", meta.TTL)
}
if meta.TotalLength != 1500 {
t.Errorf("TotalLength = %v, want 1500", meta.TotalLength)
}
if meta.IPID != 12345 {
t.Errorf("IPID = %v, want 12345", meta.IPID)
}
if !meta.DF {
t.Error("DF = false, want true")
}
}
func TestExtractTCPMeta(t *testing.T) {
tcp := &layers.TCP{
SrcPort: 12345,
DstPort: 443,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionData: []byte{0x05, 0xb4}, // 1460
},
{
OptionType: layers.TCPOptionKindWindowScale,
OptionData: []byte{0x07}, // scale 7
},
{
OptionType: layers.TCPOptionKindSACKPermitted,
OptionData: []byte{},
},
},
}
meta := extractTCPMeta(tcp)
if meta.WindowSize != 65535 {
t.Errorf("WindowSize = %v, want 65535", meta.WindowSize)
}
if meta.MSS != 1460 {
t.Errorf("MSS = %v, want 1460", meta.MSS)
}
if meta.WindowScale != 7 {
t.Errorf("WindowScale = %v, want 7", meta.WindowScale)
}
if len(meta.Options) != 3 {
t.Errorf("Options length = %v, want 3", len(meta.Options))
}
}
// Helper functions to create test TLS records
func createTLSClientHello(version uint16) []byte {
// Minimal TLS ClientHello record
handshake := []byte{
0x01, // Handshake type: ClientHello
0x00, 0x00, 0x00, 0x10, // Handshake length (16 bytes)
// ClientHello body (simplified)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}
record := make([]byte, 5+len(handshake))
record[0] = 0x16 // Handshake
record[1] = byte(version >> 8)
record[2] = byte(version)
record[3] = byte(len(handshake) >> 8)
record[4] = byte(len(handshake))
copy(record[5:], handshake)
return record
}
func createTLSServerHello(version uint16) []byte {
// Minimal TLS ServerHello record
handshake := []byte{
0x02, // Handshake type: ServerHello
0x00, 0x00, 0x00, 0x10, // Handshake length
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}
record := make([]byte, 5+len(handshake))
record[0] = 0x16 // Handshake
record[1] = byte(version >> 8)
record[2] = byte(version)
record[3] = byte(len(handshake) >> 8)
record[4] = byte(len(handshake))
copy(record[5:], handshake)
return record
}
func TestNewParser(t *testing.T) {
parser := NewParser()
defer parser.Close()
if parser == nil {
t.Error("NewParser() returned nil")
}
if parser.flows == nil {
t.Error("NewParser() flows map not initialized")
}
if parser.flowTimeout == 0 {
t.Error("NewParser() flowTimeout not set")
}
}
func TestParserClose(t *testing.T) {
parser := NewParser()
err := parser.Close()
if err != nil {
t.Errorf("Close() error = %v", err)
}
}
func TestFlowKey(t *testing.T) {
// Test unidirectional flow key (client -> server only)
key := flowKey("192.168.1.1", 12345, "10.0.0.1", 443)
expected := "192.168.1.1:12345->10.0.0.1:443"
if key != expected {
t.Errorf("flowKey() = %v, want %v", key, expected)
}
}
func TestParserConnectionStateTracking(t *testing.T) {
parser := NewParser()
defer parser.Close()
// Create a valid ClientHello payload
clientHello := createTLSClientHello(0x0303)
// Test parseClientHello directly (lower-level test)
result, err := parseClientHello(clientHello)
if err != nil {
t.Errorf("parseClientHello() error = %v", err)
}
if result == nil {
t.Error("parseClientHello() should return ClientHello")
}
// Test IsClientHello helper
if !IsClientHello(clientHello) {
t.Error("IsClientHello() should return true for valid ClientHello")
}
}
func TestParserClose_Idempotent(t *testing.T) {
parser := NewParser()
if err := parser.Close(); err != nil {
t.Fatalf("first Close() error = %v", err)
}
if err := parser.Close(); err != nil {
t.Fatalf("second Close() error = %v", err)
}
}
func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) {
tcp := &layers.TCP{
Window: 1234,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionData: []byte{0x05}, // malformed (1 byte instead of 2)
},
},
}
meta := extractTCPMeta(tcp)
found := false
for _, opt := range meta.Options {
if opt == "MSS_INVALID" {
found = true
break
}
}
if !found {
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
}
}
func TestGetOrCreateFlow_RespectsMaxTrackedFlows(t *testing.T) {
parser := NewParser()
defer parser.Close()
parser.maxTrackedFlows = 1
flow1 := parser.getOrCreateFlow(
flowKey("192.168.1.1", 12345, "10.0.0.1", 443),
"192.168.1.1", 12345, "10.0.0.1", 443,
api.IPMeta{}, api.TCPMeta{},
)
if flow1 == nil {
t.Fatal("first flow should be created")
}
flow2 := parser.getOrCreateFlow(
flowKey("192.168.1.2", 12346, "10.0.0.1", 443),
"192.168.1.2", 12346, "10.0.0.1", 443,
api.IPMeta{}, api.TCPMeta{},
)
if flow2 != nil {
t.Fatal("second flow should be nil when maxTrackedFlows is reached")
}
}
func TestProcess_DropsWhenHelloBufferExceedsLimit(t *testing.T) {
parser := NewParserWithTimeout(30 * time.Second)
defer parser.Close()
parser.maxHelloBufferBytes = 10
srcIP := "192.168.1.10"
dstIP := "10.0.0.1"
srcPort := uint16(12345)
dstPort := uint16(443)
// TLS-like payload, but intentionally incomplete to trigger accumulation.
payloadChunk := []byte{0x16, 0x03, 0x03, 0x00, 0x20, 0x01} // len = 6
pkt1 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk)
ch, err := parser.Process(pkt1)
if err != nil {
t.Fatalf("first Process() error = %v", err)
}
if ch != nil {
t.Fatal("first Process() should not return complete ClientHello")
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
_, existsAfterFirst := parser.flows[key]
parser.mu.RUnlock()
if !existsAfterFirst {
t.Fatal("flow should exist after first chunk")
}
pkt2 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk)
ch, err = parser.Process(pkt2)
if err != nil {
t.Fatalf("second Process() error = %v", err)
}
if ch != nil {
t.Fatal("second Process() should not return ClientHello")
}
parser.mu.RLock()
_, existsAfterSecond := parser.flows[key]
parser.mu.RUnlock()
if existsAfterSecond {
t.Fatal("flow should be removed when hello buffer exceeds maxHelloBufferBytes")
}
}
func TestProcess_NonTLSNewFlowNotTracked(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.20"
dstIP := "10.0.0.2"
srcPort := uint16(23456)
dstPort := uint16(443)
// Non-TLS content type (not 22)
payload := []byte{0x17, 0x03, 0x03, 0x00, 0x05, 0x00}
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payload)
ch, err := parser.Process(pkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if ch != nil {
t.Fatal("Process() should return nil for non-TLS new flow")
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
_, exists := parser.flows[key]
parser.mu.RUnlock()
if exists {
t.Fatal("non-TLS new flow should not be tracked")
}
}
func buildRawPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket {
t.Helper()
ip := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(srcIP).To4(),
DstIP: net.ParseIP(dstIP).To4(),
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: 1,
ACK: true,
Window: 65535,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
t.Fatalf("SetNetworkLayerForChecksum() error = %v", err)
}
eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
DstMAC: net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
EthernetType: layers.EthernetTypeIPv4,
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
if err := gopacket.SerializeLayers(buf, opts, eth, ip, tcp, gopacket.Payload(payload)); err != nil {
t.Fatalf("SerializeLayers() error = %v", err)
}
return api.RawPacket{
Data: buf.Bytes(),
Timestamp: time.Now().UnixNano(),
}
}
func TestTLSVersionToString(t *testing.T) {
tests := []struct {
version uint16
want string
}{
{0x0301, "1.0"},
{0x0302, "1.1"},
{0x0303, "1.2"},
{0x0304, "1.3"},
{0x0300, ""},
{0x0305, ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
got := tlsVersionToString(tt.version)
if got != tt.want {
t.Errorf("tlsVersionToString(%#x) = %v, want %v", tt.version, got, tt.want)
}
})
}
}
func TestExtractTLSExtensions(t *testing.T) {
tests := []struct {
name string
payload []byte
wantSNI string
wantALPN []string
wantVersion string
wantNil bool
}{
{
name: "empty payload",
payload: []byte{},
wantNil: true,
},
{
name: "too short",
payload: []byte{0x16, 0x03, 0x03},
wantNil: true,
},
{
name: "TLS 1.2 ClientHello without extensions",
payload: createTLSClientHello(0x0303),
wantVersion: "1.2",
wantNil: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractTLSExtensions(tt.payload)
if err != nil {
t.Errorf("extractTLSExtensions() unexpected error = %v", err)
return
}
if (got == nil) != tt.wantNil {
t.Errorf("extractTLSExtensions() = %v, wantNil %v", got == nil, tt.wantNil)
return
}
if got != nil {
if got.TLSVersion != tt.wantVersion {
t.Errorf("TLSVersion = %v, want %v", got.TLSVersion, tt.wantVersion)
}
}
})
}
}
func TestParser_ExtractsTLSFields(t *testing.T) {
parser := NewParser()
defer parser.Close()
// Create a minimal valid TLS 1.2 ClientHello with SNI and ALPN extensions
// This is a real-world-like ClientHello structure
clientHelloWithExt := createMinimalTLSClientHelloWithSNIAndALPN("example.com", []string{"h2", "http/1.1"})
// Debug: Check what extractTLSExtensions returns
extInfo, err := extractTLSExtensions(clientHelloWithExt)
if err != nil {
t.Logf("extractTLSExtensions error: %v", err)
}
if extInfo != nil {
t.Logf("extInfo: SNI=%q, ALPN=%v, Version=%q", extInfo.SNI, extInfo.ALPN, extInfo.TLSVersion)
} else {
t.Log("extInfo is nil")
}
// Also test with tlsfingerprint directly
fp, err := tlsfingerprint.ParseClientHello(clientHelloWithExt)
if err != nil {
t.Logf("tlsfingerprint error: %v", err)
} else {
t.Logf("tlsfingerprint: ALPN=%v, Version=%#x, HasSNI=%v", fp.ALPNProtocols, fp.Version, fp.HasSNI)
}
// Debug: print first bytes of ClientHello
t.Logf("ClientHello hex: % x", clientHelloWithExt[:min(50, len(clientHelloWithExt))])
srcIP := "192.168.1.100"
dstIP := "10.0.0.1"
srcPort := uint16(54321)
dstPort := uint16(443)
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHelloWithExt)
result, err := parser.Process(pkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if result == nil {
t.Fatal("Process() should return TLSClientHello")
}
// Verify new fields are populated
if result.SNI != "example.com" {
t.Errorf("SNI = %v, want example.com", result.SNI)
}
if result.ALPN == "" {
t.Error("ALPN should not be empty")
}
if result.TLSVersion != "1.2" {
t.Errorf("TLSVersion = %v, want 1.2", result.TLSVersion)
}
if result.ConnID == "" {
t.Error("ConnID should not be empty")
}
if result.SynToCHMs == nil {
t.Error("SynToCHMs should not be nil")
}
}
// createMinimalTLSClientHelloWithSNIAndALPN creates a minimal but valid TLS 1.2 ClientHello
// with SNI and ALPN extensions
func createMinimalTLSClientHelloWithSNIAndALPN(sni string, alpnProtocols []string) []byte {
// Build SNI extension
sniExt := buildSNIExtension(sni)
// Build ALPN extension
alpnExt := buildALPNExtension(alpnProtocols)
// Build supported_versions extension (TLS 1.2)
supportedVersionsExt := []byte{0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x03} // type=43, len=3, TLS 1.2
// Combine extensions
extensions := append(sniExt, alpnExt...)
extensions = append(extensions, supportedVersionsExt...)
extLen := len(extensions)
// Cipher suites (minimal set)
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
// 4 cipher suites: TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
// Compression methods (null only)
compressionMethods := []byte{0x01, 0x00}
// Build ClientHello handshake (without length header first)
handshakeBody := []byte{
0x03, 0x03, // Version: TLS 1.2
// Random (32 bytes)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, // Session ID length: 0
}
// Add cipher suites (with length prefix)
cipherSuiteLen := len(cipherSuites)
handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen))
handshakeBody = append(handshakeBody, cipherSuites...)
// Add compression methods (with length prefix)
handshakeBody = append(handshakeBody, compressionMethods...)
// Add extensions (with length prefix)
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
handshakeBody = append(handshakeBody, extensions...)
// Now build full handshake with type and length
handshakeLen := len(handshakeBody)
handshake := append([]byte{
0x01, // Handshake type: ClientHello
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen), // Handshake length
}, handshakeBody...)
// Build TLS record
recordLen := len(handshake)
record := make([]byte, 5+recordLen)
record[0] = 0x16 // Handshake
record[1] = 0x03 // Version: TLS 1.2
record[2] = 0x03
record[3] = byte(recordLen >> 8)
record[4] = byte(recordLen)
copy(record[5:], handshake)
return record
}
// buildSNIExtension builds a Server Name Indication extension
func buildSNIExtension(sni string) []byte {
nameLen := len(sni)
// SNI extension structure:
// - name_list_len (2 bytes): total length of all names
// - name_type (1 byte): always 0x00 for host_name
// - name_len (2 bytes): length of this name
// - name (variable): the actual hostname
nameListLen := 1 + 2 + nameLen // name_type + name_len + name
// Extension data = name_list_len (2) + name_type (1) + name_len (2) + name
extDataLen := 2 + nameListLen
// Full extension = type (2) + length (2) + data (variable)
ext := make([]byte, 4+extDataLen)
ext[0] = 0x00 // Extension type: SNI (0)
ext[1] = 0x00
ext[2] = byte(extDataLen >> 8)
ext[3] = byte(extDataLen)
ext[4] = byte(nameListLen >> 8) // name_list_len (high byte)
ext[5] = byte(nameListLen) // name_list_len (low byte)
ext[6] = 0x00 // name_type: host_name (0)
ext[7] = byte(nameLen >> 8) // name_len (high byte)
ext[8] = byte(nameLen) // name_len (low byte)
copy(ext[9:], sni)
return ext
}
// buildALPNExtension builds an Application-Layer Protocol Negotiation extension
func buildALPNExtension(protocols []string) []byte {
// Calculate ALPN data length
// ALPN data = alpn_list_len (2) + for each protocol: length (1) + data (variable)
alpnDataLen := 0
for _, proto := range protocols {
alpnDataLen += 1 + len(proto) // length byte + protocol string
}
// Extension data = alpn_list_len (2) + protocols
extDataLen := 2 + alpnDataLen
// Full extension = type (2) + length (2) + data (variable)
ext := make([]byte, 4+extDataLen)
ext[0] = 0x00 // Extension type: ALPN (16 = 0x10)
ext[1] = 0x10
ext[2] = byte(extDataLen >> 8)
ext[3] = byte(extDataLen)
// ALPN protocol list length (2 bytes)
ext[4] = byte(alpnDataLen >> 8)
ext[5] = byte(alpnDataLen)
offset := 6
for _, proto := range protocols {
ext[offset] = byte(len(proto))
offset++
copy(ext[offset:], proto)
offset += len(proto)
}
return ext
}
// min returns the minimum of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// TestExtractSNIFromPayload tests the SNI extraction function
func TestExtractSNIFromPayload(t *testing.T) {
tests := []struct {
name string
payload []byte
wantSNI string
}{
{
name: "empty_payload",
payload: []byte{},
wantSNI: "",
},
{
name: "payload_too_short",
payload: []byte{0x01, 0x00, 0x00, 0x10}, // Only 4 bytes
wantSNI: "",
},
{
name: "no_extensions",
payload: buildClientHelloWithoutExtensions(),
wantSNI: "",
},
{
name: "with_sni_extension",
payload: buildClientHelloWithSNI("example.com"),
wantSNI: "example.com",
},
{
name: "with_sni_long_domain",
payload: buildClientHelloWithSNI("very-long-subdomain.example-test-domain.com"),
wantSNI: "very-long-subdomain.example-test-domain.com",
},
{
name: "malformed_sni_truncated",
payload: buildTruncatedSNI(),
wantSNI: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractSNIFromPayload(tt.payload)
if got != tt.wantSNI {
t.Errorf("extractSNIFromPayload() = %q, want %q", got, tt.wantSNI)
}
})
}
}
// TestCleanupExpiredFlows tests the flow cleanup functionality
func TestCleanupExpiredFlows(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
// Create a flow manually using exported types
key := "192.168.1.1:12345->10.0.0.1:443"
flow := &ConnectionFlow{
State: NEW,
LastSeen: time.Now().Add(-2 * time.Hour), // Old flow
HelloBuffer: make([]byte, 0),
}
p.mu.Lock()
p.flows[key] = flow
p.mu.Unlock()
// Call cleanup
p.cleanupExpiredFlows()
// Flow should be deleted
p.mu.RLock()
_, exists := p.flows[key]
p.mu.RUnlock()
if exists {
t.Error("cleanupExpiredFlows() should have removed the expired flow")
}
}
// TestCleanupExpiredFlows_JA4Done tests that JA4_DONE flows are cleaned up immediately
func TestCleanupExpiredFlows_JA4Done(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
// Create a JA4_DONE flow (should be cleaned up regardless of timestamp)
key := "192.168.1.1:12345->10.0.0.1:443"
flow := &ConnectionFlow{
State: JA4_DONE,
LastSeen: time.Now(), // Recent, but should still be deleted
HelloBuffer: make([]byte, 0),
}
p.mu.Lock()
p.flows[key] = flow
p.mu.Unlock()
// Call cleanup
p.cleanupExpiredFlows()
// Flow should be deleted
p.mu.RLock()
_, exists := p.flows[key]
p.mu.RUnlock()
if exists {
t.Error("cleanupExpiredFlows() should have removed the JA4_DONE flow")
}
}
// TestCleanupExpiredFlows_RecentFlow tests that recent flows are NOT cleaned up
func TestCleanupExpiredFlows_RecentFlow(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
// Create a recent flow (should NOT be cleaned up)
key := "192.168.1.1:12345->10.0.0.1:443"
flow := &ConnectionFlow{
State: WAIT_CLIENT_HELLO,
LastSeen: time.Now(),
HelloBuffer: make([]byte, 0),
}
p.mu.Lock()
p.flows[key] = flow
p.mu.Unlock()
// Call cleanup
p.cleanupExpiredFlows()
// Flow should still exist
p.mu.RLock()
_, exists := p.flows[key]
p.mu.RUnlock()
if !exists {
t.Error("cleanupExpiredFlows() should NOT have removed the recent flow")
}
}
// TestCleanupLoop tests the cleanup goroutine shutdown
func TestCleanupLoop_Shutdown(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
// Close should stop the cleanup loop
err := p.Close()
if err != nil {
t.Errorf("Close() error = %v", err)
}
// Give goroutine time to exit
time.Sleep(100 * time.Millisecond)
// Verify cleanupDone channel is closed
select {
case _, ok := <-p.cleanupDone:
if ok {
t.Error("cleanupDone channel should be closed after Close()")
}
default:
t.Error("cleanupDone channel should be closed after Close()")
}
}
// buildClientHelloWithoutExtensions creates a ClientHello without extensions
func buildClientHelloWithoutExtensions() []byte {
// Minimal ClientHello: type(1) + len(3) + version(2) + random(32) + sessionIDLen(1) + cipherLen(2) + compressLen(1) + extLen(2)
handshake := make([]byte, 43)
handshake[0] = 0x01 // ClientHello type
handshake[1] = 0x00
handshake[2] = 0x00
handshake[3] = 0x27 // Length
handshake[4] = 0x03 // Version TLS 1.2
handshake[5] = 0x03
// Random (32 bytes) - zeros
handshake[38] = 0x00 // Session ID length
handshake[39] = 0x00 // Cipher suite length (high)
handshake[40] = 0x02 // Cipher suite length (low)
handshake[41] = 0x13 // Cipher suite data
handshake[42] = 0x01
// Compression length (1 byte) - 0
// Extensions length (2 bytes) - 0
return handshake
}
// buildClientHelloWithSNI creates a ClientHello with SNI extension
func buildClientHelloWithSNI(sni string) []byte {
// Build base handshake
handshake := make([]byte, 43)
handshake[0] = 0x01 // ClientHello type
handshake[1] = 0x00
handshake[2] = 0x00
handshake[4] = 0x03 // Version TLS 1.2
handshake[5] = 0x03
handshake[38] = 0x00 // Session ID length
handshake[39] = 0x00 // Cipher suite length (high)
handshake[40] = 0x02 // Cipher suite length (low)
handshake[41] = 0x13 // Cipher suite data
handshake[42] = 0x01
// Add compression length (1 byte) - 0
handshake = append(handshake, 0x00)
// Add extensions
sniExt := buildSNIExtension(sni)
extLen := len(sniExt)
handshake = append(handshake, byte(extLen>>8), byte(extLen))
handshake = append(handshake, sniExt...)
// Update handshake length
handshakeLen := len(handshake) - 4
handshake[1] = byte(handshakeLen >> 16)
handshake[2] = byte(handshakeLen >> 8)
handshake[3] = byte(handshakeLen)
return handshake
}
// buildTruncatedSNI creates a malformed ClientHello with truncated SNI
func buildTruncatedSNI() []byte {
// Build base handshake
handshake := make([]byte, 44)
handshake[0] = 0x01
handshake[4] = 0x03
handshake[5] = 0x03
handshake[38] = 0x00
handshake[39] = 0x00
handshake[40] = 0x02
handshake[41] = 0x13
handshake[42] = 0x01
handshake[43] = 0x00 // Compression length
// Add extensions with truncated SNI
// Extension type (2) + length (2) + data (truncated)
handshake = append(handshake, 0x00, 0x0a) // Extension length says 10 bytes
handshake = append(handshake, 0x00, 0x05) // But only provide 5 bytes of data
handshake = append(handshake, 0x00, 0x03, 0x74, 0x65, 0x73) // "tes" truncated
return handshake
}
// TestJoinStringSlice tests the deprecated helper function
func TestJoinStringSlice(t *testing.T) {
tests := []struct {
name string
slice []string
sep string
want string
}{
{
name: "empty_slice",
slice: []string{},
sep: ",",
want: "",
},
{
name: "single_element",
slice: []string{"MSS"},
sep: ",",
want: "MSS",
},
{
name: "multiple_elements",
slice: []string{"MSS", "SACK", "TS"},
sep: ",",
want: "MSS,SACK,TS",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := joinStringSlice(tt.slice, tt.sep)
if got != tt.want {
t.Errorf("joinStringSlice() = %q, want %q", got, tt.want)
}
})
}
}
// TestProcess_NilPacketData tests error handling for nil packet data
func TestProcess_NilPacketData(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
pkt := api.RawPacket{
Data: nil,
Timestamp: time.Now().UnixNano(),
}
_, err := p.Process(pkt)
if err == nil {
t.Error("Process() with nil data should return error")
}
if err.Error() != "empty packet data" {
t.Errorf("Process() error = %v, want 'empty packet data'", err)
}
}
// TestProcess_EmptyPacketData tests error handling for empty packet data
func TestProcess_EmptyPacketData(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
pkt := api.RawPacket{
Data: []byte{},
Timestamp: time.Now().UnixNano(),
}
_, err := p.Process(pkt)
if err == nil {
t.Error("Process() with empty data should return error")
}
}