// Package capture provides network packet capture functionality for ja4sentinel package capture import ( "fmt" "log" "net" "regexp" "strings" "sync" "sync/atomic" "github.com/google/gopacket" "github.com/google/gopacket/pcap" "ja4sentinel/api" ) // Capture configuration constants const ( // DefaultSnapLen is the default snapshot length for packet capture // Increased from 1600 to 65535 to capture full packets including large TLS handshakes DefaultSnapLen = 65535 // DefaultPromiscuous is the default promiscuous mode setting DefaultPromiscuous = false // MaxBPFFilterLength is the maximum allowed length for BPF filters MaxBPFFilterLength = 1024 ) // validBPFPattern checks if a BPF filter contains only valid characters // This is a basic validation to prevent injection attacks var validBPFPattern = regexp.MustCompile(`^[a-zA-Z0-9\s\(\)\-\_\.\*\+\?\:\=\!\&\|\<\>\[\]\/\@,]+$`) // CaptureImpl implements the capture.Capture interface for packet capture type CaptureImpl struct { handle *pcap.Handle mu sync.Mutex snapLen int promisc bool isClosed bool localIPs []string // Local IPs to filter (dst host) linkType int // Link type from pcap handle interfaceName string // Interface name (for diagnostics) bpfFilter string // Applied BPF filter (for diagnostics) // Metrics counters (atomic) packetsReceived uint64 // Total packets received from interface packetsSent uint64 // Total packets sent to channel packetsDropped uint64 // Total packets dropped (channel full) } // New creates a new capture instance func New() *CaptureImpl { return &CaptureImpl{ snapLen: DefaultSnapLen, promisc: DefaultPromiscuous, } } // NewWithSnapLen creates a new capture instance with custom snapshot length func NewWithSnapLen(snapLen int) *CaptureImpl { if snapLen <= 0 || snapLen > 65535 { snapLen = DefaultSnapLen } return &CaptureImpl{ snapLen: snapLen, promisc: DefaultPromiscuous, } } // Run starts network packet capture according to the configuration func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error { // Validate interface name (basic check) if cfg.Interface == "" { return fmt.Errorf("interface cannot be empty") } // Find available interfaces to validate the interface exists ifaces, err := pcap.FindAllDevs() if err != nil { return fmt.Errorf("failed to list network interfaces: %w", err) } // Special handling for "any" interface interfaceFound := cfg.Interface == "any" if !interfaceFound { for _, iface := range ifaces { if iface.Name == cfg.Interface { interfaceFound = true break } } } if !interfaceFound { return fmt.Errorf("interface %s not found (available: %v)", cfg.Interface, getInterfaceNames(ifaces)) } handle, err := pcap.OpenLive(cfg.Interface, int32(c.snapLen), c.promisc, pcap.BlockForever) if err != nil { return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err) } c.mu.Lock() c.handle = handle c.mu.Unlock() defer func() { c.mu.Lock() if c.handle != nil && !c.isClosed { c.handle.Close() c.handle = nil } c.mu.Unlock() }() // Store interface name for diagnostics c.interfaceName = cfg.Interface // Resolve local IPs for filtering (if not manually specified) localIPs := cfg.LocalIPs if len(localIPs) == 0 { localIPs, err = c.detectLocalIPs(cfg.Interface) if err != nil { return fmt.Errorf("failed to detect local IPs: %w", err) } if len(localIPs) == 0 { // NAT/VIP: destination IP may not be assigned to this interface. // Fall back to port-only BPF filter instead of aborting. log.Printf("WARN capture: no local IPs found on interface %s; using port-only BPF filter (NAT/VIP mode)", cfg.Interface) } } c.localIPs = localIPs // Build and apply BPF filter bpfFilter := cfg.BPFFilter if bpfFilter == "" { bpfFilter = c.buildBPFFilter(cfg.ListenPorts, localIPs) } c.bpfFilter = bpfFilter // Validate BPF filter before applying if err := validateBPFFilter(bpfFilter); err != nil { return fmt.Errorf("invalid BPF filter: %w", err) } err = handle.SetBPFFilter(bpfFilter) if err != nil { return fmt.Errorf("failed to set BPF filter '%s': %w", bpfFilter, err) } // Store link type once, after the handle is fully configured (BPF filter applied). // A single write avoids the race where packetToRawPacket reads a stale value // that existed before the BPF filter was set. c.mu.Lock() c.linkType = int(handle.LinkType()) c.mu.Unlock() packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) for packet := range packetSource.Packets() { // Convert packet to RawPacket rawPkt := c.packetToRawPacket(packet) if rawPkt != nil { atomic.AddUint64(&c.packetsReceived, 1) select { case out <- *rawPkt: // Packet sent successfully atomic.AddUint64(&c.packetsSent, 1) default: // Channel full, drop packet atomic.AddUint64(&c.packetsDropped, 1) } } } return nil } // validateBPFFilter performs basic validation of BPF filter strings func validateBPFFilter(filter string) error { if filter == "" { return nil } if len(filter) > MaxBPFFilterLength { return fmt.Errorf("BPF filter too long (max %d characters)", MaxBPFFilterLength) } // Check for potentially dangerous patterns if !validBPFPattern.MatchString(filter) { return fmt.Errorf("BPF filter contains invalid characters") } // Check for unbalanced parentheses openParens := 0 for _, ch := range filter { if ch == '(' { openParens++ } else if ch == ')' { openParens-- if openParens < 0 { return fmt.Errorf("BPF filter has unbalanced parentheses") } } } if openParens != 0 { return fmt.Errorf("BPF filter has unbalanced parentheses") } return nil } // getInterfaceNames extracts interface names from a list of devices func getInterfaceNames(ifaces []pcap.Interface) []string { names := make([]string, len(ifaces)) for i, iface := range ifaces { names[i] = iface.Name } return names } // detectLocalIPs detects local IP addresses on the specified interface // Excludes loopback addresses (127.0.0.0/8, ::1) and IPv6 link-local (fe80::) func (c *CaptureImpl) detectLocalIPs(interfaceName string) ([]string, error) { var localIPs []string // Special case: "any" interface - get all non-loopback IPs if interfaceName == "any" { ifaces, err := net.Interfaces() if err != nil { return nil, fmt.Errorf("failed to list interfaces: %w", err) } for _, iface := range ifaces { // Skip loopback interfaces if iface.Flags&net.FlagLoopback != 0 { continue } addrs, err := iface.Addrs() if err != nil { continue // Skip this interface, try others } for _, addr := range addrs { ip := extractIP(addr) if ip != nil && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() { localIPs = append(localIPs, ip.String()) } } } return localIPs, nil } // Specific interface - get IPs from that interface only iface, err := net.InterfaceByName(interfaceName) if err != nil { return nil, fmt.Errorf("failed to get interface %s: %w", interfaceName, err) } addrs, err := iface.Addrs() if err != nil { return nil, fmt.Errorf("failed to get addresses for %s: %w", interfaceName, err) } for _, addr := range addrs { ip := extractIP(addr) if ip != nil && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() { localIPs = append(localIPs, ip.String()) } } return localIPs, nil } // extractIP extracts the IP address from a net.Addr func extractIP(addr net.Addr) net.IP { switch v := addr.(type) { case *net.IPNet: ip := v.IP // Return IPv4 as 4-byte, IPv6 as 16-byte if ip4 := ip.To4(); ip4 != nil { return ip4 } return ip case *net.IPAddr: ip := v.IP if ip4 := ip.To4(); ip4 != nil { return ip4 } return ip } return nil } // buildBPFFilter builds a BPF filter for the specified ports and local IPs // Filter: (tcp dst port 443 or tcp dst port 8443) and (dst host 192.168.1.10 or dst host 10.0.0.5) // Uses "tcp dst port" to only capture client→server traffic (not server→client responses) func (c *CaptureImpl) buildBPFFilter(ports []uint16, localIPs []string) string { if len(ports) == 0 { return "tcp" } // Build port filter (dst port only to avoid capturing server responses) portParts := make([]string, len(ports)) for i, port := range ports { portParts[i] = fmt.Sprintf("tcp dst port %d", port) } portFilter := "(" + strings.Join(portParts, ") or (") + ")" // Build destination host filter if len(localIPs) == 0 { return portFilter } hostParts := make([]string, len(localIPs)) for i, ip := range localIPs { // Handle IPv6 addresses if strings.Contains(ip, ":") { hostParts[i] = fmt.Sprintf("dst host %s", ip) } else { hostParts[i] = fmt.Sprintf("dst host %s", ip) } } hostFilter := "(" + strings.Join(hostParts, ") or (") + ")" // Combine port and host filters return portFilter + " and " + hostFilter } // joinString joins strings with a separator (kept for backward compatibility) func joinString(parts []string, sep string) string { if len(parts) == 0 { return "" } result := parts[0] for _, part := range parts[1:] { result += sep + part } return result } // packetToRawPacket converts a gopacket packet to RawPacket // Uses the raw packet bytes from the link layer func (c *CaptureImpl) packetToRawPacket(packet gopacket.Packet) *api.RawPacket { // Try to get link layer contents + payload for full packet var data []byte linkLayer := packet.LinkLayer() if linkLayer != nil { // Combine link layer contents with payload to get full packet data = append(data, linkLayer.LayerContents()...) data = append(data, linkLayer.LayerPayload()...) } else { // Fallback to packet.Data() data = packet.Data() } if len(data) == 0 { return nil } return &api.RawPacket{ Data: data, Timestamp: packet.Metadata().Timestamp.UnixNano(), LinkType: c.linkType, } } // Close properly closes the capture handle func (c *CaptureImpl) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.handle != nil && !c.isClosed { c.handle.Close() c.handle = nil c.isClosed = true return nil } c.isClosed = true return nil } // GetStats returns capture statistics (for monitoring/debugging) func (c *CaptureImpl) GetStats() (received, sent, dropped uint64) { return atomic.LoadUint64(&c.packetsReceived), atomic.LoadUint64(&c.packetsSent), atomic.LoadUint64(&c.packetsDropped) } // GetDiagnostics returns capture diagnostics information (for debugging) func (c *CaptureImpl) GetDiagnostics() (interfaceName string, localIPs []string, bpfFilter string, linkType int) { c.mu.Lock() defer c.mu.Unlock() return c.interfaceName, c.localIPs, c.bpfFilter, c.linkType }