Files
ja4sentinel/internal/tlsparse/parser.go
Jacquin Antoine 965720a183
Some checks failed
Build RPM Package / Build RPM Packages (CentOS 7, Rocky 8/9/10) (push) Has been cancelled
release: version 1.0.9 - Add SNI, ALPN, TLS version extraction and architecture.yml compliance
New features:
- Extract SNI (Server Name Indication) from TLS ClientHello
- Extract ALPN (Application-Layer Protocol Negotiation) protocols
- Detect TLS version from ClientHello using tlsfingerprint library
- Add ConnID field for TCP flow correlation
- Add SensorID field for multi-sensor deployments
- Add SynToCHMs timing field for behavioral detection
- Add AsyncBuffer configuration for output queue sizing

Architecture changes:
- Remove JA4Hash from LogRecord (JA4 format includes its own hash portions)
- Update api.TLSClientHello with new TLS metadata fields
- Update api.LogRecord with correlation, TLS, and timing fields
- Ensure 100% compliance with architecture.yml specification

Tests:
- Add unit tests for TLS extension extraction (SNI, ALPN, Version)
- Update tests for new LogRecord schema without JA4Hash
- Add tests for AsyncBuffer configuration

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
2026-03-02 19:32:16 +01:00

666 lines
17 KiB
Go

// Package tlsparse provides TLS ClientHello extraction from captured packets
package tlsparse
import (
"encoding/binary"
"fmt"
"strings"
"sync"
"time"
"ja4sentinel/api"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
tlsfingerprint "github.com/psanford/tlsfingerprint"
)
// ConnectionState represents the state of a TCP connection for TLS parsing
type ConnectionState int
const (
// NEW: Observed SYN from client on a monitored port
NEW ConnectionState = iota
// WAIT_CLIENT_HELLO: Accumulating segments for complete ClientHello
WAIT_CLIENT_HELLO
// JA4_DONE: JA4 computed and logged, stop tracking this flow
JA4_DONE
)
// Parser configuration constants
const (
// DefaultMaxTrackedFlows is the maximum number of concurrent flows to track
DefaultMaxTrackedFlows = 50000
// DefaultMaxHelloBufferBytes is the maximum buffer size for fragmented ClientHello
DefaultMaxHelloBufferBytes = 256 * 1024 // 256 KiB
// DefaultCleanupInterval is the interval between cleanup runs
DefaultCleanupInterval = 10 * time.Second
)
// ConnectionFlow tracks a single TCP flow for TLS handshake extraction
// Only tracks incoming traffic from client to the local machine
type ConnectionFlow struct {
mu sync.Mutex // Protects all fields below
State ConnectionState
CreatedAt time.Time
LastSeen time.Time
SrcIP string // Client IP
SrcPort uint16 // Client port
DstIP string // Server IP (local machine)
DstPort uint16 // Server port (local machine)
IPMeta api.IPMeta
TCPMeta api.TCPMeta
HelloBuffer []byte
}
// ParserImpl implements the api.Parser interface for TLS parsing
type ParserImpl struct {
mu sync.RWMutex
flows map[string]*ConnectionFlow
flowTimeout time.Duration
cleanupDone chan struct{}
cleanupClose chan struct{}
closeOnce sync.Once
maxTrackedFlows int
maxHelloBufferBytes int
}
// NewParser creates a new TLS parser with connection state tracking
func NewParser() *ParserImpl {
return NewParserWithTimeout(30 * time.Second)
}
// NewParserWithTimeout creates a new TLS parser with a custom flow timeout
func NewParserWithTimeout(timeout time.Duration) *ParserImpl {
p := &ParserImpl{
flows: make(map[string]*ConnectionFlow),
flowTimeout: timeout,
cleanupDone: make(chan struct{}),
cleanupClose: make(chan struct{}),
maxTrackedFlows: DefaultMaxTrackedFlows,
maxHelloBufferBytes: DefaultMaxHelloBufferBytes,
}
go p.cleanupLoop()
return p
}
// flowKey generates a unique key for a TCP flow (client -> server only)
// Only tracks incoming traffic from client to the local machine
func flowKey(srcIP string, srcPort uint16, dstIP string, dstPort uint16) string {
return fmt.Sprintf("%s:%d->%s:%d", srcIP, srcPort, dstIP, dstPort)
}
// cleanupLoop periodically removes expired flows
func (p *ParserImpl) cleanupLoop() {
ticker := time.NewTicker(DefaultCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.cleanupExpiredFlows()
case <-p.cleanupClose:
close(p.cleanupDone)
return
}
}
}
// cleanupExpiredFlows removes flows that have timed out or are done
func (p *ParserImpl) cleanupExpiredFlows() {
p.mu.Lock()
defer p.mu.Unlock()
now := time.Now()
for key, flow := range p.flows {
flow.mu.Lock()
shouldDelete := flow.State == JA4_DONE || now.Sub(flow.LastSeen) > p.flowTimeout
flow.mu.Unlock()
if shouldDelete {
delete(p.flows, key)
}
}
}
// Process extracts TLS ClientHello from a raw packet
func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) {
if len(pkt.Data) == 0 {
return nil, fmt.Errorf("empty packet data")
}
// Parse packet layers
packet := gopacket.NewPacket(pkt.Data, layers.LinkTypeEthernet, gopacket.Default)
// Get IP layer
ipLayer := packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
ipLayer = packet.Layer(layers.LayerTypeIPv6)
}
if ipLayer == nil {
return nil, nil // Not an IP packet
}
// Get TCP layer
tcpLayer := packet.Layer(layers.LayerTypeTCP)
if tcpLayer == nil {
return nil, nil // Not a TCP packet
}
ip, ok := ipLayer.(gopacket.Layer)
if !ok {
return nil, fmt.Errorf("failed to cast IP layer")
}
tcp, ok := tcpLayer.(*layers.TCP)
if !ok {
return nil, fmt.Errorf("failed to cast TCP layer")
}
// Extract IP metadata
ipMeta := extractIPMeta(ip)
// Extract TCP metadata
tcpMeta := extractTCPMeta(tcp)
// Get source/destination info
var srcIP, dstIP string
var srcPort, dstPort uint16
switch v := ip.(type) {
case *layers.IPv4:
srcIP = v.SrcIP.String()
dstIP = v.DstIP.String()
case *layers.IPv6:
srcIP = v.SrcIP.String()
dstIP = v.DstIP.String()
}
srcPort = uint16(tcp.SrcPort)
dstPort = uint16(tcp.DstPort)
// Get TCP payload (TLS data)
payload := tcp.Payload
if len(payload) == 0 {
return nil, nil // No payload
}
key := flowKey(srcIP, srcPort, dstIP, dstPort)
// Check if flow exists before acquiring write lock
p.mu.RLock()
flow, flowExists := p.flows[key]
p.mu.RUnlock()
// Early exit for non-ClientHello first packet
if !flowExists && payload[0] != 22 {
return nil, nil
}
flow = p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta)
if flow == nil {
return nil, nil
}
// Lock the flow for the entire processing to avoid race conditions
flow.mu.Lock()
defer flow.mu.Unlock()
// Check if flow is already done
if flow.State == JA4_DONE {
return nil, nil // Already processed this flow
}
// Check if this is a TLS ClientHello
clientHello, err := parseClientHello(payload)
if err != nil {
return nil, err
}
if clientHello != nil {
// Found ClientHello, mark flow as done
flow.State = JA4_DONE
flow.HelloBuffer = clientHello
// Extract TLS extensions (SNI, ALPN, TLS version)
extInfo, _ := extractTLSExtensions(clientHello)
// Generate ConnID from flow key
connID := key
ch := &api.TLSClientHello{
SrcIP: srcIP,
SrcPort: srcPort,
DstIP: dstIP,
DstPort: dstPort,
Payload: clientHello,
IPMeta: ipMeta,
TCPMeta: tcpMeta,
ConnID: connID,
SNI: extInfo.SNI,
ALPN: joinStringSlice(extInfo.ALPN, ","),
TLSVersion: extInfo.TLSVersion,
}
// Calculate SynToCHMs if we have timing info
synToCH := uint32(time.Since(flow.CreatedAt).Milliseconds())
ch.SynToCHMs = &synToCH
return ch, nil
}
// Check for fragmented ClientHello (accumulate segments)
if flow.State == WAIT_CLIENT_HELLO || flow.State == NEW {
if len(flow.HelloBuffer)+len(payload) > p.maxHelloBufferBytes {
// Buffer would exceed limit, drop this flow
p.mu.Lock()
delete(p.flows, key)
p.mu.Unlock()
return nil, nil
}
flow.State = WAIT_CLIENT_HELLO
flow.HelloBuffer = append(flow.HelloBuffer, payload...)
flow.LastSeen = time.Now()
// Make a copy of the buffer for parsing (outside the lock)
bufferCopy := make([]byte, len(flow.HelloBuffer))
copy(bufferCopy, flow.HelloBuffer)
// Try to parse accumulated buffer
clientHello, err := parseClientHello(bufferCopy)
if err != nil {
return nil, err
}
if clientHello != nil {
// Complete ClientHello found
flow.State = JA4_DONE
// Extract TLS extensions (SNI, ALPN, TLS version)
extInfo, _ := extractTLSExtensions(clientHello)
// Generate ConnID from flow key
connID := key
ch := &api.TLSClientHello{
SrcIP: srcIP,
SrcPort: srcPort,
DstIP: dstIP,
DstPort: dstPort,
Payload: clientHello,
IPMeta: ipMeta,
TCPMeta: tcpMeta,
ConnID: connID,
SNI: extInfo.SNI,
ALPN: joinStringSlice(extInfo.ALPN, ","),
TLSVersion: extInfo.TLSVersion,
}
// Calculate SynToCHMs
synToCH := uint32(time.Since(flow.CreatedAt).Milliseconds())
ch.SynToCHMs = &synToCH
return ch, nil
}
}
return nil, nil // No ClientHello found yet
}
// getOrCreateFlow gets existing flow or creates a new one
// Only tracks incoming traffic from client to the local machine
func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, dstIP string, dstPort uint16, ipMeta api.IPMeta, tcpMeta api.TCPMeta) *ConnectionFlow {
p.mu.Lock()
defer p.mu.Unlock()
if flow, exists := p.flows[key]; exists {
flow.mu.Lock()
flow.LastSeen = time.Now()
flow.mu.Unlock()
return flow
}
if len(p.flows) >= p.maxTrackedFlows {
return nil
}
flow := &ConnectionFlow{
State: NEW,
CreatedAt: time.Now(),
LastSeen: time.Now(),
SrcIP: srcIP, // Client IP
SrcPort: srcPort, // Client port
DstIP: dstIP, // Server IP (local machine)
DstPort: dstPort, // Server port (local machine)
IPMeta: ipMeta,
TCPMeta: tcpMeta,
HelloBuffer: make([]byte, 0),
}
p.flows[key] = flow
return flow
}
// Close cleans up the parser and stops background goroutines
func (p *ParserImpl) Close() error {
p.closeOnce.Do(func() {
close(p.cleanupClose)
<-p.cleanupDone
})
return nil
}
// extractIPMeta extracts IP metadata from the IP layer
func extractIPMeta(ipLayer gopacket.Layer) api.IPMeta {
meta := api.IPMeta{}
switch v := ipLayer.(type) {
case *layers.IPv4:
meta.TTL = v.TTL
meta.TotalLength = v.Length
meta.IPID = v.Id
meta.DF = v.Flags&layers.IPv4DontFragment != 0
case *layers.IPv6:
meta.TTL = v.HopLimit
meta.TotalLength = uint16(v.Length)
meta.IPID = 0 // IPv6 doesn't have IP ID
meta.DF = true // IPv6 doesn't fragment at source
}
return meta
}
// extractTCPMeta extracts TCP metadata from the TCP layer
func extractTCPMeta(tcp *layers.TCP) api.TCPMeta {
meta := api.TCPMeta{
WindowSize: tcp.Window,
Options: make([]string, 0, len(tcp.Options)),
}
// Parse TCP options
for _, opt := range tcp.Options {
switch opt.OptionType {
case layers.TCPOptionKindEndList:
// End of Options List - skip silently
continue
case layers.TCPOptionKindNop:
// No Operation (padding) - skip silently
continue
case layers.TCPOptionKindMSS:
if len(opt.OptionData) >= 2 {
meta.MSS = binary.BigEndian.Uint16(opt.OptionData[:2])
meta.Options = append(meta.Options, "MSS")
} else {
meta.Options = append(meta.Options, "MSS_INVALID")
}
case layers.TCPOptionKindWindowScale:
if len(opt.OptionData) > 0 {
meta.WindowScale = opt.OptionData[0]
}
meta.Options = append(meta.Options, "WS")
case layers.TCPOptionKindSACKPermitted:
meta.Options = append(meta.Options, "SACK")
case layers.TCPOptionKindSACK:
// SACK blocks (actual SACK data, not just permitted)
meta.Options = append(meta.Options, "SACK")
case layers.TCPOptionKindTimestamps:
meta.Options = append(meta.Options, "TS")
default:
meta.Options = append(meta.Options, fmt.Sprintf("OPT%d", opt.OptionType))
}
}
return meta
}
// TLSExtensionInfo contains parsed TLS extension information
type TLSExtensionInfo struct {
SNI string
ALPN []string
TLSVersion string
}
// parseClientHello checks if the payload contains a TLS ClientHello and returns it
func parseClientHello(payload []byte) ([]byte, error) {
if len(payload) < 5 {
return nil, nil // Too short for TLS record
}
// TLS record layer: Content Type (1 byte), Version (2 bytes), Length (2 bytes)
contentType := payload[0]
// Check for TLS handshake (content type 22)
if contentType != 22 {
return nil, nil // Not a TLS handshake
}
// Check TLS version (TLS 1.0 = 0x0301, TLS 1.1 = 0x0302, TLS 1.2 = 0x0303, TLS 1.3 = 0x0304)
version := binary.BigEndian.Uint16(payload[1:3])
if version < 0x0301 || version > 0x0304 {
return nil, nil // Unknown TLS version
}
recordLength := int(binary.BigEndian.Uint16(payload[3:5]))
if len(payload) < 5+recordLength {
return nil, nil // Incomplete TLS record
}
// Parse handshake protocol
handshakePayload := payload[5 : 5+recordLength]
if len(handshakePayload) < 1 {
return nil, nil // Too short for handshake type
}
handshakeType := handshakePayload[0]
// Check for ClientHello (handshake type 1)
if handshakeType != 1 {
return nil, nil // Not a ClientHello
}
// Return the full TLS record (header + payload) for fingerprinting
return payload[:5+recordLength], nil
}
// extractTLSExtensions extracts SNI, ALPN, and TLS version from a ClientHello payload
// Uses tlsfingerprint library for ALPN and TLS version, manual parsing for SNI value
func extractTLSExtensions(payload []byte) (*TLSExtensionInfo, error) {
if len(payload) < 5 {
return nil, nil
}
// TLS record layer
contentType := payload[0]
if contentType != 22 {
return nil, nil // Not a handshake
}
version := binary.BigEndian.Uint16(payload[1:3])
recordLength := int(binary.BigEndian.Uint16(payload[3:5]))
if len(payload) < 5+recordLength {
return nil, nil // Incomplete record
}
handshakePayload := payload[5 : 5+recordLength]
if len(handshakePayload) < 1 {
return nil, nil
}
handshakeType := handshakePayload[0]
if handshakeType != 1 {
return nil, nil // Not a ClientHello
}
info := &TLSExtensionInfo{}
// Use tlsfingerprint to parse ALPN and TLS version
fp, err := tlsfingerprint.ParseClientHello(payload)
if err == nil && fp != nil {
// Extract ALPN protocols
if len(fp.ALPNProtocols) > 0 {
info.ALPN = fp.ALPNProtocols
}
// Extract TLS version
info.TLSVersion = tlsVersionToString(fp.Version)
}
// If tlsfingerprint didn't provide version, fall back to record version
if info.TLSVersion == "" {
info.TLSVersion = tlsVersionToString(version)
}
// Parse SNI manually (tlsfingerprint only provides HasSNI, not the value)
sniValue := extractSNIFromPayload(handshakePayload)
if sniValue != "" {
info.SNI = sniValue
}
return info, nil
}
// extractSNIFromPayload extracts the SNI value from a ClientHello handshake payload
// handshakePayload starts at the handshake type byte (0x01 for ClientHello)
func extractSNIFromPayload(handshakePayload []byte) string {
// handshakePayload structure:
// [0]: Handshake type (0x01 for ClientHello)
// [1:4]: Handshake length (3 bytes, big-endian)
// [4:6]: Version (2 bytes)
// [6:38]: Random (32 bytes)
// [38]: Session ID length
// ...
if len(handshakePayload) < 40 { // type(1) + len(3) + version(2) + random(32) + sessionIDLen(1)
return ""
}
// Start after type (1) + length (3) + version (2) + random (32) = 38
offset := 38
// Session ID length (1 byte)
sessionIDLen := int(handshakePayload[offset])
offset++
// Skip session ID
offset += sessionIDLen
// Cipher suites length (2 bytes)
if offset+2 > len(handshakePayload) {
return ""
}
cipherSuiteLen := int(binary.BigEndian.Uint16(handshakePayload[offset : offset+2]))
offset += 2 + cipherSuiteLen
// Compression methods length (1 byte)
if offset >= len(handshakePayload) {
return ""
}
compressionLen := int(handshakePayload[offset])
offset++
// Skip compression methods
offset += compressionLen
// Extensions length (2 bytes) - optional
if offset+2 > len(handshakePayload) {
return ""
}
extensionsLen := int(binary.BigEndian.Uint16(handshakePayload[offset : offset+2]))
offset += 2
if extensionsLen == 0 || offset+extensionsLen > len(handshakePayload) {
return ""
}
extensionsEnd := offset + extensionsLen
// Debug: log extension types found
_ = extensionsEnd // suppress unused warning in case we remove debug code
// Parse extensions to find SNI (type 0)
for offset < extensionsEnd {
if offset+4 > len(handshakePayload) {
break
}
extType := binary.BigEndian.Uint16(handshakePayload[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(handshakePayload[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(handshakePayload) {
break
}
extData := handshakePayload[offset : offset+extLen]
offset += extLen
if extType == 0 && len(extData) >= 5 { // SNI extension
// SNI extension structure:
// - name_list_len (2 bytes)
// - name_type (1 byte)
// - name_len (2 bytes)
// - name (variable)
// Skip name_list_len (2), read name_type (1) + name_len (2)
nameLen := int(binary.BigEndian.Uint16(extData[3:5]))
if len(extData) >= 5+nameLen {
return string(extData[5 : 5+nameLen])
}
}
}
return ""
}
// tlsVersionToString converts a TLS version number to a string
func tlsVersionToString(version uint16) string {
switch version {
case 0x0301:
return "1.0"
case 0x0302:
return "1.1"
case 0x0303:
return "1.2"
case 0x0304:
return "1.3"
default:
return ""
}
}
// IsClientHello checks if a payload contains a TLS ClientHello
func IsClientHello(payload []byte) bool {
if len(payload) < 6 {
return false
}
// TLS handshake record
if payload[0] != 22 {
return false
}
// Check version
version := binary.BigEndian.Uint16(payload[1:3])
if version < 0x0301 || version > 0x0304 {
return false
}
recordLength := int(binary.BigEndian.Uint16(payload[3:5]))
if len(payload) < 5+recordLength {
return false
}
handshakePayload := payload[5 : 5+recordLength]
if len(handshakePayload) < 1 {
return false
}
// ClientHello type
return handshakePayload[0] == 1
}
// Helper function to join string slice with separator (kept for backward compatibility)
// Deprecated: Use strings.Join instead
func joinStringSlice(slice []string, sep string) string {
if len(slice) == 0 {
return ""
}
return strings.Join(slice, sep)
}