Files
ja4sentinel/internal/tlsparse/parser_test.go
toto d22b0634da
Some checks failed
Build RPM Package / Build RPM Packages (CentOS 7, Rocky 8/9/10) (push) Has been cancelled
release: version 1.1.15 - Fix ALPN detection for malformed TLS extensions
- FIX: ALPN (tls_alpn) not appearing in logs for packets with truncated extensions
- Add sanitizeTLSRecord fallback in extractTLSExtensions (tlsparse/parser.go)
- Mirrors sanitization already present in fingerprint/engine.go
- ALPN now correctly extracted even when ParseClientHello fails on raw payload
- Bump version to 1.1.15 in main.go and packaging/rpm/ja4sentinel.spec

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-05 14:42:15 +01:00

1793 lines
47 KiB
Go

package tlsparse
import (
"net"
"testing"
"time"
"ja4sentinel/api"
tlsfingerprint "github.com/psanford/tlsfingerprint"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
func TestIsClientHello(t *testing.T) {
tests := []struct {
name string
payload []byte
want bool
}{
{
name: "empty payload",
payload: []byte{},
want: false,
},
{
name: "too short",
payload: []byte{0x16, 0x03, 0x03},
want: false,
},
{
name: "valid TLS 1.2 ClientHello",
payload: createTLSClientHello(0x0303),
want: true,
},
{
name: "valid TLS 1.3 ClientHello",
payload: createTLSClientHello(0x0304),
want: true,
},
{
name: "not a handshake",
payload: []byte{0x17, 0x03, 0x03, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
want: false,
},
{
name: "ServerHello (type 2)",
payload: createTLSServerHello(0x0303),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsClientHello(tt.payload)
if got != tt.want {
t.Errorf("IsClientHello() = %v, want %v", got, tt.want)
}
})
}
}
func TestParseClientHello(t *testing.T) {
tests := []struct {
name string
payload []byte
wantErr bool
wantNil bool
}{
{
name: "empty payload",
payload: []byte{},
wantErr: false,
wantNil: true,
},
{
name: "valid ClientHello",
payload: createTLSClientHello(0x0303),
wantErr: false,
wantNil: false,
},
{
name: "incomplete record",
payload: []byte{0x16, 0x03, 0x03, 0x01, 0x00, 0x01},
wantErr: false,
wantNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseClientHello(tt.payload)
if (err != nil) != tt.wantErr {
t.Errorf("parseClientHello() error = %v, wantErr %v", err, tt.wantErr)
return
}
if (got == nil) != tt.wantNil {
t.Errorf("parseClientHello() = %v, wantNil %v", got == nil, tt.wantNil)
}
})
}
}
func TestExtractIPMeta(t *testing.T) {
ipLayer := &layers.IPv4{
TTL: 64,
Length: 1500,
Id: 12345,
Flags: layers.IPv4DontFragment,
SrcIP: []byte{192, 168, 1, 1},
DstIP: []byte{10, 0, 0, 1},
}
meta := extractIPMeta(ipLayer)
if meta.TTL != 64 {
t.Errorf("TTL = %v, want 64", meta.TTL)
}
if meta.TotalLength != 1500 {
t.Errorf("TotalLength = %v, want 1500", meta.TotalLength)
}
if meta.IPID != 12345 {
t.Errorf("IPID = %v, want 12345", meta.IPID)
}
if !meta.DF {
t.Error("DF = false, want true")
}
}
func TestExtractTCPMeta(t *testing.T) {
tcp := &layers.TCP{
SrcPort: 12345,
DstPort: 443,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionData: []byte{0x05, 0xb4}, // 1460
},
{
OptionType: layers.TCPOptionKindWindowScale,
OptionData: []byte{0x07}, // scale 7
},
{
OptionType: layers.TCPOptionKindSACKPermitted,
OptionData: []byte{},
},
},
}
meta := extractTCPMeta(tcp)
if meta.WindowSize != 65535 {
t.Errorf("WindowSize = %v, want 65535", meta.WindowSize)
}
if meta.MSS != 1460 {
t.Errorf("MSS = %v, want 1460", meta.MSS)
}
if meta.WindowScale != 7 {
t.Errorf("WindowScale = %v, want 7", meta.WindowScale)
}
if len(meta.Options) != 3 {
t.Errorf("Options length = %v, want 3", len(meta.Options))
}
}
// Helper functions to create test TLS records
func createTLSClientHello(version uint16) []byte {
// Minimal TLS ClientHello record
handshake := []byte{
0x01, // Handshake type: ClientHello
0x00, 0x00, 0x00, 0x10, // Handshake length (16 bytes)
// ClientHello body (simplified)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}
record := make([]byte, 5+len(handshake))
record[0] = 0x16 // Handshake
record[1] = byte(version >> 8)
record[2] = byte(version)
record[3] = byte(len(handshake) >> 8)
record[4] = byte(len(handshake))
copy(record[5:], handshake)
return record
}
func createTLSServerHello(version uint16) []byte {
// Minimal TLS ServerHello record
handshake := []byte{
0x02, // Handshake type: ServerHello
0x00, 0x00, 0x00, 0x10, // Handshake length
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}
record := make([]byte, 5+len(handshake))
record[0] = 0x16 // Handshake
record[1] = byte(version >> 8)
record[2] = byte(version)
record[3] = byte(len(handshake) >> 8)
record[4] = byte(len(handshake))
copy(record[5:], handshake)
return record
}
func TestNewParser(t *testing.T) {
parser := NewParser()
defer parser.Close()
if parser == nil {
t.Error("NewParser() returned nil")
}
if parser.flows == nil {
t.Error("NewParser() flows map not initialized")
}
if parser.flowTimeout == 0 {
t.Error("NewParser() flowTimeout not set")
}
}
func TestParserClose(t *testing.T) {
parser := NewParser()
err := parser.Close()
if err != nil {
t.Errorf("Close() error = %v", err)
}
}
func TestFlowKey(t *testing.T) {
// Test unidirectional flow key (client -> server only)
key := flowKey("192.168.1.1", 12345, "10.0.0.1", 443)
expected := "192.168.1.1:12345->10.0.0.1:443"
if key != expected {
t.Errorf("flowKey() = %v, want %v", key, expected)
}
}
func TestParserConnectionStateTracking(t *testing.T) {
parser := NewParser()
defer parser.Close()
// Create a valid ClientHello payload
clientHello := createTLSClientHello(0x0303)
// Test parseClientHello directly (lower-level test)
result, err := parseClientHello(clientHello)
if err != nil {
t.Errorf("parseClientHello() error = %v", err)
}
if result == nil {
t.Error("parseClientHello() should return ClientHello")
}
// Test IsClientHello helper
if !IsClientHello(clientHello) {
t.Error("IsClientHello() should return true for valid ClientHello")
}
}
func TestParserClose_Idempotent(t *testing.T) {
parser := NewParser()
if err := parser.Close(); err != nil {
t.Fatalf("first Close() error = %v", err)
}
if err := parser.Close(); err != nil {
t.Fatalf("second Close() error = %v", err)
}
}
func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) {
tcp := &layers.TCP{
Window: 1234,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionData: []byte{0x05}, // malformed (1 byte instead of 2)
},
},
}
meta := extractTCPMeta(tcp)
found := false
for _, opt := range meta.Options {
if opt == "MSS_INVALID" {
found = true
break
}
}
if !found {
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
}
}
func TestGetOrCreateFlow_RespectsMaxTrackedFlows(t *testing.T) {
parser := NewParser()
defer parser.Close()
parser.maxTrackedFlows = 1
flow1 := parser.getOrCreateFlow(
flowKey("192.168.1.1", 12345, "10.0.0.1", 443),
"192.168.1.1", 12345, "10.0.0.1", 443,
api.IPMeta{}, api.TCPMeta{},
)
if flow1 == nil {
t.Fatal("first flow should be created")
}
flow2 := parser.getOrCreateFlow(
flowKey("192.168.1.2", 12346, "10.0.0.1", 443),
"192.168.1.2", 12346, "10.0.0.1", 443,
api.IPMeta{}, api.TCPMeta{},
)
if flow2 != nil {
t.Fatal("second flow should be nil when maxTrackedFlows is reached")
}
}
func TestProcess_DropsWhenHelloBufferExceedsLimit(t *testing.T) {
parser := NewParserWithTimeout(30 * time.Second)
defer parser.Close()
parser.maxHelloBufferBytes = 10
srcIP := "192.168.1.10"
dstIP := "10.0.0.1"
srcPort := uint16(12345)
dstPort := uint16(443)
// TLS-like payload, but intentionally incomplete to trigger accumulation.
payloadChunk := []byte{0x16, 0x03, 0x03, 0x00, 0x20, 0x01} // len = 6
pkt1 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, payloadChunk, 1)
ch, err := parser.Process(pkt1)
if err != nil {
t.Fatalf("first Process() error = %v", err)
}
if ch != nil {
t.Fatal("first Process() should not return complete ClientHello")
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
_, existsAfterFirst := parser.flows[key]
parser.mu.RUnlock()
if !existsAfterFirst {
t.Fatal("flow should exist after first chunk")
}
pkt2 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, payloadChunk, 1+uint32(len(payloadChunk)))
ch, err = parser.Process(pkt2)
if err != nil {
t.Fatalf("second Process() error = %v", err)
}
if ch != nil {
t.Fatal("second Process() should not return ClientHello")
}
parser.mu.RLock()
_, existsAfterSecond := parser.flows[key]
parser.mu.RUnlock()
if existsAfterSecond {
t.Fatal("flow should be removed when hello buffer exceeds maxHelloBufferBytes")
}
}
func TestProcess_NonTLSNewFlowNotTracked(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.20"
dstIP := "10.0.0.2"
srcPort := uint16(23456)
dstPort := uint16(443)
// Non-TLS content type (not 22)
payload := []byte{0x17, 0x03, 0x03, 0x00, 0x05, 0x00}
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payload)
ch, err := parser.Process(pkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if ch != nil {
t.Fatal("Process() should return nil for non-TLS new flow")
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
_, exists := parser.flows[key]
parser.mu.RUnlock()
if exists {
t.Fatal("non-TLS new flow should not be tracked")
}
}
func buildRawPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket {
return buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, payload, 1)
}
func buildRawPacketWithSeq(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte, seq uint32) api.RawPacket {
t.Helper()
ip := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(srcIP).To4(),
DstIP: net.ParseIP(dstIP).To4(),
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: seq,
ACK: true,
Window: 65535,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
t.Fatalf("SetNetworkLayerForChecksum() error = %v", err)
}
eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
DstMAC: net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
EthernetType: layers.EthernetTypeIPv4,
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
if err := gopacket.SerializeLayers(buf, opts, eth, ip, tcp, gopacket.Payload(payload)); err != nil {
t.Fatalf("SerializeLayers() error = %v", err)
}
return api.RawPacket{
Data: buf.Bytes(),
Timestamp: time.Now().UnixNano(),
LinkType: 1, // Ethernet
}
}
func TestTLSVersionToString(t *testing.T) {
tests := []struct {
version uint16
want string
}{
{0x0301, "1.0"},
{0x0302, "1.1"},
{0x0303, "1.2"},
{0x0304, "1.3"},
{0x0300, ""},
{0x0305, ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
got := tlsVersionToString(tt.version)
if got != tt.want {
t.Errorf("tlsVersionToString(%#x) = %v, want %v", tt.version, got, tt.want)
}
})
}
}
func TestExtractTLSExtensions(t *testing.T) {
tests := []struct {
name string
payload []byte
wantSNI string
wantALPN []string
wantVersion string
wantNil bool
}{
{
name: "empty payload",
payload: []byte{},
wantNil: true,
},
{
name: "too short",
payload: []byte{0x16, 0x03, 0x03},
wantNil: true,
},
{
name: "TLS 1.2 ClientHello without extensions",
payload: createTLSClientHello(0x0303),
wantVersion: "1.2",
wantNil: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractTLSExtensions(tt.payload)
if err != nil {
t.Errorf("extractTLSExtensions() unexpected error = %v", err)
return
}
if (got == nil) != tt.wantNil {
t.Errorf("extractTLSExtensions() = %v, wantNil %v", got == nil, tt.wantNil)
return
}
if got != nil {
if got.TLSVersion != tt.wantVersion {
t.Errorf("TLSVersion = %v, want %v", got.TLSVersion, tt.wantVersion)
}
}
})
}
}
func TestParser_ExtractsTLSFields(t *testing.T) {
parser := NewParser()
defer parser.Close()
// Create a minimal valid TLS 1.2 ClientHello with SNI and ALPN extensions
// This is a real-world-like ClientHello structure
clientHelloWithExt := createMinimalTLSClientHelloWithSNIAndALPN("example.com", []string{"h2", "http/1.1"})
// Debug: Check what extractTLSExtensions returns
extInfo, err := extractTLSExtensions(clientHelloWithExt)
if err != nil {
t.Logf("extractTLSExtensions error: %v", err)
}
if extInfo != nil {
t.Logf("extInfo: SNI=%q, ALPN=%v, Version=%q", extInfo.SNI, extInfo.ALPN, extInfo.TLSVersion)
} else {
t.Log("extInfo is nil")
}
// Also test with tlsfingerprint directly
fp, err := tlsfingerprint.ParseClientHello(clientHelloWithExt)
if err != nil {
t.Logf("tlsfingerprint error: %v", err)
} else {
t.Logf("tlsfingerprint: ALPN=%v, Version=%#x, HasSNI=%v", fp.ALPNProtocols, fp.Version, fp.HasSNI)
}
// Debug: print first bytes of ClientHello
t.Logf("ClientHello hex: % x", clientHelloWithExt[:min(50, len(clientHelloWithExt))])
srcIP := "192.168.1.100"
dstIP := "10.0.0.1"
srcPort := uint16(54321)
dstPort := uint16(443)
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHelloWithExt)
result, err := parser.Process(pkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if result == nil {
t.Fatal("Process() should return TLSClientHello")
}
// Verify new fields are populated
if result.SNI != "example.com" {
t.Errorf("SNI = %v, want example.com", result.SNI)
}
if result.ALPN == "" {
t.Error("ALPN should not be empty")
}
if result.TLSVersion != "1.2" {
t.Errorf("TLSVersion = %v, want 1.2", result.TLSVersion)
}
if result.ConnID == "" {
t.Error("ConnID should not be empty")
}
if result.SynToCHMs == nil {
t.Error("SynToCHMs should not be nil")
}
}
// createMinimalTLSClientHelloWithSNIAndALPN creates a minimal but valid TLS 1.2 ClientHello
// with SNI and ALPN extensions
func createMinimalTLSClientHelloWithSNIAndALPN(sni string, alpnProtocols []string) []byte {
// Build SNI extension
sniExt := buildSNIExtension(sni)
// Build ALPN extension
alpnExt := buildALPNExtension(alpnProtocols)
// Build supported_versions extension (TLS 1.2)
supportedVersionsExt := []byte{0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x03} // type=43, len=3, TLS 1.2
// Combine extensions
extensions := append(sniExt, alpnExt...)
extensions = append(extensions, supportedVersionsExt...)
extLen := len(extensions)
// Cipher suites (minimal set)
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
// 4 cipher suites: TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
// Compression methods (null only)
compressionMethods := []byte{0x01, 0x00}
// Build ClientHello handshake (without length header first)
handshakeBody := []byte{
0x03, 0x03, // Version: TLS 1.2
// Random (32 bytes)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, // Session ID length: 0
}
// Add cipher suites (with length prefix)
cipherSuiteLen := len(cipherSuites)
handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen))
handshakeBody = append(handshakeBody, cipherSuites...)
// Add compression methods (with length prefix)
handshakeBody = append(handshakeBody, compressionMethods...)
// Add extensions (with length prefix)
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
handshakeBody = append(handshakeBody, extensions...)
// Now build full handshake with type and length
handshakeLen := len(handshakeBody)
handshake := append([]byte{
0x01, // Handshake type: ClientHello
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen), // Handshake length
}, handshakeBody...)
// Build TLS record
recordLen := len(handshake)
record := make([]byte, 5+recordLen)
record[0] = 0x16 // Handshake
record[1] = 0x03 // Version: TLS 1.2
record[2] = 0x03
record[3] = byte(recordLen >> 8)
record[4] = byte(recordLen)
copy(record[5:], handshake)
return record
}
// buildSNIExtension builds a Server Name Indication extension
func buildSNIExtension(sni string) []byte {
nameLen := len(sni)
// SNI extension structure:
// - name_list_len (2 bytes): total length of all names
// - name_type (1 byte): always 0x00 for host_name
// - name_len (2 bytes): length of this name
// - name (variable): the actual hostname
nameListLen := 1 + 2 + nameLen // name_type + name_len + name
// Extension data = name_list_len (2) + name_type (1) + name_len (2) + name
extDataLen := 2 + nameListLen
// Full extension = type (2) + length (2) + data (variable)
ext := make([]byte, 4+extDataLen)
ext[0] = 0x00 // Extension type: SNI (0)
ext[1] = 0x00
ext[2] = byte(extDataLen >> 8)
ext[3] = byte(extDataLen)
ext[4] = byte(nameListLen >> 8) // name_list_len (high byte)
ext[5] = byte(nameListLen) // name_list_len (low byte)
ext[6] = 0x00 // name_type: host_name (0)
ext[7] = byte(nameLen >> 8) // name_len (high byte)
ext[8] = byte(nameLen) // name_len (low byte)
copy(ext[9:], sni)
return ext
}
// buildALPNExtension builds an Application-Layer Protocol Negotiation extension
func buildALPNExtension(protocols []string) []byte {
// Calculate ALPN data length
// ALPN data = alpn_list_len (2) + for each protocol: length (1) + data (variable)
alpnDataLen := 0
for _, proto := range protocols {
alpnDataLen += 1 + len(proto) // length byte + protocol string
}
// Extension data = alpn_list_len (2) + protocols
extDataLen := 2 + alpnDataLen
// Full extension = type (2) + length (2) + data (variable)
ext := make([]byte, 4+extDataLen)
ext[0] = 0x00 // Extension type: ALPN (16 = 0x10)
ext[1] = 0x10
ext[2] = byte(extDataLen >> 8)
ext[3] = byte(extDataLen)
// ALPN protocol list length (2 bytes)
ext[4] = byte(alpnDataLen >> 8)
ext[5] = byte(alpnDataLen)
offset := 6
for _, proto := range protocols {
ext[offset] = byte(len(proto))
offset++
copy(ext[offset:], proto)
offset += len(proto)
}
return ext
}
// min returns the minimum of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// TestExtractSNIFromPayload tests the SNI extraction function
func TestExtractSNIFromPayload(t *testing.T) {
tests := []struct {
name string
payload []byte
wantSNI string
}{
{
name: "empty_payload",
payload: []byte{},
wantSNI: "",
},
{
name: "payload_too_short",
payload: []byte{0x01, 0x00, 0x00, 0x10}, // Only 4 bytes
wantSNI: "",
},
{
name: "no_extensions",
payload: buildClientHelloWithoutExtensions(),
wantSNI: "",
},
{
name: "with_sni_extension",
payload: buildClientHelloWithSNI("example.com"),
wantSNI: "example.com",
},
{
name: "with_sni_long_domain",
payload: buildClientHelloWithSNI("very-long-subdomain.example-test-domain.com"),
wantSNI: "very-long-subdomain.example-test-domain.com",
},
{
name: "malformed_sni_truncated",
payload: buildTruncatedSNI(),
wantSNI: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractSNIFromPayload(tt.payload)
if got != tt.wantSNI {
t.Errorf("extractSNIFromPayload() = %q, want %q", got, tt.wantSNI)
}
})
}
}
// TestCleanupExpiredFlows tests the flow cleanup functionality
func TestCleanupExpiredFlows(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
// Create a flow manually using exported types
key := "192.168.1.1:12345->10.0.0.1:443"
flow := &ConnectionFlow{
State: NEW,
LastSeen: time.Now().Add(-2 * time.Hour), // Old flow
HelloBuffer: make([]byte, 0),
}
p.mu.Lock()
p.flows[key] = flow
p.mu.Unlock()
// Call cleanup
p.cleanupExpiredFlows()
// Flow should be deleted
p.mu.RLock()
_, exists := p.flows[key]
p.mu.RUnlock()
if exists {
t.Error("cleanupExpiredFlows() should have removed the expired flow")
}
}
// TestCleanupExpiredFlows_JA4Done tests that JA4_DONE flows are cleaned up immediately
func TestCleanupExpiredFlows_JA4Done(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
// Create a JA4_DONE flow (should be cleaned up regardless of timestamp)
key := "192.168.1.1:12345->10.0.0.1:443"
flow := &ConnectionFlow{
State: JA4_DONE,
LastSeen: time.Now(), // Recent, but should still be deleted
HelloBuffer: make([]byte, 0),
}
p.mu.Lock()
p.flows[key] = flow
p.mu.Unlock()
// Call cleanup
p.cleanupExpiredFlows()
// Flow should be deleted
p.mu.RLock()
_, exists := p.flows[key]
p.mu.RUnlock()
if exists {
t.Error("cleanupExpiredFlows() should have removed the JA4_DONE flow")
}
}
// TestCleanupExpiredFlows_RecentFlow tests that recent flows are NOT cleaned up
func TestCleanupExpiredFlows_RecentFlow(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
// Create a recent flow (should NOT be cleaned up)
key := "192.168.1.1:12345->10.0.0.1:443"
flow := &ConnectionFlow{
State: WAIT_CLIENT_HELLO,
LastSeen: time.Now(),
HelloBuffer: make([]byte, 0),
}
p.mu.Lock()
p.flows[key] = flow
p.mu.Unlock()
// Call cleanup
p.cleanupExpiredFlows()
// Flow should still exist
p.mu.RLock()
_, exists := p.flows[key]
p.mu.RUnlock()
if !exists {
t.Error("cleanupExpiredFlows() should NOT have removed the recent flow")
}
}
// TestCleanupLoop tests the cleanup goroutine shutdown
func TestCleanupLoop_Shutdown(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
// Close should stop the cleanup loop
err := p.Close()
if err != nil {
t.Errorf("Close() error = %v", err)
}
// Give goroutine time to exit
time.Sleep(100 * time.Millisecond)
// Verify cleanupDone channel is closed
select {
case _, ok := <-p.cleanupDone:
if ok {
t.Error("cleanupDone channel should be closed after Close()")
}
default:
t.Error("cleanupDone channel should be closed after Close()")
}
}
// buildClientHelloWithoutExtensions creates a ClientHello without extensions
func buildClientHelloWithoutExtensions() []byte {
// Minimal ClientHello: type(1) + len(3) + version(2) + random(32) + sessionIDLen(1) + cipherLen(2) + compressLen(1) + extLen(2)
handshake := make([]byte, 43)
handshake[0] = 0x01 // ClientHello type
handshake[1] = 0x00
handshake[2] = 0x00
handshake[3] = 0x27 // Length
handshake[4] = 0x03 // Version TLS 1.2
handshake[5] = 0x03
// Random (32 bytes) - zeros
handshake[38] = 0x00 // Session ID length
handshake[39] = 0x00 // Cipher suite length (high)
handshake[40] = 0x02 // Cipher suite length (low)
handshake[41] = 0x13 // Cipher suite data
handshake[42] = 0x01
// Compression length (1 byte) - 0
// Extensions length (2 bytes) - 0
return handshake
}
// buildClientHelloWithSNI creates a ClientHello with SNI extension
func buildClientHelloWithSNI(sni string) []byte {
// Build base handshake
handshake := make([]byte, 43)
handshake[0] = 0x01 // ClientHello type
handshake[1] = 0x00
handshake[2] = 0x00
handshake[4] = 0x03 // Version TLS 1.2
handshake[5] = 0x03
handshake[38] = 0x00 // Session ID length
handshake[39] = 0x00 // Cipher suite length (high)
handshake[40] = 0x02 // Cipher suite length (low)
handshake[41] = 0x13 // Cipher suite data
handshake[42] = 0x01
// Add compression length (1 byte) - 0
handshake = append(handshake, 0x00)
// Add extensions
sniExt := buildSNIExtension(sni)
extLen := len(sniExt)
handshake = append(handshake, byte(extLen>>8), byte(extLen))
handshake = append(handshake, sniExt...)
// Update handshake length
handshakeLen := len(handshake) - 4
handshake[1] = byte(handshakeLen >> 16)
handshake[2] = byte(handshakeLen >> 8)
handshake[3] = byte(handshakeLen)
return handshake
}
// buildTruncatedSNI creates a malformed ClientHello with truncated SNI
func buildTruncatedSNI() []byte {
// Build base handshake
handshake := make([]byte, 44)
handshake[0] = 0x01
handshake[4] = 0x03
handshake[5] = 0x03
handshake[38] = 0x00
handshake[39] = 0x00
handshake[40] = 0x02
handshake[41] = 0x13
handshake[42] = 0x01
handshake[43] = 0x00 // Compression length
// Add extensions with truncated SNI
// Extension type (2) + length (2) + data (truncated)
handshake = append(handshake, 0x00, 0x0a) // Extension length says 10 bytes
handshake = append(handshake, 0x00, 0x05) // But only provide 5 bytes of data
handshake = append(handshake, 0x00, 0x03, 0x74, 0x65, 0x73) // "tes" truncated
return handshake
}
// TestJoinStringSlice tests the deprecated helper function
func TestJoinStringSlice(t *testing.T) {
tests := []struct {
name string
slice []string
sep string
want string
}{
{
name: "empty_slice",
slice: []string{},
sep: ",",
want: "",
},
{
name: "single_element",
slice: []string{"MSS"},
sep: ",",
want: "MSS",
},
{
name: "multiple_elements",
slice: []string{"MSS", "SACK", "TS"},
sep: ",",
want: "MSS,SACK,TS",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := joinStringSlice(tt.slice, tt.sep)
if got != tt.want {
t.Errorf("joinStringSlice() = %q, want %q", got, tt.want)
}
})
}
}
// TestProcess_NilPacketData tests error handling for nil packet data
func TestProcess_NilPacketData(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
pkt := api.RawPacket{
Data: nil,
Timestamp: time.Now().UnixNano(),
}
_, err := p.Process(pkt)
if err == nil {
t.Error("Process() with nil data should return error")
}
if err.Error() != "empty packet data" {
t.Errorf("Process() error = %v, want 'empty packet data'", err)
}
}
// TestProcess_EmptyPacketData tests error handling for empty packet data
func TestProcess_EmptyPacketData(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
pkt := api.RawPacket{
Data: []byte{},
Timestamp: time.Now().UnixNano(),
}
_, err := p.Process(pkt)
if err == nil {
t.Error("Process() with empty data should return error")
}
}
// TestProcess_SLLPacket tests parsing of Linux SLL (cooked capture) packets
func TestProcess_SLLPacket(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
srcIP := "192.168.1.100"
dstIP := "10.0.0.1"
srcPort := uint16(54321)
dstPort := uint16(443)
// Create a valid ClientHello payload
clientHello := createTLSClientHello(0x0303)
// Build SLL packet instead of Ethernet
pkt := buildSLLRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHello)
// Debug: try to parse the packet manually
packet := gopacket.NewPacket(pkt.Data, layers.LinkTypeLinuxSLL, gopacket.Default)
ipLayer := packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
t.Logf("DEBUG: SLL packet - no IPv4 layer found")
t.Logf("DEBUG: Packet data (first 50 bytes): % x", pkt.Data[:min(50, len(pkt.Data))])
t.Logf("DEBUG: Packet layers: %v", packet.Layers())
}
result, err := p.Process(pkt)
if err != nil {
t.Fatalf("Process() with SLL packet error = %v", err)
}
if result == nil {
t.Fatal("Process() with SLL packet should return TLSClientHello")
}
if result.SrcIP != srcIP {
t.Errorf("SrcIP = %v, want %v", result.SrcIP, srcIP)
}
if result.DstIP != dstIP {
t.Errorf("DstIP = %v, want %v", result.DstIP, dstIP)
}
}
// TestProcess_SLLPacket_IPv6 tests parsing of Linux SLL IPv6 packets
func TestProcess_SLLPacket_IPv6(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
srcIP := "2001:db8::1"
dstIP := "2001:db8::2"
srcPort := uint16(54321)
dstPort := uint16(443)
// Create a valid ClientHello payload
clientHello := createTLSClientHello(0x0303)
// Build SLL IPv6 packet
pkt := buildSLLRawPacketIPv6(t, srcIP, dstIP, srcPort, dstPort, clientHello)
result, err := p.Process(pkt)
if err != nil {
t.Fatalf("Process() with SLL IPv6 packet error = %v", err)
}
if result == nil {
t.Fatal("Process() with SLL IPv6 packet should return TLSClientHello")
}
if result.SrcIP != srcIP {
t.Errorf("SrcIP = %v, want %v", result.SrcIP, srcIP)
}
if result.DstIP != dstIP {
t.Errorf("DstIP = %v, want %v", result.DstIP, dstIP)
}
}
// TestProcess_EthernetFallback tests that Ethernet parsing still works
func TestProcess_EthernetFallback(t *testing.T) {
p := NewParser()
if p == nil {
t.Fatal("NewParser() returned nil")
}
defer p.Close()
srcIP := "192.168.1.100"
dstIP := "10.0.0.1"
srcPort := uint16(54321)
dstPort := uint16(443)
clientHello := createTLSClientHello(0x0303)
// Build standard Ethernet packet
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHello)
result, err := p.Process(pkt)
if err != nil {
t.Fatalf("Process() with Ethernet packet error = %v", err)
}
if result == nil {
t.Fatal("Process() with Ethernet packet should return TLSClientHello")
}
}
// buildSLLRawPacket builds a Linux SLL (cooked capture) packet
// Manually constructs SLL header since layers.LinuxSLL doesn't implement SerializableLayer
func buildSLLRawPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket {
t.Helper()
// Linux SLL header (16 bytes) - manually constructed
// See: https://www.tcpdump.org/linktypes/LINKTYPE_LINUX_SLL.html
// Packet type (2 bytes): 0x0000 = PACKET_HOST
// Address length (2 bytes): 0x0006 = 6 bytes (MAC)
// Address (8 bytes): 00:11:22:33:44:55 + 2 padding bytes
// Protocol type (2 bytes): 0x0800 = IPv4
sllHeader := make([]byte, 16)
sllHeader[0] = 0x00 // Packet type: PACKET_HOST (high byte)
sllHeader[1] = 0x00 // Packet type: PACKET_HOST (low byte)
sllHeader[2] = 0x00 // Address length (high byte)
sllHeader[3] = 0x06 // Address length (low byte) = 6
// Address (8 bytes, only 6 used)
sllHeader[4] = 0x00
sllHeader[5] = 0x11
sllHeader[6] = 0x22
sllHeader[7] = 0x33
sllHeader[8] = 0x44
sllHeader[9] = 0x55
sllHeader[10] = 0x00 // Padding
sllHeader[11] = 0x00 // Padding
sllHeader[12] = 0x08 // Protocol type: IPv4 (high byte)
sllHeader[13] = 0x00 // Protocol type: IPv4 (low byte)
ip := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(srcIP).To4(),
DstIP: net.ParseIP(dstIP).To4(),
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: 1,
ACK: true,
Window: 65535,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
t.Fatalf("SetNetworkLayerForChecksum() error = %v", err)
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
// Serialize IP + TCP + payload (SLL header is prepended manually)
if err := gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload(payload)); err != nil {
t.Fatalf("SerializeLayers() error = %v", err)
}
// Prepend SLL header
packetData := append(sllHeader, buf.Bytes()...)
return api.RawPacket{
Data: packetData,
Timestamp: time.Now().UnixNano(),
LinkType: 101, // Linux SLL
}
}
// buildSLLRawPacketIPv6 builds a Linux SLL IPv6 packet
func buildSLLRawPacketIPv6(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket {
t.Helper()
// Linux SLL header for IPv6
// Protocol type: 0x86DD = IPv6
sllHeader := make([]byte, 16)
sllHeader[0] = 0x00 // Packet type: PACKET_HOST (high byte)
sllHeader[1] = 0x00 // Packet type: PACKET_HOST (low byte)
sllHeader[2] = 0x00 // Address length (high byte)
sllHeader[3] = 0x06 // Address length (low byte) = 6
// Address (8 bytes, only 6 used)
sllHeader[4] = 0x00
sllHeader[5] = 0x11
sllHeader[6] = 0x22
sllHeader[7] = 0x33
sllHeader[8] = 0x44
sllHeader[9] = 0x55
sllHeader[10] = 0x00 // Padding
sllHeader[11] = 0x00 // Padding
sllHeader[12] = 0x86 // Protocol type: IPv6 (high byte)
sllHeader[13] = 0xDD // Protocol type: IPv6 (low byte)
ip := &layers.IPv6{
Version: 6,
HopLimit: 64,
SrcIP: net.ParseIP(srcIP).To16(),
DstIP: net.ParseIP(dstIP).To16(),
NextHeader: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: 1,
ACK: true,
Window: 65535,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
t.Fatalf("SetNetworkLayerForChecksum() error = %v", err)
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
if err := gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload(payload)); err != nil {
t.Fatalf("SerializeLayers() error = %v", err)
}
// Prepend SLL header
packetData := append(sllHeader, buf.Bytes()...)
return api.RawPacket{
Data: packetData,
Timestamp: time.Now().UnixNano(),
LinkType: 101, // Linux SLL
}
}
// TestParser_SLLPacketType tests different SLL packet types
func TestParser_SLLPacketType(t *testing.T) {
// Test that the parser handles SLL packets with different packet types
p := NewParser()
defer p.Close()
// PACKET_HOST (0) - packet destined for local host
srcIP := "192.168.1.100"
dstIP := "10.0.0.1"
srcPort := uint16(54321)
dstPort := uint16(443)
clientHello := createTLSClientHello(0x0303)
pkt := buildSLLRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHello)
result, err := p.Process(pkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if result == nil {
t.Fatal("Process() should return TLSClientHello for PACKET_HOST")
}
}
// buildSYNPacket creates a raw SYN packet (no payload) with TCP options
func buildSYNPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, mss uint16, windowScale uint8) api.RawPacket {
t.Helper()
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Id: 0x1234,
Flags: layers.IPv4DontFragment,
SrcIP: net.ParseIP(srcIP).To4(),
DstIP: net.ParseIP(dstIP).To4(),
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: 1000,
SYN: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: []byte{byte(mss >> 8), byte(mss)},
},
{
OptionType: layers.TCPOptionKindWindowScale,
OptionLength: 3,
OptionData: []byte{windowScale},
},
{
OptionType: layers.TCPOptionKindSACKPermitted,
OptionLength: 2,
},
},
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
t.Fatalf("SetNetworkLayerForChecksum() error = %v", err)
}
eth := &layers.Ethernet{
SrcMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
DstMAC: net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
EthernetType: layers.EthernetTypeIPv4,
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
if err := gopacket.SerializeLayers(buf, opts, eth, ip, tcp); err != nil {
t.Fatalf("SerializeLayers() error = %v", err)
}
return api.RawPacket{
Data: buf.Bytes(),
Timestamp: time.Now().UnixNano(),
LinkType: 1,
}
}
func TestProcess_SYNCreatesFlowWithTCPMeta(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.50"
dstIP := "10.0.0.1"
srcPort := uint16(44444)
dstPort := uint16(443)
expectedMSS := uint16(1460)
expectedWS := uint8(7)
// Step 1: Send SYN packet (should create flow, return nil)
synPkt := buildSYNPacket(t, srcIP, dstIP, srcPort, dstPort, expectedMSS, expectedWS)
ch, err := parser.Process(synPkt)
if err != nil {
t.Fatalf("Process(SYN) error = %v", err)
}
if ch != nil {
t.Fatal("Process(SYN) should return nil (no ClientHello yet)")
}
// Verify flow was created with correct metadata
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
flow, exists := parser.flows[key]
parser.mu.RUnlock()
if !exists {
t.Fatal("SYN should create a flow")
}
flow.mu.Lock()
if flow.State != NEW {
t.Errorf("flow state = %v, want NEW", flow.State)
}
if flow.TCPMeta.MSS != expectedMSS {
t.Errorf("flow TCPMeta.MSS = %d, want %d", flow.TCPMeta.MSS, expectedMSS)
}
if flow.TCPMeta.WindowScale != expectedWS {
t.Errorf("flow TCPMeta.WindowScale = %d, want %d", flow.TCPMeta.WindowScale, expectedWS)
}
if flow.TCPMeta.WindowSize != 65535 {
t.Errorf("flow TCPMeta.WindowSize = %d, want 65535", flow.TCPMeta.WindowSize)
}
// Check SACK is in options
hasSACK := false
for _, opt := range flow.TCPMeta.Options {
if opt == "SACK" {
hasSACK = true
}
}
if !hasSACK {
t.Errorf("flow TCPMeta.Options = %v, want SACK", flow.TCPMeta.Options)
}
if flow.IPMeta.TTL != 64 {
t.Errorf("flow IPMeta.TTL = %d, want 64", flow.IPMeta.TTL)
}
if !flow.IPMeta.DF {
t.Error("flow IPMeta.DF should be true")
}
flow.mu.Unlock()
// Step 2: Send ClientHello data packet (SYN had Seq=1000, so data starts at 1001)
clientHello := createTLSClientHello(0x0303)
dataPkt := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, clientHello, 1001)
result, err := parser.Process(dataPkt)
if err != nil {
t.Fatalf("Process(ClientHello) error = %v", err)
}
if result == nil {
t.Fatal("Process(ClientHello) should return TLSClientHello")
}
// Verify result uses TCP metadata from SYN, not from data packet
if result.TCPMeta.MSS != expectedMSS {
t.Errorf("result TCPMeta.MSS = %d, want %d (from SYN)", result.TCPMeta.MSS, expectedMSS)
}
if result.TCPMeta.WindowScale != expectedWS {
t.Errorf("result TCPMeta.WindowScale = %d, want %d (from SYN)", result.TCPMeta.WindowScale, expectedWS)
}
if result.IPMeta.TTL != 64 {
t.Errorf("result IPMeta.TTL = %d, want 64 (from SYN)", result.IPMeta.TTL)
}
if !result.IPMeta.DF {
t.Error("result IPMeta.DF should be true (from SYN)")
}
}
func TestProcess_SynToCHMs_Timing(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.60"
dstIP := "10.0.0.1"
srcPort := uint16(55555)
dstPort := uint16(443)
// Step 1: Send SYN
synPkt := buildSYNPacket(t, srcIP, dstIP, srcPort, dstPort, 1460, 7)
_, err := parser.Process(synPkt)
if err != nil {
t.Fatalf("Process(SYN) error = %v", err)
}
// Wait a measurable amount of time
time.Sleep(50 * time.Millisecond)
// Step 2: Send ClientHello (SYN had Seq=1000, data at 1001)
clientHello := createTLSClientHello(0x0303)
dataPkt := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, clientHello, 1001)
result, err := parser.Process(dataPkt)
if err != nil {
t.Fatalf("Process(ClientHello) error = %v", err)
}
if result == nil {
t.Fatal("Process(ClientHello) should return TLSClientHello")
}
if result.SynToCHMs == nil {
t.Fatal("SynToCHMs should not be nil")
}
// SynToCHMs should be at least 50ms (we slept 50ms)
if *result.SynToCHMs < 40 {
t.Errorf("SynToCHMs = %d ms, want >= 40ms (slept 50ms)", *result.SynToCHMs)
}
}
func TestProcess_NoSYN_StillWorks(t *testing.T) {
// Ensure backward compatibility: if no SYN is seen (e.g. capture started
// mid-connection), a ClientHello data packet still creates a flow and works.
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.70"
dstIP := "10.0.0.1"
srcPort := uint16(55666)
dstPort := uint16(443)
clientHello := createTLSClientHello(0x0303)
dataPkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHello)
result, err := parser.Process(dataPkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if result == nil {
t.Fatal("Process() should return TLSClientHello even without SYN")
}
if result.SrcIP != srcIP {
t.Errorf("SrcIP = %v, want %v", result.SrcIP, srcIP)
}
}
func TestProcess_FragmentedClientHello_UsesFlowMeta(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.80"
dstIP := "10.0.0.1"
srcPort := uint16(33333)
dstPort := uint16(443)
expectedMSS := uint16(1460)
expectedWS := uint8(7)
// Step 1: Send SYN with TCP options
synPkt := buildSYNPacket(t, srcIP, dstIP, srcPort, dstPort, expectedMSS, expectedWS)
_, err := parser.Process(synPkt)
if err != nil {
t.Fatalf("Process(SYN) error = %v", err)
}
// Step 2: Send incomplete TLS record (fragment 1, Seq=1001)
clientHello := createTLSClientHello(0x0303)
half := len(clientHello) / 2
fragment1 := clientHello[:half]
pkt1 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, fragment1, 1001)
ch, err := parser.Process(pkt1)
if err != nil {
t.Fatalf("Process(fragment1) error = %v", err)
}
if ch != nil {
t.Fatal("Process(fragment1) should return nil (incomplete)")
}
// Step 3: Send rest (fragment 2, Seq=1001+len(fragment1))
fragment2 := clientHello[half:]
pkt2 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, fragment2, 1001+uint32(half))
result, err := parser.Process(pkt2)
if err != nil {
t.Fatalf("Process(fragment2) error = %v", err)
}
if result == nil {
t.Fatal("Process(fragment2) should return complete TLSClientHello")
}
// Verify metadata comes from the SYN (flow), not from the last data fragment
if result.TCPMeta.MSS != expectedMSS {
t.Errorf("result TCPMeta.MSS = %d, want %d (from SYN)", result.TCPMeta.MSS, expectedMSS)
}
if result.TCPMeta.WindowScale != expectedWS {
t.Errorf("result TCPMeta.WindowScale = %d, want %d (from SYN)", result.TCPMeta.WindowScale, expectedWS)
}
if result.IPMeta.TTL != 64 {
t.Errorf("result IPMeta.TTL = %d, want 64 (from SYN)", result.IPMeta.TTL)
}
}
func TestProcess_TCPRetransmission_Ignored(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.90"
dstIP := "10.0.0.1"
srcPort := uint16(44321)
dstPort := uint16(443)
// Step 1: Send SYN (Seq=1000)
synPkt := buildSYNPacket(t, srcIP, dstIP, srcPort, dstPort, 1460, 7)
_, _ = parser.Process(synPkt)
// Step 2: Send first fragment (Seq=1001)
clientHello := createTLSClientHello(0x0303)
half := len(clientHello) / 2
fragment1 := clientHello[:half]
pkt1 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, fragment1, 1001)
_, _ = parser.Process(pkt1)
// Step 3: Retransmit fragment 1 (same Seq=1001) — should be ignored
pkt1dup := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, fragment1, 1001)
ch, err := parser.Process(pkt1dup)
if err != nil {
t.Fatalf("Process(retransmit) error = %v", err)
}
if ch != nil {
t.Fatal("Process(retransmit) should return nil")
}
// Step 4: Send second fragment (correct Seq)
fragment2 := clientHello[half:]
pkt2 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, fragment2, 1001+uint32(half))
result, err := parser.Process(pkt2)
if err != nil {
t.Fatalf("Process(fragment2) error = %v", err)
}
if result == nil {
t.Fatal("Process(fragment2) should return complete TLSClientHello after retransmission")
}
}
func TestProcess_TCPGap_DropsFlow(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.91"
dstIP := "10.0.0.1"
srcPort := uint16(44322)
dstPort := uint16(443)
// Step 1: Send SYN (Seq=1000)
synPkt := buildSYNPacket(t, srcIP, dstIP, srcPort, dstPort, 1460, 7)
_, _ = parser.Process(synPkt)
// Step 2: Send first fragment (Seq=1001)
clientHello := createTLSClientHello(0x0303)
half := len(clientHello) / 2
fragment1 := clientHello[:half]
pkt1 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, fragment1, 1001)
_, _ = parser.Process(pkt1)
// Step 3: Send fragment with gap (Seq far ahead) — should drop flow
fragment2 := clientHello[half:]
gapSeq := uint32(1001 + half + 100) // 100 bytes gap
pkt2 := buildRawPacketWithSeq(t, srcIP, dstIP, srcPort, dstPort, fragment2, gapSeq)
ch, err := parser.Process(pkt2)
if err != nil {
t.Fatalf("Process(gap) error = %v", err)
}
if ch != nil {
t.Fatal("Process(gap) should return nil")
}
// Verify flow was removed
key := flowKey(srcIP, srcPort, dstIP, dstPort)
parser.mu.RLock()
_, exists := parser.flows[key]
parser.mu.RUnlock()
if exists {
t.Fatal("flow should be removed after sequence gap")
}
}
// createTLS13ClientHelloWithSNI creates a TLS 1.3 ClientHello (record version 0x0303,
// supported_versions extension includes 0x0304)
func createTLS13ClientHelloWithSNI(sni string) []byte {
// Build SNI extension
sniExt := buildSNIExtension(sni)
// Build ALPN extension
alpnExt := buildALPNExtension([]string{"h2", "http/1.1"})
// Build supported_versions extension with TLS 1.3 (0x0304) and TLS 1.2 (0x0303)
// Extension type: 43 (0x002b), data: list_len(1) + 2 versions (4 bytes)
supportedVersionsExt := []byte{
0x00, 0x2b, // Extension type: supported_versions (43)
0x00, 0x05, // Extension data length: 5
0x04, // Supported versions list length: 4 bytes (2 versions)
0x03, 0x04, // TLS 1.3
0x03, 0x03, // TLS 1.2
}
// Combine extensions
extensions := append(sniExt, alpnExt...)
extensions = append(extensions, supportedVersionsExt...)
extLen := len(extensions)
// Cipher suites (TLS 1.3 suites)
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
// Compression methods (null only)
compressionMethods := []byte{0x01, 0x00}
// Build ClientHello handshake body
handshakeBody := []byte{
0x03, 0x03, // Version: TLS 1.2 (mandatory for TLS 1.3 ClientHello)
}
// Random (32 bytes)
for i := 0; i < 32; i++ {
handshakeBody = append(handshakeBody, 0x01)
}
handshakeBody = append(handshakeBody, 0x00) // Session ID length: 0
// Add cipher suites
cipherSuiteLen := len(cipherSuites)
handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen))
handshakeBody = append(handshakeBody, cipherSuites...)
// Add compression methods
handshakeBody = append(handshakeBody, compressionMethods...)
// Add extensions
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
handshakeBody = append(handshakeBody, extensions...)
// Build handshake with type and length
handshakeLen := len(handshakeBody)
handshake := append([]byte{
0x01, // Handshake type: ClientHello
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen),
}, handshakeBody...)
// Build TLS record (version always 0x0303 for TLS 1.3)
recordLen := len(handshake)
record := make([]byte, 5+recordLen)
record[0] = 0x16 // Handshake
record[1] = 0x03 // TLS 1.2 in record layer (per TLS 1.3 spec)
record[2] = 0x03
record[3] = byte(recordLen >> 8)
record[4] = byte(recordLen)
copy(record[5:], handshake)
return record
}
func TestExtractTLSExtensions_TLS13(t *testing.T) {
payload := createTLS13ClientHelloWithSNI("example.com")
info, err := extractTLSExtensions(payload)
if err != nil {
t.Fatalf("extractTLSExtensions() error = %v", err)
}
if info == nil {
t.Fatal("extractTLSExtensions() returned nil")
}
// TLS 1.3 should be detected via supported_versions extension
if info.TLSVersion != "1.3" {
t.Errorf("TLSVersion = %q, want \"1.3\"", info.TLSVersion)
}
if info.SNI != "example.com" {
t.Errorf("SNI = %q, want \"example.com\"", info.SNI)
}
}
func TestProcess_TLS13ClientHello_CorrectVersion(t *testing.T) {
parser := NewParser()
defer parser.Close()
srcIP := "192.168.1.200"
dstIP := "10.0.0.1"
srcPort := uint16(44555)
dstPort := uint16(443)
clientHello := createTLS13ClientHelloWithSNI("tls13.example.com")
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHello)
result, err := parser.Process(pkt)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if result == nil {
t.Fatal("Process() should return TLSClientHello")
}
if result.TLSVersion != "1.3" {
t.Errorf("TLSVersion = %q, want \"1.3\"", result.TLSVersion)
}
if result.SNI != "tls13.example.com" {
t.Errorf("SNI = %q, want \"tls13.example.com\"", result.SNI)
}
}