package tlsparse import ( "testing" "github.com/google/gopacket/layers" ) func TestIsClientHello(t *testing.T) { tests := []struct { name string payload []byte want bool }{ { name: "empty payload", payload: []byte{}, want: false, }, { name: "too short", payload: []byte{0x16, 0x03, 0x03}, want: false, }, { name: "valid TLS 1.2 ClientHello", payload: createTLSClientHello(0x0303), want: true, }, { name: "valid TLS 1.3 ClientHello", payload: createTLSClientHello(0x0304), want: true, }, { name: "not a handshake", payload: []byte{0x17, 0x03, 0x03, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, want: false, }, { name: "ServerHello (type 2)", payload: createTLSServerHello(0x0303), want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := IsClientHello(tt.payload) if got != tt.want { t.Errorf("IsClientHello() = %v, want %v", got, tt.want) } }) } } func TestParseClientHello(t *testing.T) { tests := []struct { name string payload []byte wantErr bool wantNil bool }{ { name: "empty payload", payload: []byte{}, wantErr: false, wantNil: true, }, { name: "valid ClientHello", payload: createTLSClientHello(0x0303), wantErr: false, wantNil: false, }, { name: "incomplete record", payload: []byte{0x16, 0x03, 0x03, 0x01, 0x00, 0x01}, wantErr: false, wantNil: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseClientHello(tt.payload) if (err != nil) != tt.wantErr { t.Errorf("parseClientHello() error = %v, wantErr %v", err, tt.wantErr) return } if (got == nil) != tt.wantNil { t.Errorf("parseClientHello() = %v, wantNil %v", got == nil, tt.wantNil) } }) } } func TestExtractIPMeta(t *testing.T) { ipLayer := &layers.IPv4{ TTL: 64, Length: 1500, Id: 12345, Flags: layers.IPv4DontFragment, SrcIP: []byte{192, 168, 1, 1}, DstIP: []byte{10, 0, 0, 1}, } meta := extractIPMeta(ipLayer) if meta.TTL != 64 { t.Errorf("TTL = %v, want 64", meta.TTL) } if meta.TotalLength != 1500 { t.Errorf("TotalLength = %v, want 1500", meta.TotalLength) } if meta.IPID != 12345 { t.Errorf("IPID = %v, want 12345", meta.IPID) } if !meta.DF { t.Error("DF = false, want true") } } func TestExtractTCPMeta(t *testing.T) { tcp := &layers.TCP{ SrcPort: 12345, DstPort: 443, Window: 65535, Options: []layers.TCPOption{ { OptionType: layers.TCPOptionKindMSS, OptionData: []byte{0x05, 0xb4}, // 1460 }, { OptionType: layers.TCPOptionKindWindowScale, OptionData: []byte{0x07}, // scale 7 }, { OptionType: layers.TCPOptionKindSACKPermitted, OptionData: []byte{}, }, }, } meta := extractTCPMeta(tcp) if meta.WindowSize != 65535 { t.Errorf("WindowSize = %v, want 65535", meta.WindowSize) } if meta.MSS != 1460 { t.Errorf("MSS = %v, want 1460", meta.MSS) } if meta.WindowScale != 7 { t.Errorf("WindowScale = %v, want 7", meta.WindowScale) } if len(meta.Options) != 3 { t.Errorf("Options length = %v, want 3", len(meta.Options)) } } // Helper functions to create test TLS records func createTLSClientHello(version uint16) []byte { // Minimal TLS ClientHello record handshake := []byte{ 0x01, // Handshake type: ClientHello 0x00, 0x00, 0x00, 0x10, // Handshake length (16 bytes) // ClientHello body (simplified) 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } record := make([]byte, 5+len(handshake)) record[0] = 0x16 // Handshake record[1] = byte(version >> 8) record[2] = byte(version) record[3] = byte(len(handshake) >> 8) record[4] = byte(len(handshake)) copy(record[5:], handshake) return record } func createTLSServerHello(version uint16) []byte { // Minimal TLS ServerHello record handshake := []byte{ 0x02, // Handshake type: ServerHello 0x00, 0x00, 0x00, 0x10, // Handshake length 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } record := make([]byte, 5+len(handshake)) record[0] = 0x16 // Handshake record[1] = byte(version >> 8) record[2] = byte(version) record[3] = byte(len(handshake) >> 8) record[4] = byte(len(handshake)) copy(record[5:], handshake) return record } func TestNewParser(t *testing.T) { parser := NewParser() if parser == nil { t.Error("NewParser() returned nil") } if parser.flows == nil { t.Error("NewParser() flows map not initialized") } if parser.flowTimeout == 0 { t.Error("NewParser() flowTimeout not set") } } func TestParserClose(t *testing.T) { parser := NewParser() err := parser.Close() if err != nil { t.Errorf("Close() error = %v", err) } } func TestFlowKey(t *testing.T) { // Test unidirectional flow key (client -> server only) key := flowKey("192.168.1.1", 12345, "10.0.0.1", 443) expected := "192.168.1.1:12345->10.0.0.1:443" if key != expected { t.Errorf("flowKey() = %v, want %v", key, expected) } } func TestParserConnectionStateTracking(t *testing.T) { parser := NewParser() defer parser.Close() // Create a valid ClientHello payload clientHello := createTLSClientHello(0x0303) // Test parseClientHello directly (lower-level test) result, err := parseClientHello(clientHello) if err != nil { t.Errorf("parseClientHello() error = %v", err) } if result == nil { t.Error("parseClientHello() should return ClientHello") } // Test IsClientHello helper if !IsClientHello(clientHello) { t.Error("IsClientHello() should return true for valid ClientHello") } } func TestParserClose_Idempotent(t *testing.T) { parser := NewParser() if err := parser.Close(); err != nil { t.Fatalf("first Close() error = %v", err) } if err := parser.Close(); err != nil { t.Fatalf("second Close() error = %v", err) } } func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) { tcp := &layers.TCP{ Window: 1234, Options: []layers.TCPOption{ { OptionType: layers.TCPOptionKindMSS, OptionData: []byte{0x05}, // malformed (1 byte instead of 2) }, }, } meta := extractTCPMeta(tcp) found := false for _, opt := range meta.Options { if opt == "MSS_INVALID" { found = true break } } if !found { t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options) } }