Files
ja4sentinel/internal/tlsparse/parser.go
toto e166fdab2e
Some checks failed
Build RPM Package / Build RPM Packages (CentOS 7, Rocky 8/9/10) (push) Has been cancelled
feature: 1.1.18
+- FEATURE: Add comprehensive metrics for capture and TLS parser monitoring
+- Capture metrics: packets_received, packets_sent, packets_dropped (atomic counters)
+- Parser metrics: retransmit_count, gap_detected_count, buffer_exceeded_count, segment_exceeded_count
+- New GetStats() method on Capture interface for capture statistics
+- New GetMetrics() method on Parser interface for parser statistics
+- Add DefaultMaxHelloSegments constant (100) to prevent memory leaks from fragmented handshakes
+- Add Segments field to ConnectionFlow for per-flow segment tracking
+- Increase DefaultMaxTrackedFlows from 50000 to 100000 for high-traffic scenarios
+- Improve TCP reassembly: better handling of retransmissions and sequence gaps
+- Memory leak prevention: limit segments per flow and buffer size
+- Aggressive flow cleanup: clean up JA4_DONE flows when approaching flow limit
+- Lock ordering fix: release flow.mu before acquiring p.mu to avoid deadlocks
+- Exclude IPv6 link-local addresses (fe80::) from local IP detection
+- Improve error logging with detailed connection and TLS extension information
+- Add capture diagnostics logging (interface, link_type, local_ips, bpf_filter)
+- Fix false positive retransmission counter when SYN packet is missed
+- Fix gap handling: reset sequence tracking instead of dropping flow
+- Fix extractTLSExtensions: return error details with basic TLS info for debugging
2026-03-09 16:38:40 +01:00

1009 lines
28 KiB
Go

// Package tlsparse provides TLS ClientHello extraction from captured packets
package tlsparse
import (
"encoding/binary"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"ja4sentinel/api"
"ja4sentinel/internal/ipfilter"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
tlsfingerprint "github.com/psanford/tlsfingerprint"
)
// ConnectionState represents the state of a TCP connection for TLS parsing
type ConnectionState int
const (
// NEW: Observed SYN from client on a monitored port
NEW ConnectionState = iota
// WAIT_CLIENT_HELLO: Accumulating segments for complete ClientHello
WAIT_CLIENT_HELLO
// JA4_DONE: JA4 computed and logged, stop tracking this flow
JA4_DONE
)
// Parser configuration constants
const (
// DefaultMaxTrackedFlows is the maximum number of concurrent flows to track
// Increased from 50000 to 100000 to handle high-traffic scenarios
DefaultMaxTrackedFlows = 100000
// DefaultMaxHelloBufferBytes is the maximum buffer size for fragmented ClientHello
DefaultMaxHelloBufferBytes = 256 * 1024 // 256 KiB
// DefaultMaxHelloSegments is the maximum number of segments to accumulate per flow
DefaultMaxHelloSegments = 100
// DefaultCleanupInterval is the interval between cleanup runs
DefaultCleanupInterval = 10 * time.Second
)
// ConnectionFlow tracks a single TCP flow for TLS handshake extraction
// Only tracks incoming traffic from client to the local machine
type ConnectionFlow struct {
mu sync.Mutex // Protects all fields below
State ConnectionState
CreatedAt time.Time
LastSeen time.Time
SrcIP string // Client IP
SrcPort uint16 // Client port
DstIP string // Server IP (local machine)
DstPort uint16 // Server port (local machine)
IPMeta api.IPMeta
TCPMeta api.TCPMeta
HelloBuffer []byte
Segments int // Number of segments accumulated (for memory leak prevention)
NextSeq uint32 // Expected next TCP sequence number for reassembly
SeqInit bool // Whether NextSeq has been initialized
}
// ParserImpl implements the api.Parser interface for TLS parsing
type ParserImpl struct {
mu sync.RWMutex
flows map[string]*ConnectionFlow
flowTimeout time.Duration
cleanupDone chan struct{}
cleanupClose chan struct{}
closeOnce sync.Once
maxTrackedFlows int
maxHelloBufferBytes int
maxHelloSegments int
sourceIPFilter *ipfilter.Filter
// Metrics counters (atomic)
filteredCount uint64 // Counter for filtered packets (debug)
retransmitCount uint64 // Counter for retransmitted packets
gapDetectedCount uint64 // Counter for flows dropped due to sequence gaps
bufferExceededCount uint64 // Counter for flows dropped due to buffer limits
segmentExceededCount uint64 // Counter for flows dropped due to segment limits
}
// NewParser creates a new TLS parser with connection state tracking
func NewParser() *ParserImpl {
return NewParserWithTimeoutAndFilter(30*time.Second, nil)
}
// NewParserWithTimeout creates a new TLS parser with a custom flow timeout
func NewParserWithTimeout(timeout time.Duration) *ParserImpl {
return NewParserWithTimeoutAndFilter(timeout, nil)
}
// NewParserWithTimeoutAndFilter creates a new TLS parser with timeout and source IP filter
func NewParserWithTimeoutAndFilter(timeout time.Duration, excludeSourceIPs []string) *ParserImpl {
var filter *ipfilter.Filter
if len(excludeSourceIPs) > 0 {
f, err := ipfilter.New(excludeSourceIPs)
if err != nil {
// Log error but continue without filter
filter = nil
} else {
filter = f
ips, networks := filter.Count()
_ = ips
_ = networks
}
}
p := &ParserImpl{
flows: make(map[string]*ConnectionFlow),
flowTimeout: timeout,
cleanupDone: make(chan struct{}),
cleanupClose: make(chan struct{}),
closeOnce: sync.Once{},
maxTrackedFlows: DefaultMaxTrackedFlows,
maxHelloBufferBytes: DefaultMaxHelloBufferBytes,
maxHelloSegments: DefaultMaxHelloSegments,
sourceIPFilter: filter,
filteredCount: 0,
retransmitCount: 0,
gapDetectedCount: 0,
bufferExceededCount: 0,
segmentExceededCount: 0,
}
go p.cleanupLoop()
return p
}
// flowKey generates a unique key for a TCP flow (client -> server only)
// Only tracks incoming traffic from client to the local machine
func flowKey(srcIP string, srcPort uint16, dstIP string, dstPort uint16) string {
return fmt.Sprintf("%s:%d->%s:%d", srcIP, srcPort, dstIP, dstPort)
}
// cleanupLoop periodically removes expired flows
func (p *ParserImpl) cleanupLoop() {
ticker := time.NewTicker(DefaultCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.cleanupExpiredFlows()
case <-p.cleanupClose:
close(p.cleanupDone)
return
}
}
}
// cleanupExpiredFlows removes flows that have timed out or are done
func (p *ParserImpl) cleanupExpiredFlows() {
p.mu.Lock()
defer p.mu.Unlock()
now := time.Now()
for key, flow := range p.flows {
flow.mu.Lock()
shouldDelete := flow.State == JA4_DONE || now.Sub(flow.LastSeen) > p.flowTimeout
flow.mu.Unlock()
if shouldDelete {
delete(p.flows, key)
}
}
}
// Process extracts TLS ClientHello from a raw packet
func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) {
if len(pkt.Data) == 0 {
return nil, fmt.Errorf("empty packet data")
}
var ipLayer gopacket.Layer
var tcpLayer gopacket.Layer
// Handle different link types
// LinkType 1 = Ethernet, LinkType 101 = Linux SLL (cooked capture)
const (
LinkTypeEthernet = 1
LinkTypeLinuxSLL = 101
SLL_HEADER_LEN = 16
)
// For Linux SLL (cooked capture), strip the 16-byte SLL header and
// decode directly as raw IP using the protocol type from the SLL header.
if pkt.LinkType == LinkTypeLinuxSLL && len(pkt.Data) >= SLL_HEADER_LEN {
protoType := uint16(pkt.Data[12])<<8 | uint16(pkt.Data[13])
raw := pkt.Data[SLL_HEADER_LEN:]
switch protoType {
case 0x0800: // IPv4
pkt4 := gopacket.NewPacket(raw, layers.LayerTypeIPv4, gopacket.Default)
ipLayer = pkt4.Layer(layers.LayerTypeIPv4)
if ipLayer != nil {
tcpLayer = pkt4.Layer(layers.LayerTypeTCP)
}
case 0x86DD: // IPv6
pkt6 := gopacket.NewPacket(raw, layers.LayerTypeIPv6, gopacket.Default)
ipLayer = pkt6.Layer(layers.LayerTypeIPv6)
if ipLayer != nil {
tcpLayer = pkt6.Layer(layers.LayerTypeTCP)
}
}
if ipLayer == nil {
return nil, nil // Unsupported SLL protocol
}
} else {
// Ethernet or unknown link type: try Ethernet first, then raw IP fallback.
data := pkt.Data
packet := gopacket.NewPacket(data, layers.LinkTypeEthernet, gopacket.Default)
ipLayer = packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
ipLayer = packet.Layer(layers.LayerTypeIPv6)
}
tcpLayer = packet.Layer(layers.LayerTypeTCP)
// If no IP/TCP layer found with Ethernet, try parsing as raw IP.
// Use LayerTypeIPv4/IPv6 (not LinkTypeIPv4/IPv6) so gopacket decodes
// the payload starting directly from the IP header.
if ipLayer == nil || tcpLayer == nil {
// Detect IP version from first nibble of first byte
if len(data) > 0 && (data[0]>>4) == 4 {
rawPacket := gopacket.NewPacket(data, layers.LayerTypeIPv4, gopacket.Default)
ipLayer = rawPacket.Layer(layers.LayerTypeIPv4)
if ipLayer != nil {
tcpLayer = rawPacket.Layer(layers.LayerTypeTCP)
}
} else if len(data) > 0 && (data[0]>>4) == 6 {
rawPacket := gopacket.NewPacket(data, layers.LayerTypeIPv6, gopacket.Default)
ipLayer = rawPacket.Layer(layers.LayerTypeIPv6)
if ipLayer != nil {
tcpLayer = rawPacket.Layer(layers.LayerTypeTCP)
}
}
}
}
if ipLayer == nil {
return nil, nil // Not an IP packet
}
if tcpLayer == nil {
return nil, nil // Not a TCP packet
}
ip, ok := ipLayer.(gopacket.Layer)
if !ok {
return nil, fmt.Errorf("failed to cast IP layer")
}
tcp, ok := tcpLayer.(*layers.TCP)
if !ok {
return nil, fmt.Errorf("failed to cast TCP layer")
}
// Extract IP metadata
ipMeta := extractIPMeta(ip)
// Extract TCP metadata
tcpMeta := extractTCPMeta(tcp)
// Get source/destination info
var srcIP, dstIP string
var srcPort, dstPort uint16
switch v := ip.(type) {
case *layers.IPv4:
srcIP = v.SrcIP.String()
dstIP = v.DstIP.String()
case *layers.IPv6:
srcIP = v.SrcIP.String()
dstIP = v.DstIP.String()
}
srcPort = uint16(tcp.SrcPort)
dstPort = uint16(tcp.DstPort)
// Check if source IP should be excluded
if p.sourceIPFilter != nil && p.sourceIPFilter.ShouldExclude(srcIP) {
atomic.AddUint64(&p.filteredCount, 1)
return nil, nil // Source IP is excluded
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
// Handle SYN packets: create flow and store IP/TCP metadata from SYN
// SYN is the only packet that carries TCP options (MSS, WindowScale, SACK, etc.)
if tcp.SYN && !tcp.ACK {
flow := p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta)
if flow != nil {
flow.mu.Lock()
// SYN consumes 1 sequence number, so data starts at Seq+1
flow.NextSeq = tcp.Seq + 1
flow.SeqInit = true
flow.mu.Unlock()
}
return nil, nil
}
// Get TCP payload (TLS data)
payload := tcp.Payload
if len(payload) == 0 {
return nil, nil // No payload (ACK, FIN, etc.)
}
// Check if this is a TLS handshake (content type 22)
isTLSHandshake := payload[0] == 22
// Early exit for non-ClientHello first packet (no SYN seen, no TLS handshake)
// Check flow existence atomically within getOrCreateFlow
if !isTLSHandshake {
p.mu.RLock()
_, flowExists := p.flows[key]
p.mu.RUnlock()
if !flowExists {
return nil, nil
}
}
flow := p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta)
if flow == nil {
return nil, nil
}
// If flow was just created and we didn't see SYN, initialize sequence from this packet
// This handles the case where SYN was missed but we still want to extract the ClientHello
flow.mu.Lock()
if !flow.SeqInit {
flow.NextSeq = tcp.Seq + uint32(len(payload))
flow.SeqInit = true
}
flow.mu.Unlock()
// Lock the flow for the entire processing to avoid race conditions
flow.mu.Lock()
flowMuLocked := true
defer func() {
if flowMuLocked {
flow.mu.Unlock()
}
}()
// Check if flow is already done
if flow.State == JA4_DONE {
return nil, nil // Already processed this flow
}
// TCP sequence tracking: detect retransmissions and maintain order
seq := tcp.Seq
if flow.SeqInit {
if seq < flow.NextSeq {
// Bug 7 fix: only count as retransmission when the flow is past NEW.
// When SYN is missed, SeqInit is set from the first data packet so
// seq < NextSeq always holds for that same packet — incrementing the
// counter here was a false positive.
if flow.State != NEW {
atomic.AddUint64(&p.retransmitCount, 1)
return nil, nil
}
}
if seq > flow.NextSeq && flow.State == WAIT_CLIENT_HELLO {
// Gap detected — missing segment in fragmented ClientHello
// Instead of dropping the flow, log and continue with available data
atomic.AddUint64(&p.gapDetectedCount, 1)
// Reset sequence tracking to continue with this segment
flow.NextSeq = seq + uint32(len(payload))
// Clear buffer since we have a gap - start fresh with this segment
flow.HelloBuffer = make([]byte, 0)
flow.Segments = 0
}
}
// Update expected next sequence number
flow.NextSeq = seq + uint32(len(payload))
flow.SeqInit = true
// Check if this is a TLS ClientHello
clientHello, err := parseClientHello(payload)
if err != nil {
return nil, err
}
if clientHello != nil {
// Found ClientHello, mark flow as done
flow.State = JA4_DONE
flow.HelloBuffer = clientHello
flow.Segments = 0 // Reset segment count
// Extract TLS extensions (SNI, ALPN, TLS version)
extInfo, err := extractTLSExtensions(clientHello)
if err != nil {
// Log error but continue with empty extension info
extInfo = &TLSExtensionInfo{}
}
// Ensure extInfo is never nil
if extInfo == nil {
extInfo = &TLSExtensionInfo{}
}
// Generate ConnID from flow key
connID := key
// Use flow metadata (captured from SYN) for accurate IP/TCP fingerprinting
ch := &api.TLSClientHello{
SrcIP: srcIP,
SrcPort: srcPort,
DstIP: dstIP,
DstPort: dstPort,
Payload: clientHello,
IPMeta: flow.IPMeta,
TCPMeta: flow.TCPMeta,
ConnID: connID,
SNI: extInfo.SNI,
ALPN: joinStringSlice(extInfo.ALPN, ","),
TLSVersion: extInfo.TLSVersion,
}
// Calculate SynToCHMs if we have timing info
synToCH := uint32(time.Since(flow.CreatedAt).Milliseconds())
ch.SynToCHMs = &synToCH
return ch, nil
}
// Check for fragmented ClientHello (accumulate segments)
if flow.State == WAIT_CLIENT_HELLO || flow.State == NEW {
// Check segment count limit (memory leak prevention)
// Bug 4 fix: release flow.mu before acquiring p.mu to avoid lock-order
// inversion with cleanupExpiredFlows (which acquires p.mu then flow.mu).
if flow.Segments >= p.maxHelloSegments {
atomic.AddUint64(&p.segmentExceededCount, 1)
flowMuLocked = false
flow.mu.Unlock()
p.mu.Lock()
delete(p.flows, key)
p.mu.Unlock()
return nil, nil
}
// Check buffer size limit (memory leak prevention)
// Bug 4 fix (same): release flow.mu before acquiring p.mu.
if len(flow.HelloBuffer)+len(payload) > p.maxHelloBufferBytes {
atomic.AddUint64(&p.bufferExceededCount, 1)
flowMuLocked = false
flow.mu.Unlock()
p.mu.Lock()
delete(p.flows, key)
p.mu.Unlock()
return nil, nil
}
flow.State = WAIT_CLIENT_HELLO
flow.HelloBuffer = append(flow.HelloBuffer, payload...)
flow.Segments++
flow.LastSeen = time.Now()
// Make a copy of the buffer for parsing (outside the lock)
bufferCopy := make([]byte, len(flow.HelloBuffer))
copy(bufferCopy, flow.HelloBuffer)
// Try to parse accumulated buffer
clientHello, err := parseClientHello(bufferCopy)
if err != nil {
return nil, err
}
if clientHello != nil {
// Complete ClientHello found
flow.State = JA4_DONE
flow.Segments = 0 // Reset segment count
// Extract TLS extensions (SNI, ALPN, TLS version)
extInfo, err := extractTLSExtensions(clientHello)
if err != nil {
// Log error but continue with empty extension info
extInfo = &TLSExtensionInfo{}
}
// Ensure extInfo is never nil
if extInfo == nil {
extInfo = &TLSExtensionInfo{}
}
// Generate ConnID from flow key
connID := key
// Use flow metadata (captured from SYN) for accurate IP/TCP fingerprinting
ch := &api.TLSClientHello{
SrcIP: srcIP,
SrcPort: srcPort,
DstIP: dstIP,
DstPort: dstPort,
Payload: clientHello,
IPMeta: flow.IPMeta,
TCPMeta: flow.TCPMeta,
ConnID: connID,
SNI: extInfo.SNI,
ALPN: joinStringSlice(extInfo.ALPN, ","),
TLSVersion: extInfo.TLSVersion,
}
// Calculate SynToCHMs
synToCH := uint32(time.Since(flow.CreatedAt).Milliseconds())
ch.SynToCHMs = &synToCH
return ch, nil
}
}
return nil, nil // No ClientHello found yet
}
// getOrCreateFlow gets existing flow or creates a new one
// Only tracks incoming traffic from client to the local machine
func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, dstIP string, dstPort uint16, ipMeta api.IPMeta, tcpMeta api.TCPMeta) *ConnectionFlow {
p.mu.Lock()
defer p.mu.Unlock()
if flow, exists := p.flows[key]; exists {
flow.mu.Lock()
flow.LastSeen = time.Now()
flow.mu.Unlock()
return flow
}
// If approaching flow limit, trigger aggressive cleanup of finished flows
if len(p.flows) >= p.maxTrackedFlows {
// Clean up all JA4_DONE flows first (they're already processed)
for k, flow := range p.flows {
flow.mu.Lock()
isDone := flow.State == JA4_DONE
flow.mu.Unlock()
if isDone {
delete(p.flows, k)
}
}
// If still at limit, clean up expired flows
if len(p.flows) >= p.maxTrackedFlows {
now := time.Now()
for k, flow := range p.flows {
flow.mu.Lock()
isExpired := now.Sub(flow.LastSeen) > p.flowTimeout
flow.mu.Unlock()
if isExpired {
delete(p.flows, k)
}
}
}
// Final check - if still at limit, return nil
if len(p.flows) >= p.maxTrackedFlows {
return nil
}
}
flow := &ConnectionFlow{
State: NEW,
CreatedAt: time.Now(),
LastSeen: time.Now(),
SrcIP: srcIP, // Client IP
SrcPort: srcPort, // Client port
DstIP: dstIP, // Server IP (local machine)
DstPort: dstPort, // Server port (local machine)
IPMeta: ipMeta,
TCPMeta: tcpMeta,
HelloBuffer: make([]byte, 0),
Segments: 0,
}
p.flows[key] = flow
return flow
}
// GetFilterStats returns statistics about the IP filter (for debug/monitoring)
func (p *ParserImpl) GetFilterStats() (filteredCount uint64, hasFilter bool) {
if p.sourceIPFilter == nil {
return 0, false
}
return atomic.LoadUint64(&p.filteredCount), true
}
// GetMetrics returns comprehensive parser metrics (for monitoring/debugging)
func (p *ParserImpl) GetMetrics() (retransmit, gapDetected, bufferExceeded, segmentExceeded uint64) {
return atomic.LoadUint64(&p.retransmitCount),
atomic.LoadUint64(&p.gapDetectedCount),
atomic.LoadUint64(&p.bufferExceededCount),
atomic.LoadUint64(&p.segmentExceededCount)
}
// Close cleans up the parser and stops background goroutines
func (p *ParserImpl) Close() error {
p.closeOnce.Do(func() {
close(p.cleanupClose)
<-p.cleanupDone
})
return nil
}
// extractIPMeta extracts IP metadata from the IP layer
func extractIPMeta(ipLayer gopacket.Layer) api.IPMeta {
meta := api.IPMeta{}
switch v := ipLayer.(type) {
case *layers.IPv4:
meta.TTL = v.TTL
meta.TotalLength = v.Length
meta.IPID = v.Id
meta.DF = v.Flags&layers.IPv4DontFragment != 0
case *layers.IPv6:
meta.TTL = v.HopLimit
meta.TotalLength = uint16(v.Length)
meta.IPID = 0 // IPv6 doesn't have IP ID
meta.DF = true // IPv6 doesn't fragment at source
}
return meta
}
// extractTCPMeta extracts TCP metadata from the TCP layer
func extractTCPMeta(tcp *layers.TCP) api.TCPMeta {
meta := api.TCPMeta{
WindowSize: tcp.Window,
Options: make([]string, 0, len(tcp.Options)),
}
// Parse TCP options
for _, opt := range tcp.Options {
switch opt.OptionType {
case layers.TCPOptionKindEndList:
// End of Options List - skip silently
continue
case layers.TCPOptionKindNop:
// No Operation (padding) - skip silently
continue
case layers.TCPOptionKindMSS:
if len(opt.OptionData) >= 2 {
meta.MSS = binary.BigEndian.Uint16(opt.OptionData[:2])
meta.Options = append(meta.Options, "MSS")
} else {
meta.Options = append(meta.Options, "MSS_INVALID")
}
case layers.TCPOptionKindWindowScale:
if len(opt.OptionData) > 0 {
meta.WindowScale = opt.OptionData[0]
}
meta.Options = append(meta.Options, "WS")
case layers.TCPOptionKindSACKPermitted:
meta.Options = append(meta.Options, "SACK")
case layers.TCPOptionKindSACK:
// SACK blocks (actual SACK data, not just permitted)
meta.Options = append(meta.Options, "SACK")
case layers.TCPOptionKindTimestamps:
meta.Options = append(meta.Options, "TS")
default:
meta.Options = append(meta.Options, fmt.Sprintf("OPT%d", opt.OptionType))
}
}
return meta
}
// TLSExtensionInfo contains parsed TLS extension information
type TLSExtensionInfo struct {
SNI string
ALPN []string
TLSVersion string
}
// parseClientHello checks if the payload contains a TLS ClientHello and returns it
func parseClientHello(payload []byte) ([]byte, error) {
if len(payload) < 5 {
return nil, nil // Too short for TLS record
}
// TLS record layer: Content Type (1 byte), Version (2 bytes), Length (2 bytes)
contentType := payload[0]
// Check for TLS handshake (content type 22)
if contentType != 22 {
return nil, nil // Not a TLS handshake
}
// Check TLS version (TLS 1.0 = 0x0301, TLS 1.1 = 0x0302, TLS 1.2 = 0x0303, TLS 1.3 = 0x0304)
version := binary.BigEndian.Uint16(payload[1:3])
if version < 0x0301 || version > 0x0304 {
return nil, nil // Unknown TLS version
}
recordLength := int(binary.BigEndian.Uint16(payload[3:5]))
if len(payload) < 5+recordLength {
return nil, nil // Incomplete TLS record
}
// Parse handshake protocol
handshakePayload := payload[5 : 5+recordLength]
if len(handshakePayload) < 1 {
return nil, nil // Too short for handshake type
}
handshakeType := handshakePayload[0]
// Check for ClientHello (handshake type 1)
if handshakeType != 1 {
return nil, nil // Not a ClientHello
}
// Return the full TLS record (header + payload) for fingerprinting
return payload[:5+recordLength], nil
}
// extractTLSExtensions extracts SNI, ALPN, and TLS version from a ClientHello payload
// Uses tlsfingerprint library for ALPN and TLS version, manual parsing for SNI value
func extractTLSExtensions(payload []byte) (*TLSExtensionInfo, error) {
if len(payload) < 5 {
return nil, nil
}
// TLS record layer
contentType := payload[0]
if contentType != 22 {
return nil, nil // Not a handshake
}
version := binary.BigEndian.Uint16(payload[1:3])
recordLength := int(binary.BigEndian.Uint16(payload[3:5]))
if len(payload) < 5+recordLength {
return nil, nil // Incomplete record
}
handshakePayload := payload[5 : 5+recordLength]
if len(handshakePayload) < 1 {
return nil, nil
}
handshakeType := handshakePayload[0]
if handshakeType != 1 {
return nil, nil // Not a ClientHello
}
info := &TLSExtensionInfo{}
// Use tlsfingerprint to parse ALPN and TLS version
fp, err := tlsfingerprint.ParseClientHello(payload)
if err != nil {
// Retry with sanitized payload (handles truncated/malformed extensions)
if sanitized := sanitizeTLSRecord(payload); sanitized != nil {
fp, err = tlsfingerprint.ParseClientHello(sanitized)
if err != nil {
// Return error but also provide basic info from manual parsing
info.TLSVersion = tlsVersionToString(version)
info.SNI = extractSNIFromPayload(handshakePayload)
return info, fmt.Errorf("tlsfingerprint.ParseClientHello failed: %w", err)
}
} else {
// Sanitization not available, return error with basic info
info.TLSVersion = tlsVersionToString(version)
info.SNI = extractSNIFromPayload(handshakePayload)
return info, fmt.Errorf("tlsfingerprint.ParseClientHello failed and sanitization unavailable")
}
}
if fp != nil {
// Extract ALPN protocols
if len(fp.ALPNProtocols) > 0 {
info.ALPN = fp.ALPNProtocols
}
// Extract TLS version
info.TLSVersion = tlsVersionToString(fp.Version)
}
// If tlsfingerprint didn't provide version, fall back to record version
if info.TLSVersion == "" {
info.TLSVersion = tlsVersionToString(version)
}
// Parse SNI manually (tlsfingerprint only provides HasSNI, not the value)
sniValue := extractSNIFromPayload(handshakePayload)
if sniValue != "" {
info.SNI = sniValue
}
return info, nil
}
// extractSNIFromPayload extracts the SNI value from a ClientHello handshake payload
// handshakePayload starts at the handshake type byte (0x01 for ClientHello)
func extractSNIFromPayload(handshakePayload []byte) string {
// handshakePayload structure:
// [0]: Handshake type (0x01 for ClientHello)
// [1:4]: Handshake length (3 bytes, big-endian)
// [4:6]: Version (2 bytes)
// [6:38]: Random (32 bytes)
// [38]: Session ID length
// ...
if len(handshakePayload) < 40 { // type(1) + len(3) + version(2) + random(32) + sessionIDLen(1)
return ""
}
// Start after type (1) + length (3) + version (2) + random (32) = 38
offset := 38
// Session ID length (1 byte)
sessionIDLen := int(handshakePayload[offset])
offset++
// Skip session ID
offset += sessionIDLen
// Cipher suites length (2 bytes)
if offset+2 > len(handshakePayload) {
return ""
}
cipherSuiteLen := int(binary.BigEndian.Uint16(handshakePayload[offset : offset+2]))
offset += 2 + cipherSuiteLen
// Compression methods length (1 byte)
if offset >= len(handshakePayload) {
return ""
}
compressionLen := int(handshakePayload[offset])
offset++
// Skip compression methods
offset += compressionLen
// Extensions length (2 bytes) - optional
if offset+2 > len(handshakePayload) {
return ""
}
extensionsLen := int(binary.BigEndian.Uint16(handshakePayload[offset : offset+2]))
offset += 2
if extensionsLen == 0 || offset+extensionsLen > len(handshakePayload) {
return ""
}
extensionsEnd := offset + extensionsLen
// Debug: log extension types found
_ = extensionsEnd // suppress unused warning in case we remove debug code
// Parse extensions to find SNI (type 0)
for offset < extensionsEnd {
if offset+4 > len(handshakePayload) {
break
}
extType := binary.BigEndian.Uint16(handshakePayload[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(handshakePayload[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(handshakePayload) {
break
}
extData := handshakePayload[offset : offset+extLen]
offset += extLen
if extType == 0 && len(extData) >= 5 { // SNI extension
// SNI extension structure:
// - name_list_len (2 bytes)
// - name_type (1 byte)
// - name_len (2 bytes)
// - name (variable)
// Skip name_list_len (2), read name_type (1) + name_len (2)
nameLen := int(binary.BigEndian.Uint16(extData[3:5]))
if len(extData) >= 5+nameLen {
return string(extData[5 : 5+nameLen])
}
}
}
return ""
}
// tlsVersionToString converts a TLS version number to a string
func tlsVersionToString(version uint16) string {
switch version {
case 0x0301:
return "1.0"
case 0x0302:
return "1.1"
case 0x0303:
return "1.2"
case 0x0304:
return "1.3"
default:
return ""
}
}
// IsClientHello checks if a payload contains a TLS ClientHello
func IsClientHello(payload []byte) bool {
if len(payload) < 6 {
return false
}
// TLS handshake record
if payload[0] != 22 {
return false
}
// Check version
version := binary.BigEndian.Uint16(payload[1:3])
if version < 0x0301 || version > 0x0304 {
return false
}
recordLength := int(binary.BigEndian.Uint16(payload[3:5]))
if len(payload) < 5+recordLength {
return false
}
handshakePayload := payload[5 : 5+recordLength]
if len(handshakePayload) < 1 {
return false
}
// ClientHello type
return handshakePayload[0] == 1
}
// Helper function to join string slice with separator (kept for backward compatibility)
// Deprecated: Use strings.Join instead
func joinStringSlice(slice []string, sep string) string {
if len(slice) == 0 {
return ""
}
return strings.Join(slice, sep)
}
// sanitizeTLSRecord attempts to fix a TLS ClientHello with truncated extensions
// by adjusting lengths to cover only complete extensions. Returns a corrected
// copy of the record, or nil if no fix is needed or possible.
func sanitizeTLSRecord(data []byte) []byte {
if len(data) < 5 || data[0] != 0x16 {
return nil
}
recordLen := int(data[3])<<8 | int(data[4])
if len(data) < 5+recordLen {
return nil
}
payload := data[5 : 5+recordLen]
if len(payload) < 4 || payload[0] != 0x01 {
return nil
}
helloLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3])
if len(payload) < 4+helloLen {
return nil
}
hello := payload[4 : 4+helloLen]
// Skip version + random + session ID + cipher suites + compression methods
offset := 2 + 32
if len(hello) < offset+1 {
return nil
}
offset += 1 + int(hello[offset]) // session ID
if len(hello) < offset+2 {
return nil
}
csLen := int(hello[offset])<<8 | int(hello[offset+1])
offset += 2 + csLen
if len(hello) < offset+1 {
return nil
}
offset += 1 + int(hello[offset]) // compression methods
if len(hello) < offset+2 {
return nil
}
extLenOffset := offset
declaredExtLen := int(hello[offset])<<8 | int(hello[offset+1])
offset += 2
extStart := offset
if len(hello) < extStart+declaredExtLen {
return nil
}
extData := hello[extStart : extStart+declaredExtLen]
// Walk extensions to find the last complete one
validLen := 0
pos := 0
for pos < len(extData) {
if pos+4 > len(extData) {
break
}
extBodyLen := int(extData[pos+2])<<8 | int(extData[pos+3])
if pos+4+extBodyLen > len(extData) {
break // truncated extension
}
pos += 4 + extBodyLen
validLen = pos
}
if validLen == declaredExtLen {
return nil // no truncation, nothing to fix
}
fixed := make([]byte, len(data))
copy(fixed, data)
diff := declaredExtLen - validLen
extLenAbs := 5 + 4 + extLenOffset
binary.BigEndian.PutUint16(fixed[extLenAbs:], uint16(validLen))
newHelloLen := helloLen - diff
fixed[5+1] = byte(newHelloLen >> 16)
fixed[5+2] = byte(newHelloLen >> 8)
fixed[5+3] = byte(newHelloLen)
newRecordLen := recordLen - diff
binary.BigEndian.PutUint16(fixed[3:5], uint16(newRecordLen))
return fixed[:5+newRecordLen]
}