package tlsparse import ( "net" "testing" "time" "ja4sentinel/api" tlsfingerprint "github.com/psanford/tlsfingerprint" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) func TestIsClientHello(t *testing.T) { tests := []struct { name string payload []byte want bool }{ { name: "empty payload", payload: []byte{}, want: false, }, { name: "too short", payload: []byte{0x16, 0x03, 0x03}, want: false, }, { name: "valid TLS 1.2 ClientHello", payload: createTLSClientHello(0x0303), want: true, }, { name: "valid TLS 1.3 ClientHello", payload: createTLSClientHello(0x0304), want: true, }, { name: "not a handshake", payload: []byte{0x17, 0x03, 0x03, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, want: false, }, { name: "ServerHello (type 2)", payload: createTLSServerHello(0x0303), want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := IsClientHello(tt.payload) if got != tt.want { t.Errorf("IsClientHello() = %v, want %v", got, tt.want) } }) } } func TestParseClientHello(t *testing.T) { tests := []struct { name string payload []byte wantErr bool wantNil bool }{ { name: "empty payload", payload: []byte{}, wantErr: false, wantNil: true, }, { name: "valid ClientHello", payload: createTLSClientHello(0x0303), wantErr: false, wantNil: false, }, { name: "incomplete record", payload: []byte{0x16, 0x03, 0x03, 0x01, 0x00, 0x01}, wantErr: false, wantNil: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseClientHello(tt.payload) if (err != nil) != tt.wantErr { t.Errorf("parseClientHello() error = %v, wantErr %v", err, tt.wantErr) return } if (got == nil) != tt.wantNil { t.Errorf("parseClientHello() = %v, wantNil %v", got == nil, tt.wantNil) } }) } } func TestExtractIPMeta(t *testing.T) { ipLayer := &layers.IPv4{ TTL: 64, Length: 1500, Id: 12345, Flags: layers.IPv4DontFragment, SrcIP: []byte{192, 168, 1, 1}, DstIP: []byte{10, 0, 0, 1}, } meta := extractIPMeta(ipLayer) if meta.TTL != 64 { t.Errorf("TTL = %v, want 64", meta.TTL) } if meta.TotalLength != 1500 { t.Errorf("TotalLength = %v, want 1500", meta.TotalLength) } if meta.IPID != 12345 { t.Errorf("IPID = %v, want 12345", meta.IPID) } if !meta.DF { t.Error("DF = false, want true") } } func TestExtractTCPMeta(t *testing.T) { tcp := &layers.TCP{ SrcPort: 12345, DstPort: 443, Window: 65535, Options: []layers.TCPOption{ { OptionType: layers.TCPOptionKindMSS, OptionData: []byte{0x05, 0xb4}, // 1460 }, { OptionType: layers.TCPOptionKindWindowScale, OptionData: []byte{0x07}, // scale 7 }, { OptionType: layers.TCPOptionKindSACKPermitted, OptionData: []byte{}, }, }, } meta := extractTCPMeta(tcp) if meta.WindowSize != 65535 { t.Errorf("WindowSize = %v, want 65535", meta.WindowSize) } if meta.MSS != 1460 { t.Errorf("MSS = %v, want 1460", meta.MSS) } if meta.WindowScale != 7 { t.Errorf("WindowScale = %v, want 7", meta.WindowScale) } if len(meta.Options) != 3 { t.Errorf("Options length = %v, want 3", len(meta.Options)) } } // Helper functions to create test TLS records func createTLSClientHello(version uint16) []byte { // Minimal TLS ClientHello record handshake := []byte{ 0x01, // Handshake type: ClientHello 0x00, 0x00, 0x00, 0x10, // Handshake length (16 bytes) // ClientHello body (simplified) 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } record := make([]byte, 5+len(handshake)) record[0] = 0x16 // Handshake record[1] = byte(version >> 8) record[2] = byte(version) record[3] = byte(len(handshake) >> 8) record[4] = byte(len(handshake)) copy(record[5:], handshake) return record } func createTLSServerHello(version uint16) []byte { // Minimal TLS ServerHello record handshake := []byte{ 0x02, // Handshake type: ServerHello 0x00, 0x00, 0x00, 0x10, // Handshake length 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } record := make([]byte, 5+len(handshake)) record[0] = 0x16 // Handshake record[1] = byte(version >> 8) record[2] = byte(version) record[3] = byte(len(handshake) >> 8) record[4] = byte(len(handshake)) copy(record[5:], handshake) return record } func TestNewParser(t *testing.T) { parser := NewParser() defer parser.Close() if parser == nil { t.Error("NewParser() returned nil") } if parser.flows == nil { t.Error("NewParser() flows map not initialized") } if parser.flowTimeout == 0 { t.Error("NewParser() flowTimeout not set") } } func TestParserClose(t *testing.T) { parser := NewParser() err := parser.Close() if err != nil { t.Errorf("Close() error = %v", err) } } func TestFlowKey(t *testing.T) { // Test unidirectional flow key (client -> server only) key := flowKey("192.168.1.1", 12345, "10.0.0.1", 443) expected := "192.168.1.1:12345->10.0.0.1:443" if key != expected { t.Errorf("flowKey() = %v, want %v", key, expected) } } func TestParserConnectionStateTracking(t *testing.T) { parser := NewParser() defer parser.Close() // Create a valid ClientHello payload clientHello := createTLSClientHello(0x0303) // Test parseClientHello directly (lower-level test) result, err := parseClientHello(clientHello) if err != nil { t.Errorf("parseClientHello() error = %v", err) } if result == nil { t.Error("parseClientHello() should return ClientHello") } // Test IsClientHello helper if !IsClientHello(clientHello) { t.Error("IsClientHello() should return true for valid ClientHello") } } func TestParserClose_Idempotent(t *testing.T) { parser := NewParser() if err := parser.Close(); err != nil { t.Fatalf("first Close() error = %v", err) } if err := parser.Close(); err != nil { t.Fatalf("second Close() error = %v", err) } } func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) { tcp := &layers.TCP{ Window: 1234, Options: []layers.TCPOption{ { OptionType: layers.TCPOptionKindMSS, OptionData: []byte{0x05}, // malformed (1 byte instead of 2) }, }, } meta := extractTCPMeta(tcp) found := false for _, opt := range meta.Options { if opt == "MSS_INVALID" { found = true break } } if !found { t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options) } } func TestGetOrCreateFlow_RespectsMaxTrackedFlows(t *testing.T) { parser := NewParser() defer parser.Close() parser.maxTrackedFlows = 1 flow1 := parser.getOrCreateFlow( flowKey("192.168.1.1", 12345, "10.0.0.1", 443), "192.168.1.1", 12345, "10.0.0.1", 443, api.IPMeta{}, api.TCPMeta{}, ) if flow1 == nil { t.Fatal("first flow should be created") } flow2 := parser.getOrCreateFlow( flowKey("192.168.1.2", 12346, "10.0.0.1", 443), "192.168.1.2", 12346, "10.0.0.1", 443, api.IPMeta{}, api.TCPMeta{}, ) if flow2 != nil { t.Fatal("second flow should be nil when maxTrackedFlows is reached") } } func TestProcess_DropsWhenHelloBufferExceedsLimit(t *testing.T) { parser := NewParserWithTimeout(30 * time.Second) defer parser.Close() parser.maxHelloBufferBytes = 10 srcIP := "192.168.1.10" dstIP := "10.0.0.1" srcPort := uint16(12345) dstPort := uint16(443) // TLS-like payload, but intentionally incomplete to trigger accumulation. payloadChunk := []byte{0x16, 0x03, 0x03, 0x00, 0x20, 0x01} // len = 6 pkt1 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk) ch, err := parser.Process(pkt1) if err != nil { t.Fatalf("first Process() error = %v", err) } if ch != nil { t.Fatal("first Process() should not return complete ClientHello") } key := flowKey(srcIP, srcPort, dstIP, dstPort) parser.mu.RLock() _, existsAfterFirst := parser.flows[key] parser.mu.RUnlock() if !existsAfterFirst { t.Fatal("flow should exist after first chunk") } pkt2 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk) ch, err = parser.Process(pkt2) if err != nil { t.Fatalf("second Process() error = %v", err) } if ch != nil { t.Fatal("second Process() should not return ClientHello") } parser.mu.RLock() _, existsAfterSecond := parser.flows[key] parser.mu.RUnlock() if existsAfterSecond { t.Fatal("flow should be removed when hello buffer exceeds maxHelloBufferBytes") } } func TestProcess_NonTLSNewFlowNotTracked(t *testing.T) { parser := NewParser() defer parser.Close() srcIP := "192.168.1.20" dstIP := "10.0.0.2" srcPort := uint16(23456) dstPort := uint16(443) // Non-TLS content type (not 22) payload := []byte{0x17, 0x03, 0x03, 0x00, 0x05, 0x00} pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payload) ch, err := parser.Process(pkt) if err != nil { t.Fatalf("Process() error = %v", err) } if ch != nil { t.Fatal("Process() should return nil for non-TLS new flow") } key := flowKey(srcIP, srcPort, dstIP, dstPort) parser.mu.RLock() _, exists := parser.flows[key] parser.mu.RUnlock() if exists { t.Fatal("non-TLS new flow should not be tracked") } } func buildRawPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket { t.Helper() ip := &layers.IPv4{ Version: 4, TTL: 64, SrcIP: net.ParseIP(srcIP).To4(), DstIP: net.ParseIP(dstIP).To4(), Protocol: layers.IPProtocolTCP, } tcp := &layers.TCP{ SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort), Seq: 1, ACK: true, Window: 65535, } if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { t.Fatalf("SetNetworkLayerForChecksum() error = %v", err) } eth := &layers.Ethernet{ SrcMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, DstMAC: net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, EthernetType: layers.EthernetTypeIPv4, } buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ FixLengths: true, ComputeChecksums: true, } if err := gopacket.SerializeLayers(buf, opts, eth, ip, tcp, gopacket.Payload(payload)); err != nil { t.Fatalf("SerializeLayers() error = %v", err) } return api.RawPacket{ Data: buf.Bytes(), Timestamp: time.Now().UnixNano(), } } func TestTLSVersionToString(t *testing.T) { tests := []struct { version uint16 want string }{ {0x0301, "1.0"}, {0x0302, "1.1"}, {0x0303, "1.2"}, {0x0304, "1.3"}, {0x0300, ""}, {0x0305, ""}, } for _, tt := range tests { t.Run(tt.want, func(t *testing.T) { got := tlsVersionToString(tt.version) if got != tt.want { t.Errorf("tlsVersionToString(%#x) = %v, want %v", tt.version, got, tt.want) } }) } } func TestExtractTLSExtensions(t *testing.T) { tests := []struct { name string payload []byte wantSNI string wantALPN []string wantVersion string wantNil bool }{ { name: "empty payload", payload: []byte{}, wantNil: true, }, { name: "too short", payload: []byte{0x16, 0x03, 0x03}, wantNil: true, }, { name: "TLS 1.2 ClientHello without extensions", payload: createTLSClientHello(0x0303), wantVersion: "1.2", wantNil: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := extractTLSExtensions(tt.payload) if err != nil { t.Errorf("extractTLSExtensions() unexpected error = %v", err) return } if (got == nil) != tt.wantNil { t.Errorf("extractTLSExtensions() = %v, wantNil %v", got == nil, tt.wantNil) return } if got != nil { if got.TLSVersion != tt.wantVersion { t.Errorf("TLSVersion = %v, want %v", got.TLSVersion, tt.wantVersion) } } }) } } func TestParser_ExtractsTLSFields(t *testing.T) { parser := NewParser() defer parser.Close() // Create a minimal valid TLS 1.2 ClientHello with SNI and ALPN extensions // This is a real-world-like ClientHello structure clientHelloWithExt := createMinimalTLSClientHelloWithSNIAndALPN("example.com", []string{"h2", "http/1.1"}) // Debug: Check what extractTLSExtensions returns extInfo, err := extractTLSExtensions(clientHelloWithExt) if err != nil { t.Logf("extractTLSExtensions error: %v", err) } if extInfo != nil { t.Logf("extInfo: SNI=%q, ALPN=%v, Version=%q", extInfo.SNI, extInfo.ALPN, extInfo.TLSVersion) } else { t.Log("extInfo is nil") } // Also test with tlsfingerprint directly fp, err := tlsfingerprint.ParseClientHello(clientHelloWithExt) if err != nil { t.Logf("tlsfingerprint error: %v", err) } else { t.Logf("tlsfingerprint: ALPN=%v, Version=%#x, HasSNI=%v", fp.ALPNProtocols, fp.Version, fp.HasSNI) } // Debug: print first bytes of ClientHello t.Logf("ClientHello hex: % x", clientHelloWithExt[:min(50, len(clientHelloWithExt))]) srcIP := "192.168.1.100" dstIP := "10.0.0.1" srcPort := uint16(54321) dstPort := uint16(443) pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, clientHelloWithExt) result, err := parser.Process(pkt) if err != nil { t.Fatalf("Process() error = %v", err) } if result == nil { t.Fatal("Process() should return TLSClientHello") } // Verify new fields are populated if result.SNI != "example.com" { t.Errorf("SNI = %v, want example.com", result.SNI) } if result.ALPN == "" { t.Error("ALPN should not be empty") } if result.TLSVersion != "1.2" { t.Errorf("TLSVersion = %v, want 1.2", result.TLSVersion) } if result.ConnID == "" { t.Error("ConnID should not be empty") } if result.SynToCHMs == nil { t.Error("SynToCHMs should not be nil") } } // createMinimalTLSClientHelloWithSNIAndALPN creates a minimal but valid TLS 1.2 ClientHello // with SNI and ALPN extensions func createMinimalTLSClientHelloWithSNIAndALPN(sni string, alpnProtocols []string) []byte { // Build SNI extension sniExt := buildSNIExtension(sni) // Build ALPN extension alpnExt := buildALPNExtension(alpnProtocols) // Build supported_versions extension (TLS 1.2) supportedVersionsExt := []byte{0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x03} // type=43, len=3, TLS 1.2 // Combine extensions extensions := append(sniExt, alpnExt...) extensions = append(extensions, supportedVersionsExt...) extLen := len(extensions) // Cipher suites (minimal set) cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f} // 4 cipher suites: TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 // Compression methods (null only) compressionMethods := []byte{0x01, 0x00} // Build ClientHello handshake (without length header first) handshakeBody := []byte{ 0x03, 0x03, // Version: TLS 1.2 // Random (32 bytes) 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Session ID length: 0 } // Add cipher suites (with length prefix) cipherSuiteLen := len(cipherSuites) handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen)) handshakeBody = append(handshakeBody, cipherSuites...) // Add compression methods (with length prefix) handshakeBody = append(handshakeBody, compressionMethods...) // Add extensions (with length prefix) handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen)) handshakeBody = append(handshakeBody, extensions...) // Now build full handshake with type and length handshakeLen := len(handshakeBody) handshake := append([]byte{ 0x01, // Handshake type: ClientHello byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen), // Handshake length }, handshakeBody...) // Build TLS record recordLen := len(handshake) record := make([]byte, 5+recordLen) record[0] = 0x16 // Handshake record[1] = 0x03 // Version: TLS 1.2 record[2] = 0x03 record[3] = byte(recordLen >> 8) record[4] = byte(recordLen) copy(record[5:], handshake) return record } // buildSNIExtension builds a Server Name Indication extension func buildSNIExtension(sni string) []byte { nameLen := len(sni) // SNI extension structure: // - name_list_len (2 bytes): total length of all names // - name_type (1 byte): always 0x00 for host_name // - name_len (2 bytes): length of this name // - name (variable): the actual hostname nameListLen := 1 + 2 + nameLen // name_type + name_len + name // Extension data = name_list_len (2) + name_type (1) + name_len (2) + name extDataLen := 2 + nameListLen // Full extension = type (2) + length (2) + data (variable) ext := make([]byte, 4+extDataLen) ext[0] = 0x00 // Extension type: SNI (0) ext[1] = 0x00 ext[2] = byte(extDataLen >> 8) ext[3] = byte(extDataLen) ext[4] = byte(nameListLen >> 8) // name_list_len (high byte) ext[5] = byte(nameListLen) // name_list_len (low byte) ext[6] = 0x00 // name_type: host_name (0) ext[7] = byte(nameLen >> 8) // name_len (high byte) ext[8] = byte(nameLen) // name_len (low byte) copy(ext[9:], sni) return ext } // buildALPNExtension builds an Application-Layer Protocol Negotiation extension func buildALPNExtension(protocols []string) []byte { // Calculate ALPN data length // ALPN data = alpn_list_len (2) + for each protocol: length (1) + data (variable) alpnDataLen := 0 for _, proto := range protocols { alpnDataLen += 1 + len(proto) // length byte + protocol string } // Extension data = alpn_list_len (2) + protocols extDataLen := 2 + alpnDataLen // Full extension = type (2) + length (2) + data (variable) ext := make([]byte, 4+extDataLen) ext[0] = 0x00 // Extension type: ALPN (16 = 0x10) ext[1] = 0x10 ext[2] = byte(extDataLen >> 8) ext[3] = byte(extDataLen) // ALPN protocol list length (2 bytes) ext[4] = byte(alpnDataLen >> 8) ext[5] = byte(alpnDataLen) offset := 6 for _, proto := range protocols { ext[offset] = byte(len(proto)) offset++ copy(ext[offset:], proto) offset += len(proto) } return ext } // min returns the minimum of two integers func min(a, b int) int { if a < b { return a } return b }