feat: ja4-platform monorepo — 5 services unified, tests & RPM builds standardized
Services: - ja4sentinel: TLS/JA4 fingerprint capture daemon (Go, libpcap) - logcorrelator: JA4 log correlation engine (Go, ClickHouse) - mod_reqin_log: Apache module (C, JSON request logging) - bot_detector: ML bot detection pipeline (Python) - dashboard: FastAPI/Streamlit analytics UI (Python) Shared libraries: - shared/go/ja4common: logger, config, shutdown, ipfilter (Go module) - shared/python/ja4_common: ClickHouseClient, ClickHouseSettings (Python package) - shared/clickhouse/: canonical SQL migrations (10 files) Build & packaging: - Unified 3-stage Dockerfile.package for Go RPMs (el8/el9/el10) - go.work workspace linking sentinel, correlator, ja4common - Makefile with test-all, build-all, rpm-* targets Fixes applied: - go.work: 1.21 → 1.24.6 (required by sentinel) - correlator Dockerfiles: golang:1.21 → golang:1.24 - replace directives in go.mod for ja4common local path - pyproject.toml: setuptools.backends → setuptools.build_meta - Removed static libpcap linking (unavailable on Rocky 9) - Fixed data races in output/writers_test.go (sync.Mutex + atomic.Int32) - Rewrote corrupted test files (logger_test.go × 2) Test coverage: - correlator: 67.1% total (unixsocket 80.5%, config 91.7%, app 83.3%, multi 87.7%, stdout 100%) - sentinel: all 10 packages pass (api, capture, config, fingerprint, ipfilter, logging, output, tlsparse) Documentation: - README.md + docs/ (architecture, development, 5 services, shared libs, DB schema & migrations) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
397
services/sentinel/internal/capture/capture.go
Normal file
397
services/sentinel/internal/capture/capture.go
Normal file
@ -0,0 +1,397 @@
|
||||
// Package capture provides network packet capture functionality for ja4sentinel
|
||||
package capture
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/pcap"
|
||||
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
)
|
||||
|
||||
// Capture configuration constants
|
||||
const (
|
||||
// DefaultSnapLen is the default snapshot length for packet capture
|
||||
// Increased from 1600 to 65535 to capture full packets including large TLS handshakes
|
||||
DefaultSnapLen = 65535
|
||||
// DefaultPromiscuous is the default promiscuous mode setting
|
||||
DefaultPromiscuous = false
|
||||
// MaxBPFFilterLength is the maximum allowed length for BPF filters
|
||||
MaxBPFFilterLength = 1024
|
||||
)
|
||||
|
||||
// validBPFPattern checks if a BPF filter contains only valid characters
|
||||
// This is a basic validation to prevent injection attacks
|
||||
var validBPFPattern = regexp.MustCompile(`^[a-zA-Z0-9\s\(\)\-\_\.\*\+\?\:\=\!\&\|\<\>\[\]\/\@,]+$`)
|
||||
|
||||
// CaptureImpl implements the capture.Capture interface for packet capture
|
||||
type CaptureImpl struct {
|
||||
handle *pcap.Handle
|
||||
mu sync.Mutex
|
||||
snapLen int
|
||||
promisc bool
|
||||
isClosed bool
|
||||
localIPs []string // Local IPs to filter (dst host)
|
||||
linkType int // Link type from pcap handle
|
||||
interfaceName string // Interface name (for diagnostics)
|
||||
bpfFilter string // Applied BPF filter (for diagnostics)
|
||||
// Metrics counters (atomic)
|
||||
packetsReceived uint64 // Total packets received from interface
|
||||
packetsSent uint64 // Total packets sent to channel
|
||||
packetsDropped uint64 // Total packets dropped (channel full)
|
||||
}
|
||||
|
||||
// New creates a new capture instance
|
||||
func New() *CaptureImpl {
|
||||
return &CaptureImpl{
|
||||
snapLen: DefaultSnapLen,
|
||||
promisc: DefaultPromiscuous,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithSnapLen creates a new capture instance with custom snapshot length
|
||||
func NewWithSnapLen(snapLen int) *CaptureImpl {
|
||||
if snapLen <= 0 || snapLen > 65535 {
|
||||
snapLen = DefaultSnapLen
|
||||
}
|
||||
return &CaptureImpl{
|
||||
snapLen: snapLen,
|
||||
promisc: DefaultPromiscuous,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts network packet capture according to the configuration
|
||||
func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error {
|
||||
// Validate interface name (basic check)
|
||||
if cfg.Interface == "" {
|
||||
return fmt.Errorf("interface cannot be empty")
|
||||
}
|
||||
|
||||
// Find available interfaces to validate the interface exists
|
||||
ifaces, err := pcap.FindAllDevs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list network interfaces: %w", err)
|
||||
}
|
||||
|
||||
// Special handling for "any" interface
|
||||
interfaceFound := cfg.Interface == "any"
|
||||
if !interfaceFound {
|
||||
for _, iface := range ifaces {
|
||||
if iface.Name == cfg.Interface {
|
||||
interfaceFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !interfaceFound {
|
||||
return fmt.Errorf("interface %s not found (available: %v)", cfg.Interface, getInterfaceNames(ifaces))
|
||||
}
|
||||
|
||||
handle, err := pcap.OpenLive(cfg.Interface, int32(c.snapLen), c.promisc, pcap.BlockForever)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.handle = handle
|
||||
c.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
if c.handle != nil && !c.isClosed {
|
||||
c.handle.Close()
|
||||
c.handle = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Store interface name for diagnostics
|
||||
c.interfaceName = cfg.Interface
|
||||
|
||||
// Resolve local IPs for filtering (if not manually specified)
|
||||
localIPs := cfg.LocalIPs
|
||||
if len(localIPs) == 0 {
|
||||
localIPs, err = c.detectLocalIPs(cfg.Interface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to detect local IPs: %w", err)
|
||||
}
|
||||
if len(localIPs) == 0 {
|
||||
// NAT/VIP: destination IP may not be assigned to this interface.
|
||||
// Fall back to port-only BPF filter instead of aborting.
|
||||
log.Printf("WARN capture: no local IPs found on interface %s; using port-only BPF filter (NAT/VIP mode)", cfg.Interface)
|
||||
}
|
||||
}
|
||||
c.localIPs = localIPs
|
||||
|
||||
// Build and apply BPF filter
|
||||
bpfFilter := cfg.BPFFilter
|
||||
if bpfFilter == "" {
|
||||
bpfFilter = c.buildBPFFilter(cfg.ListenPorts, localIPs)
|
||||
}
|
||||
c.bpfFilter = bpfFilter
|
||||
|
||||
// Validate BPF filter before applying
|
||||
if err := validateBPFFilter(bpfFilter); err != nil {
|
||||
return fmt.Errorf("invalid BPF filter: %w", err)
|
||||
}
|
||||
|
||||
err = handle.SetBPFFilter(bpfFilter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set BPF filter '%s': %w", bpfFilter, err)
|
||||
}
|
||||
|
||||
// Store link type once, after the handle is fully configured (BPF filter applied).
|
||||
// A single write avoids the race where packetToRawPacket reads a stale value
|
||||
// that existed before the BPF filter was set.
|
||||
c.mu.Lock()
|
||||
c.linkType = int(handle.LinkType())
|
||||
c.mu.Unlock()
|
||||
|
||||
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
||||
|
||||
for packet := range packetSource.Packets() {
|
||||
// Convert packet to RawPacket
|
||||
rawPkt := c.packetToRawPacket(packet)
|
||||
if rawPkt != nil {
|
||||
atomic.AddUint64(&c.packetsReceived, 1)
|
||||
select {
|
||||
case out <- *rawPkt:
|
||||
// Packet sent successfully
|
||||
atomic.AddUint64(&c.packetsSent, 1)
|
||||
default:
|
||||
// Channel full, drop packet
|
||||
atomic.AddUint64(&c.packetsDropped, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateBPFFilter performs basic validation of BPF filter strings
|
||||
func validateBPFFilter(filter string) error {
|
||||
if filter == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(filter) > MaxBPFFilterLength {
|
||||
return fmt.Errorf("BPF filter too long (max %d characters)", MaxBPFFilterLength)
|
||||
}
|
||||
|
||||
// Check for potentially dangerous patterns
|
||||
if !validBPFPattern.MatchString(filter) {
|
||||
return fmt.Errorf("BPF filter contains invalid characters")
|
||||
}
|
||||
|
||||
// Check for unbalanced parentheses
|
||||
openParens := 0
|
||||
for _, ch := range filter {
|
||||
if ch == '(' {
|
||||
openParens++
|
||||
} else if ch == ')' {
|
||||
openParens--
|
||||
if openParens < 0 {
|
||||
return fmt.Errorf("BPF filter has unbalanced parentheses")
|
||||
}
|
||||
}
|
||||
}
|
||||
if openParens != 0 {
|
||||
return fmt.Errorf("BPF filter has unbalanced parentheses")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInterfaceNames extracts interface names from a list of devices
|
||||
func getInterfaceNames(ifaces []pcap.Interface) []string {
|
||||
names := make([]string, len(ifaces))
|
||||
for i, iface := range ifaces {
|
||||
names[i] = iface.Name
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// detectLocalIPs detects local IP addresses on the specified interface
|
||||
// Excludes loopback addresses (127.0.0.0/8, ::1) and IPv6 link-local (fe80::)
|
||||
func (c *CaptureImpl) detectLocalIPs(interfaceName string) ([]string, error) {
|
||||
var localIPs []string
|
||||
|
||||
// Special case: "any" interface - get all non-loopback IPs
|
||||
if interfaceName == "any" {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list interfaces: %w", err)
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
// Skip loopback interfaces
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue // Skip this interface, try others
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ip := extractIP(addr)
|
||||
if ip != nil && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() {
|
||||
localIPs = append(localIPs, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return localIPs, nil
|
||||
}
|
||||
|
||||
// Specific interface - get IPs from that interface only
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get addresses for %s: %w", interfaceName, err)
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ip := extractIP(addr)
|
||||
if ip != nil && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() {
|
||||
localIPs = append(localIPs, ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
return localIPs, nil
|
||||
}
|
||||
|
||||
// extractIP extracts the IP address from a net.Addr
|
||||
func extractIP(addr net.Addr) net.IP {
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip := v.IP
|
||||
// Return IPv4 as 4-byte, IPv6 as 16-byte
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
case *net.IPAddr:
|
||||
ip := v.IP
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildBPFFilter builds a BPF filter for the specified ports and local IPs
|
||||
// Filter: (tcp dst port 443 or tcp dst port 8443) and (dst host 192.168.1.10 or dst host 10.0.0.5)
|
||||
// Uses "tcp dst port" to only capture client→server traffic (not server→client responses)
|
||||
func (c *CaptureImpl) buildBPFFilter(ports []uint16, localIPs []string) string {
|
||||
if len(ports) == 0 {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
// Build port filter (dst port only to avoid capturing server responses)
|
||||
portParts := make([]string, len(ports))
|
||||
for i, port := range ports {
|
||||
portParts[i] = fmt.Sprintf("tcp dst port %d", port)
|
||||
}
|
||||
portFilter := "(" + strings.Join(portParts, ") or (") + ")"
|
||||
|
||||
// Build destination host filter
|
||||
if len(localIPs) == 0 {
|
||||
return portFilter
|
||||
}
|
||||
|
||||
hostParts := make([]string, len(localIPs))
|
||||
for i, ip := range localIPs {
|
||||
// Handle IPv6 addresses
|
||||
if strings.Contains(ip, ":") {
|
||||
hostParts[i] = fmt.Sprintf("dst host %s", ip)
|
||||
} else {
|
||||
hostParts[i] = fmt.Sprintf("dst host %s", ip)
|
||||
}
|
||||
}
|
||||
hostFilter := "(" + strings.Join(hostParts, ") or (") + ")"
|
||||
|
||||
// Combine port and host filters
|
||||
return portFilter + " and " + hostFilter
|
||||
}
|
||||
|
||||
// joinString joins strings with a separator (kept for backward compatibility)
|
||||
func joinString(parts []string, sep string) string {
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := parts[0]
|
||||
for _, part := range parts[1:] {
|
||||
result += sep + part
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// packetToRawPacket converts a gopacket packet to RawPacket
|
||||
// Uses the raw packet bytes from the link layer
|
||||
func (c *CaptureImpl) packetToRawPacket(packet gopacket.Packet) *api.RawPacket {
|
||||
// Try to get link layer contents + payload for full packet
|
||||
var data []byte
|
||||
|
||||
linkLayer := packet.LinkLayer()
|
||||
if linkLayer != nil {
|
||||
// Combine link layer contents with payload to get full packet
|
||||
data = append(data, linkLayer.LayerContents()...)
|
||||
data = append(data, linkLayer.LayerPayload()...)
|
||||
} else {
|
||||
// Fallback to packet.Data()
|
||||
data = packet.Data()
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &api.RawPacket{
|
||||
Data: data,
|
||||
Timestamp: packet.Metadata().Timestamp.UnixNano(),
|
||||
LinkType: c.linkType,
|
||||
}
|
||||
}
|
||||
|
||||
// Close properly closes the capture handle
|
||||
func (c *CaptureImpl) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.handle != nil && !c.isClosed {
|
||||
c.handle.Close()
|
||||
c.handle = nil
|
||||
c.isClosed = true
|
||||
return nil
|
||||
}
|
||||
c.isClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns capture statistics (for monitoring/debugging)
|
||||
func (c *CaptureImpl) GetStats() (received, sent, dropped uint64) {
|
||||
return atomic.LoadUint64(&c.packetsReceived),
|
||||
atomic.LoadUint64(&c.packetsSent),
|
||||
atomic.LoadUint64(&c.packetsDropped)
|
||||
}
|
||||
|
||||
// GetDiagnostics returns capture diagnostics information (for debugging)
|
||||
func (c *CaptureImpl) GetDiagnostics() (interfaceName string, localIPs []string, bpfFilter string, linkType int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.interfaceName, c.localIPs, c.bpfFilter, c.linkType
|
||||
}
|
||||
661
services/sentinel/internal/capture/capture_test.go
Normal file
661
services/sentinel/internal/capture/capture_test.go
Normal file
@ -0,0 +1,661 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/gopacket/pcap"
|
||||
)
|
||||
|
||||
func TestCaptureImpl_Run_EmptyInterface(t *testing.T) {
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
cfg := api.Config{
|
||||
Interface: "",
|
||||
ListenPorts: []uint16{443},
|
||||
}
|
||||
|
||||
out := make(chan api.RawPacket, 10)
|
||||
err := c.Run(cfg, out)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Run() with empty interface should return error")
|
||||
}
|
||||
if err.Error() != "interface cannot be empty" {
|
||||
t.Errorf("Run() error = %v, want 'interface cannot be empty'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_Run_NonExistentInterface(t *testing.T) {
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
cfg := api.Config{
|
||||
Interface: "nonexistent_interface_xyz123",
|
||||
ListenPorts: []uint16{443},
|
||||
}
|
||||
|
||||
out := make(chan api.RawPacket, 10)
|
||||
err := c.Run(cfg, out)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Run() with non-existent interface should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_Run_InvalidBPFFilter(t *testing.T) {
|
||||
// Get a real interface name
|
||||
ifaces, err := pcap.FindAllDevs()
|
||||
if err != nil || len(ifaces) == 0 {
|
||||
t.Skip("No network interfaces available for testing")
|
||||
}
|
||||
|
||||
c := New()
|
||||
cfg := api.Config{
|
||||
Interface: ifaces[0].Name,
|
||||
ListenPorts: []uint16{443},
|
||||
BPFFilter: "invalid; rm -rf /", // Invalid characters
|
||||
}
|
||||
|
||||
out := make(chan api.RawPacket, 10)
|
||||
err = c.Run(cfg, out)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Run() with invalid BPF filter should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_Run_ChannelFull_DropsPackets(t *testing.T) {
|
||||
// This test verifies that when the output channel is full,
|
||||
// packets are dropped gracefully (non-blocking write)
|
||||
|
||||
// We can't easily test the full Run() loop without real interfaces,
|
||||
// but we can verify the channel behavior with a small buffer
|
||||
out := make(chan api.RawPacket, 1)
|
||||
|
||||
// Fill the channel
|
||||
out <- api.RawPacket{Data: []byte{1, 2, 3}, Timestamp: time.Now().UnixNano()}
|
||||
|
||||
// Channel should be full now, select default should trigger
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
select {
|
||||
case out <- api.RawPacket{Data: []byte{4, 5, 6}, Timestamp: time.Now().UnixNano()}:
|
||||
done <- false // Would block
|
||||
default:
|
||||
done <- true // Dropped as expected
|
||||
}
|
||||
}()
|
||||
|
||||
dropped := <-done
|
||||
if !dropped {
|
||||
t.Error("Expected packet to be dropped when channel is full")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketToRawPacket(t *testing.T) {
|
||||
t.Run("valid_packet", func(t *testing.T) {
|
||||
// Create a simple TCP packet
|
||||
eth := layers.Ethernet{
|
||||
SrcMAC: []byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
|
||||
DstMAC: []byte{0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB},
|
||||
EthernetType: layers.EthernetTypeIPv4,
|
||||
}
|
||||
ip := layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: []byte{192, 168, 1, 1},
|
||||
DstIP: []byte{10, 0, 0, 1},
|
||||
}
|
||||
tcp := layers.TCP{
|
||||
SrcPort: 12345,
|
||||
DstPort: 443,
|
||||
}
|
||||
tcp.SetNetworkLayerForChecksum(&ip)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{}
|
||||
gopacket.SerializeLayers(buf, opts, ð, &ip, &tcp)
|
||||
|
||||
packet := gopacket.NewPacket(buf.Bytes(), layers.LinkTypeEthernet, gopacket.Default)
|
||||
|
||||
// Create capture instance for method call
|
||||
c := New()
|
||||
rawPkt := c.packetToRawPacket(packet)
|
||||
|
||||
if rawPkt == nil {
|
||||
t.Fatal("packetToRawPacket() returned nil for valid packet")
|
||||
}
|
||||
if len(rawPkt.Data) == 0 {
|
||||
t.Error("packetToRawPacket() returned empty data")
|
||||
}
|
||||
if rawPkt.Timestamp == 0 {
|
||||
t.Error("packetToRawPacket() returned zero timestamp")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_packet", func(t *testing.T) {
|
||||
// Create packet with no data
|
||||
packet := gopacket.NewPacket([]byte{}, layers.LinkTypeEthernet, gopacket.Default)
|
||||
|
||||
c := New()
|
||||
rawPkt := c.packetToRawPacket(packet)
|
||||
|
||||
if rawPkt != nil {
|
||||
t.Error("packetToRawPacket() should return nil for empty packet")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil_packet", func(t *testing.T) {
|
||||
// packetToRawPacket will panic with nil packet due to Metadata() call
|
||||
// This is expected behavior - the function is not designed to handle nil
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("packetToRawPacket() with nil packet should panic")
|
||||
}
|
||||
}()
|
||||
c := New()
|
||||
var packet gopacket.Packet
|
||||
_ = c.packetToRawPacket(packet)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetInterfaceNames(t *testing.T) {
|
||||
t.Run("empty_list", func(t *testing.T) {
|
||||
names := getInterfaceNames([]pcap.Interface{})
|
||||
if len(names) != 0 {
|
||||
t.Errorf("getInterfaceNames() with empty list = %v, want []", names)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single_interface", func(t *testing.T) {
|
||||
ifaces := []pcap.Interface{
|
||||
{Name: "eth0"},
|
||||
}
|
||||
names := getInterfaceNames(ifaces)
|
||||
if len(names) != 1 || names[0] != "eth0" {
|
||||
t.Errorf("getInterfaceNames() = %v, want [eth0]", names)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple_interfaces", func(t *testing.T) {
|
||||
ifaces := []pcap.Interface{
|
||||
{Name: "eth0"},
|
||||
{Name: "lo"},
|
||||
{Name: "docker0"},
|
||||
}
|
||||
names := getInterfaceNames(ifaces)
|
||||
if len(names) != 3 {
|
||||
t.Errorf("getInterfaceNames() returned %d names, want 3", len(names))
|
||||
}
|
||||
expected := []string{"eth0", "lo", "docker0"}
|
||||
for i, name := range names {
|
||||
if name != expected[i] {
|
||||
t.Errorf("getInterfaceNames()[%d] = %s, want %s", i, name, expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateBPFFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filter string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty filter",
|
||||
filter: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid simple filter",
|
||||
filter: "tcp port 443",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid complex filter",
|
||||
filter: "(tcp port 443) or (tcp port 8443)",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "filter with special chars",
|
||||
filter: "tcp port 443 and host 192.168.1.1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "too long filter",
|
||||
filter: string(make([]byte, MaxBPFFilterLength+1)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unbalanced parentheses - extra open",
|
||||
filter: "(tcp port 443",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unbalanced parentheses - extra close",
|
||||
filter: "tcp port 443)",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid characters - semicolon",
|
||||
filter: "tcp port 443; rm -rf /",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid characters - backtick",
|
||||
filter: "tcp port `whoami`",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid characters - dollar",
|
||||
filter: "tcp port $HOME",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateBPFFilter(tt.filter)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateBPFFilter() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
parts []string
|
||||
sep string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty slice",
|
||||
parts: []string{},
|
||||
sep: ") or (",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "single element",
|
||||
parts: []string{"tcp port 443"},
|
||||
sep: ") or (",
|
||||
want: "tcp port 443",
|
||||
},
|
||||
{
|
||||
name: "multiple elements",
|
||||
parts: []string{"tcp port 443", "tcp port 8443"},
|
||||
sep: ") or (",
|
||||
want: "tcp port 443) or (tcp port 8443",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := joinString(tt.parts, tt.sep)
|
||||
if got != tt.want {
|
||||
t.Errorf("joinString() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCapture(t *testing.T) {
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
if c.snapLen != DefaultSnapLen {
|
||||
t.Errorf("snapLen = %d, want %d", c.snapLen, DefaultSnapLen)
|
||||
}
|
||||
if c.promisc != DefaultPromiscuous {
|
||||
t.Errorf("promisc = %v, want %v", c.promisc, DefaultPromiscuous)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithSnapLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
snapLen int
|
||||
wantSnapLen int
|
||||
}{
|
||||
{
|
||||
name: "valid snapLen",
|
||||
snapLen: 2048,
|
||||
wantSnapLen: 2048,
|
||||
},
|
||||
{
|
||||
name: "zero snapLen uses default",
|
||||
snapLen: 0,
|
||||
wantSnapLen: DefaultSnapLen,
|
||||
},
|
||||
{
|
||||
name: "negative snapLen uses default",
|
||||
snapLen: -100,
|
||||
wantSnapLen: DefaultSnapLen,
|
||||
},
|
||||
{
|
||||
name: "too large snapLen uses default",
|
||||
snapLen: 100000,
|
||||
wantSnapLen: DefaultSnapLen,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := NewWithSnapLen(tt.snapLen)
|
||||
if c == nil {
|
||||
t.Fatal("NewWithSnapLen() returned nil")
|
||||
}
|
||||
if c.snapLen != tt.wantSnapLen {
|
||||
t.Errorf("snapLen = %d, want %d", c.snapLen, tt.wantSnapLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_Close(t *testing.T) {
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
// Close should not panic on fresh instance
|
||||
if err := c.Close(); err != nil {
|
||||
t.Errorf("Close() error = %v", err)
|
||||
}
|
||||
|
||||
// Multiple closes should be safe
|
||||
if err := c.Close(); err != nil {
|
||||
t.Errorf("Close() second call error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateBPFFilter_BalancedParentheses(t *testing.T) {
|
||||
// Test various balanced parentheses scenarios
|
||||
validFilters := []string{
|
||||
"(tcp port 443)",
|
||||
"((tcp port 443))",
|
||||
"(tcp port 443) or (tcp port 8443)",
|
||||
"((tcp port 443) or (tcp port 8443))",
|
||||
"(tcp port 443 and host 1.2.3.4) or (tcp port 8443)",
|
||||
}
|
||||
|
||||
for _, filter := range validFilters {
|
||||
t.Run(filter, func(t *testing.T) {
|
||||
if err := validateBPFFilter(filter); err != nil {
|
||||
t.Errorf("validateBPFFilter(%q) unexpected error = %v", filter, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_detectLocalIPs(t *testing.T) {
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
t.Run("any_interface", func(t *testing.T) {
|
||||
ips, err := c.detectLocalIPs("any")
|
||||
if err != nil {
|
||||
t.Errorf("detectLocalIPs(any) error = %v", err)
|
||||
}
|
||||
// Should return at least one non-loopback IP or empty if none available
|
||||
for _, ip := range ips {
|
||||
if ip == "127.0.0.1" || ip == "::1" {
|
||||
t.Errorf("detectLocalIPs(any) should exclude loopback, got %s", ip)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("loopback_excluded", func(t *testing.T) {
|
||||
ips, err := c.detectLocalIPs("any")
|
||||
if err != nil {
|
||||
t.Skipf("Skipping loopback test: %v", err)
|
||||
}
|
||||
// Verify no loopback addresses are included
|
||||
for _, ip := range ips {
|
||||
if ip == "127.0.0.1" {
|
||||
t.Error("detectLocalIPs should exclude 127.0.0.1")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCaptureImpl_detectLocalIPs_SpecificInterface(t *testing.T) {
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
// Test with a non-existent interface
|
||||
_, err := c.detectLocalIPs("nonexistent_interface_xyz")
|
||||
if err == nil {
|
||||
t.Error("detectLocalIPs with non-existent interface should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_extractIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
wantIPv4 bool
|
||||
wantIPv6 bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4",
|
||||
addr: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.10"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
wantIPv4: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6",
|
||||
addr: &net.IPNet{
|
||||
IP: net.ParseIP("2001:db8::1"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
wantIPv6: true,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
addr: nil,
|
||||
wantIPv4: false,
|
||||
wantIPv6: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractIP(tt.addr)
|
||||
if tt.wantIPv4 {
|
||||
if got == nil || got.To4() == nil {
|
||||
t.Error("extractIP() should return IPv4 address")
|
||||
}
|
||||
}
|
||||
if tt.wantIPv6 {
|
||||
if got == nil || got.To4() != nil {
|
||||
t.Error("extractIP() should return IPv6 address")
|
||||
}
|
||||
}
|
||||
if !tt.wantIPv4 && !tt.wantIPv6 {
|
||||
if got != nil {
|
||||
t.Error("extractIP() should return nil for nil address")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_buildBPFFilter(t *testing.T) {
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ports []uint16
|
||||
localIPs []string
|
||||
wantParts []string // Parts that should be in the filter
|
||||
}{
|
||||
{
|
||||
name: "no ports",
|
||||
ports: []uint16{},
|
||||
localIPs: []string{},
|
||||
wantParts: []string{"tcp"},
|
||||
},
|
||||
{
|
||||
name: "single port no IPs",
|
||||
ports: []uint16{443},
|
||||
localIPs: []string{},
|
||||
wantParts: []string{"tcp dst port 443"},
|
||||
},
|
||||
{
|
||||
name: "single port with single IP",
|
||||
ports: []uint16{443},
|
||||
localIPs: []string{"192.168.1.10"},
|
||||
wantParts: []string{"tcp dst port 443", "dst host 192.168.1.10"},
|
||||
},
|
||||
{
|
||||
name: "multiple ports with multiple IPs",
|
||||
ports: []uint16{443, 8443},
|
||||
localIPs: []string{"192.168.1.10", "10.0.0.5"},
|
||||
wantParts: []string{"tcp dst port 443", "tcp dst port 8443", "dst host 192.168.1.10", "dst host 10.0.0.5"},
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
ports: []uint16{443},
|
||||
localIPs: []string{"2001:db8::1"},
|
||||
wantParts: []string{"tcp dst port 443", "dst host 2001:db8::1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.buildBPFFilter(tt.ports, tt.localIPs)
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(got, part) {
|
||||
t.Errorf("buildBPFFilter() = %q, should contain %q", got, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_Run_AnyInterface(t *testing.T) {
|
||||
t.Skip("integration: pcap on 'any' interface blocks until close; run with -run=Integration in a real network env")
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
cfg := api.Config{
|
||||
Interface: "any",
|
||||
ListenPorts: []uint16{443},
|
||||
LocalIPs: []string{"192.168.1.10"},
|
||||
}
|
||||
out := make(chan api.RawPacket, 10)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- c.Run(cfg, out) }()
|
||||
// Allow up to 300ms for the handle to open (or fail immediately)
|
||||
select {
|
||||
case err := <-errCh:
|
||||
// Immediate error: permission or "not found"
|
||||
if err != nil && strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Run() with 'any' interface should be valid, got: %v", err)
|
||||
}
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
// Run() started successfully (blocking on packets) — close to stop it
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptureImpl_Run_WithManualLocalIPs(t *testing.T) {
|
||||
t.Skip("integration: pcap on 'any' interface blocks until close; run with -run=Integration in a real network env")
|
||||
c := New()
|
||||
if c == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
cfg := api.Config{
|
||||
Interface: "any",
|
||||
ListenPorts: []uint16{443},
|
||||
LocalIPs: []string{"192.168.1.10", "10.0.0.5"},
|
||||
}
|
||||
out := make(chan api.RawPacket, 10)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- c.Run(cfg, out) }()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Run() with manual LocalIPs should be valid, got: %v", err)
|
||||
}
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TestCaptureImpl_LinkTypeInitializedOnce verifies that linkType is set exactly once,
|
||||
// after the BPF filter is applied (Bug 2 fix: removed the redundant early assignment).
|
||||
func TestCaptureImpl_LinkTypeInitializedOnce(t *testing.T) {
|
||||
c := New()
|
||||
// Fresh instance: linkType must be zero before Run() is called.
|
||||
if c.linkType != 0 {
|
||||
t.Errorf("new CaptureImpl should have linkType=0, got %d", c.linkType)
|
||||
}
|
||||
|
||||
// GetDiagnostics reflects linkType correctly.
|
||||
_, _, _, lt := c.GetDiagnostics()
|
||||
if lt != 0 {
|
||||
t.Errorf("GetDiagnostics() linkType before Run() should be 0, got %d", lt)
|
||||
}
|
||||
|
||||
// Simulate what Run() does: set linkType once under the mutex.
|
||||
c.mu.Lock()
|
||||
c.linkType = 1 // 1 = Ethernet
|
||||
c.mu.Unlock()
|
||||
|
||||
_, _, _, lt = c.GetDiagnostics()
|
||||
if lt != 1 {
|
||||
t.Errorf("GetDiagnostics() linkType after set = %d, want 1", lt)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBPFFilter_NoLocalIPs verifies Bug 3 fix: when no local IPs are
|
||||
// available (NAT/VIP), buildBPFFilter returns a port-only filter.
|
||||
func TestBuildBPFFilter_NoLocalIPs(t *testing.T) {
|
||||
c := New()
|
||||
filter := c.buildBPFFilter([]uint16{443}, nil)
|
||||
if strings.Contains(filter, "dst host") {
|
||||
t.Errorf("port-only filter expected when localIPs nil, got: %s", filter)
|
||||
}
|
||||
if !strings.Contains(filter, "tcp dst port 443") {
|
||||
t.Errorf("expected tcp dst port 443, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBPFFilter_EmptyLocalIPs(t *testing.T) {
|
||||
c := New()
|
||||
filter := c.buildBPFFilter([]uint16{443, 8443}, []string{})
|
||||
if strings.Contains(filter, "dst host") {
|
||||
t.Errorf("port-only filter expected when localIPs empty, got: %s", filter)
|
||||
}
|
||||
if !strings.Contains(filter, "tcp dst port 443") || !strings.Contains(filter, "tcp dst port 8443") {
|
||||
t.Errorf("expected both ports in filter, got: %s", filter)
|
||||
}
|
||||
}
|
||||
295
services/sentinel/internal/config/loader.go
Normal file
295
services/sentinel/internal/config/loader.go
Normal file
@ -0,0 +1,295 @@
|
||||
// Package config provides configuration loading and validation for ja4sentinel
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
)
|
||||
|
||||
// LoaderImpl implements the api.Loader interface for configuration loading
|
||||
type LoaderImpl struct {
|
||||
configPath string
|
||||
}
|
||||
|
||||
// NewLoader creates a new configuration loader
|
||||
func NewLoader(configPath string) *LoaderImpl {
|
||||
return &LoaderImpl{
|
||||
configPath: configPath,
|
||||
}
|
||||
}
|
||||
|
||||
// Load reads and merges configuration from file, environment variables, and CLI
|
||||
func (l *LoaderImpl) Load() (api.AppConfig, error) {
|
||||
config := api.DefaultConfig()
|
||||
|
||||
path := l.configPath
|
||||
explicit := path != ""
|
||||
if !explicit {
|
||||
path = "config.yml"
|
||||
}
|
||||
|
||||
fileConfig, err := l.loadFromFile(path)
|
||||
if err == nil {
|
||||
config = mergeConfigs(config, fileConfig)
|
||||
} else if !(!explicit && errors.Is(err, os.ErrNotExist)) {
|
||||
return config, fmt.Errorf("failed to load config file: %w", err)
|
||||
}
|
||||
|
||||
// Override with environment variables
|
||||
config = l.loadFromEnv(config)
|
||||
|
||||
// Validate the final configuration
|
||||
if err := l.validate(config); err != nil {
|
||||
return config, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// loadFromFile reads configuration from a YAML file
|
||||
func (l *LoaderImpl) loadFromFile(path string) (api.AppConfig, error) {
|
||||
config := api.AppConfig{}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// loadFromEnv overrides configuration with environment variables
|
||||
func (l *LoaderImpl) loadFromEnv(config api.AppConfig) api.AppConfig {
|
||||
// JA4SENTINEL_INTERFACE
|
||||
if val := os.Getenv("JA4SENTINEL_INTERFACE"); val != "" {
|
||||
config.Core.Interface = val
|
||||
}
|
||||
|
||||
// JA4SENTINEL_PORTS (comma-separated list)
|
||||
if val := os.Getenv("JA4SENTINEL_PORTS"); val != "" {
|
||||
ports := parsePorts(val)
|
||||
if len(ports) > 0 {
|
||||
config.Core.ListenPorts = ports
|
||||
}
|
||||
}
|
||||
|
||||
// JA4SENTINEL_BPF_FILTER
|
||||
if val := os.Getenv("JA4SENTINEL_BPF_FILTER"); val != "" {
|
||||
config.Core.BPFFilter = val
|
||||
}
|
||||
|
||||
// JA4SENTINEL_FLOW_TIMEOUT (in seconds)
|
||||
if val := os.Getenv("JA4SENTINEL_FLOW_TIMEOUT"); val != "" {
|
||||
if timeout, err := strconv.Atoi(val); err == nil && timeout > 0 {
|
||||
config.Core.FlowTimeoutSec = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// JA4SENTINEL_PACKET_BUFFER_SIZE
|
||||
if val := os.Getenv("JA4SENTINEL_PACKET_BUFFER_SIZE"); val != "" {
|
||||
if size, err := strconv.Atoi(val); err == nil && size > 0 {
|
||||
config.Core.PacketBufferSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// Note: JA4SENTINEL_LOG_LEVEL is intentionally NOT loaded from env.
|
||||
// log_level must be configured exclusively via the YAML config file.
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// parsePorts parses a comma-separated list of ports
|
||||
func parsePorts(s string) []uint16 {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(s, ",")
|
||||
ports := make([]uint16, 0, len(parts))
|
||||
seen := make(map[uint16]struct{}, len(parts))
|
||||
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
port, err := strconv.ParseUint(part, 10, 16)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
p := uint16(port)
|
||||
if p == 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[p]; exists {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
ports = append(ports, p)
|
||||
}
|
||||
|
||||
return ports
|
||||
}
|
||||
|
||||
// mergeConfigs merges two configs, with override taking precedence
|
||||
func mergeConfigs(base, override api.AppConfig) api.AppConfig {
|
||||
result := base
|
||||
|
||||
if override.Core.Interface != "" {
|
||||
result.Core.Interface = override.Core.Interface
|
||||
}
|
||||
|
||||
if len(override.Core.ListenPorts) > 0 {
|
||||
result.Core.ListenPorts = override.Core.ListenPorts
|
||||
}
|
||||
|
||||
if override.Core.BPFFilter != "" {
|
||||
result.Core.BPFFilter = override.Core.BPFFilter
|
||||
}
|
||||
|
||||
if override.Core.FlowTimeoutSec > 0 {
|
||||
result.Core.FlowTimeoutSec = override.Core.FlowTimeoutSec
|
||||
}
|
||||
|
||||
if override.Core.PacketBufferSize > 0 {
|
||||
result.Core.PacketBufferSize = override.Core.PacketBufferSize
|
||||
}
|
||||
|
||||
if override.Core.LogLevel != "" {
|
||||
result.Core.LogLevel = override.Core.LogLevel
|
||||
}
|
||||
|
||||
// Merge exclude_source_ips (override takes precedence)
|
||||
if len(override.Core.ExcludeSourceIPs) > 0 {
|
||||
result.Core.ExcludeSourceIPs = override.Core.ExcludeSourceIPs
|
||||
}
|
||||
|
||||
if len(override.Outputs) > 0 {
|
||||
result.Outputs = override.Outputs
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// validate checks if the configuration is valid
|
||||
func (l *LoaderImpl) validate(config api.AppConfig) error {
|
||||
if strings.TrimSpace(config.Core.Interface) == "" {
|
||||
return fmt.Errorf("interface cannot be empty")
|
||||
}
|
||||
|
||||
if len(config.Core.ListenPorts) == 0 {
|
||||
return fmt.Errorf("at least one listen port is required")
|
||||
}
|
||||
for _, p := range config.Core.ListenPorts {
|
||||
if p == 0 {
|
||||
return fmt.Errorf("listen port 0 is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
if config.Core.FlowTimeoutSec <= 0 || config.Core.FlowTimeoutSec > 300 {
|
||||
return fmt.Errorf("flow_timeout_sec must be between 1 and 300")
|
||||
}
|
||||
|
||||
if config.Core.PacketBufferSize <= 0 || config.Core.PacketBufferSize > 1_000_000 {
|
||||
return fmt.Errorf("packet_buffer_size must be between 1 and 1000000")
|
||||
}
|
||||
|
||||
// Validate log level
|
||||
validLogLevels := map[string]struct{}{
|
||||
"debug": {},
|
||||
"info": {},
|
||||
"warn": {},
|
||||
"error": {},
|
||||
}
|
||||
if config.Core.LogLevel != "" {
|
||||
if _, ok := validLogLevels[config.Core.LogLevel]; !ok {
|
||||
return fmt.Errorf("log_level must be one of: debug, info, warn, error")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate exclude_source_ips (if provided)
|
||||
if len(config.Core.ExcludeSourceIPs) > 0 {
|
||||
for i, ip := range config.Core.ExcludeSourceIPs {
|
||||
if ip == "" {
|
||||
return fmt.Errorf("exclude_source_ips[%d]: entry cannot be empty", i)
|
||||
}
|
||||
// Basic validation: check if it looks like an IP or CIDR
|
||||
if !strings.Contains(ip, "/") {
|
||||
// Single IP - basic check
|
||||
if !isValidIP(ip) {
|
||||
return fmt.Errorf("exclude_source_ips[%d]: invalid IP address %q", i, ip)
|
||||
}
|
||||
} else {
|
||||
// CIDR - basic check
|
||||
if !isValidCIDR(ip) {
|
||||
return fmt.Errorf("exclude_source_ips[%d]: invalid CIDR %q", i, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allowedTypes := map[string]struct{}{
|
||||
"stdout": {},
|
||||
"file": {},
|
||||
"unix_socket": {},
|
||||
}
|
||||
|
||||
// Validate outputs
|
||||
for i, output := range config.Outputs {
|
||||
outputType := strings.TrimSpace(output.Type)
|
||||
if outputType == "" {
|
||||
return fmt.Errorf("output[%d]: type cannot be empty", i)
|
||||
}
|
||||
if _, ok := allowedTypes[outputType]; !ok {
|
||||
return fmt.Errorf("output[%d]: unknown type %q", i, outputType)
|
||||
}
|
||||
|
||||
switch outputType {
|
||||
case "file":
|
||||
if strings.TrimSpace(output.Params["path"]) == "" {
|
||||
return fmt.Errorf("output[%d]: file output requires non-empty path", i)
|
||||
}
|
||||
case "unix_socket":
|
||||
if strings.TrimSpace(output.Params["socket_path"]) == "" {
|
||||
return fmt.Errorf("output[%d]: unix_socket output requires non-empty socket_path", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToJSON converts config to JSON string for debugging
|
||||
func ToJSON(config api.AppConfig) string {
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error marshaling config: %v", err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// isValidIP checks if a string is a valid IP address using net.ParseIP
|
||||
func isValidIP(ip string) bool {
|
||||
return net.ParseIP(ip) != nil
|
||||
}
|
||||
|
||||
// isValidCIDR checks if a string is a valid CIDR notation using net.ParseCIDR
|
||||
func isValidCIDR(cidr string) bool {
|
||||
_, _, err := net.ParseCIDR(cidr)
|
||||
return err == nil
|
||||
}
|
||||
1008
services/sentinel/internal/config/loader_test.go
Normal file
1008
services/sentinel/internal/config/loader_test.go
Normal file
File diff suppressed because it is too large
Load Diff
171
services/sentinel/internal/fingerprint/engine.go
Normal file
171
services/sentinel/internal/fingerprint/engine.go
Normal file
@ -0,0 +1,171 @@
|
||||
// Package fingerprint provides JA4/JA3 fingerprint generation for TLS ClientHello
|
||||
package fingerprint
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
|
||||
tlsfingerprint "github.com/psanford/tlsfingerprint"
|
||||
)
|
||||
|
||||
// EngineImpl implements the api.Engine interface for fingerprint generation
|
||||
type EngineImpl struct{}
|
||||
|
||||
// NewEngine creates a new fingerprint engine
|
||||
func NewEngine() *EngineImpl {
|
||||
return &EngineImpl{}
|
||||
}
|
||||
|
||||
// FromClientHello generates JA4 (and optionally JA3) fingerprints from a TLS ClientHello
|
||||
// Note: JA4 hash portion is extracted for internal use but NOT serialized to LogRecord
|
||||
// as the JA4 format already includes its own hash portions (per architecture.yml)
|
||||
func (e *EngineImpl) FromClientHello(ch api.TLSClientHello) (*api.Fingerprints, error) {
|
||||
if len(ch.Payload) == 0 {
|
||||
return nil, fmt.Errorf("empty ClientHello payload from %s:%d -> %s:%d",
|
||||
ch.SrcIP, ch.SrcPort, ch.DstIP, ch.DstPort)
|
||||
}
|
||||
|
||||
// Parse the ClientHello using tlsfingerprint
|
||||
fp, err := tlsfingerprint.ParseClientHello(ch.Payload)
|
||||
if err != nil {
|
||||
// Try to sanitize truncated extensions and retry
|
||||
sanitized := sanitizeClientHelloExtensions(ch.Payload)
|
||||
if sanitized != nil {
|
||||
fp, err = tlsfingerprint.ParseClientHello(sanitized)
|
||||
}
|
||||
if err != nil {
|
||||
sanitizeStatus := "unavailable"
|
||||
if sanitized != nil {
|
||||
sanitizeStatus = "failed"
|
||||
}
|
||||
return nil, fmt.Errorf("fingerprint generation failed for %s:%d -> %s:%d (conn_id=%s, payload_len=%d, tls_version=%s, sni=%s, sanitization=%s): %w",
|
||||
ch.SrcIP, ch.SrcPort, ch.DstIP, ch.DstPort, ch.ConnID, len(ch.Payload), ch.TLSVersion, ch.SNI, sanitizeStatus, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate JA4 fingerprint
|
||||
// Note: JA4 string format already includes the hash portion
|
||||
// e.g., "t13d1516h2_8daaf6152771_02cb136f2775" where the last part is the SHA256 hash
|
||||
ja4 := fp.JA4String()
|
||||
|
||||
// Generate JA3 fingerprint and its MD5 hash
|
||||
ja3 := fp.JA3String()
|
||||
ja3Hash := fp.JA3Hash()
|
||||
|
||||
// Extract JA4 hash portion (last segment after underscore)
|
||||
// JA4 format: <tls_ver><ciphers><extensions>_<sni_hash>_<cipher_extension_hash>
|
||||
// This is kept for internal use but NOT serialized to LogRecord
|
||||
ja4Hash := extractJA4Hash(ja4)
|
||||
|
||||
return &api.Fingerprints{
|
||||
JA4: ja4,
|
||||
JA4Hash: ja4Hash, // Internal use only - not serialized to LogRecord
|
||||
JA3: ja3,
|
||||
JA3Hash: ja3Hash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractJA4Hash extracts the hash portion from a JA4 string
|
||||
// JA4 format: <base>_<sni_hash>_<cipher_hash> -> returns "<sni_hash>_<cipher_hash>"
|
||||
func extractJA4Hash(ja4 string) string {
|
||||
// JA4 string format: t13d1516h2_8daaf6152771_02cb136f2775
|
||||
// We extract everything after the first underscore as the "hash" portion
|
||||
for i, c := range ja4 {
|
||||
if c == '_' {
|
||||
return ja4[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sanitizeClientHelloExtensions fixes ClientHellos with truncated extension data
|
||||
// by adjusting the extensions length to include only complete extensions.
|
||||
// Returns a corrected copy, or nil if the payload cannot be fixed.
|
||||
func sanitizeClientHelloExtensions(data []byte) []byte {
|
||||
if len(data) < 5 || data[0] != 0x16 {
|
||||
return nil
|
||||
}
|
||||
recordLen := int(data[3])<<8 | int(data[4])
|
||||
if len(data) < 5+recordLen {
|
||||
return nil
|
||||
}
|
||||
payload := data[5 : 5+recordLen]
|
||||
if len(payload) < 4 || payload[0] != 0x01 {
|
||||
return nil
|
||||
}
|
||||
helloLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3])
|
||||
if len(payload) < 4+helloLen {
|
||||
return nil
|
||||
}
|
||||
hello := payload[4 : 4+helloLen]
|
||||
|
||||
// Skip through ClientHello fields to reach extensions
|
||||
offset := 2 + 32 // version + random
|
||||
if len(hello) < offset+1 {
|
||||
return nil
|
||||
}
|
||||
offset += 1 + int(hello[offset]) // session ID
|
||||
if len(hello) < offset+2 {
|
||||
return nil
|
||||
}
|
||||
csLen := int(hello[offset])<<8 | int(hello[offset+1])
|
||||
offset += 2 + csLen // cipher suites
|
||||
if len(hello) < offset+1 {
|
||||
return nil
|
||||
}
|
||||
offset += 1 + int(hello[offset]) // compression methods
|
||||
if len(hello) < offset+2 {
|
||||
return nil
|
||||
}
|
||||
extLenOffset := offset // position of extensions length field
|
||||
declaredExtLen := int(hello[offset])<<8 | int(hello[offset+1])
|
||||
offset += 2
|
||||
extStart := offset
|
||||
|
||||
if len(hello) < extStart+declaredExtLen {
|
||||
return nil
|
||||
}
|
||||
extData := hello[extStart : extStart+declaredExtLen]
|
||||
|
||||
// Walk extensions, find how many complete ones exist
|
||||
validLen := 0
|
||||
pos := 0
|
||||
for pos < len(extData) {
|
||||
if pos+4 > len(extData) {
|
||||
break
|
||||
}
|
||||
extBodyLen := int(extData[pos+2])<<8 | int(extData[pos+3])
|
||||
if pos+4+extBodyLen > len(extData) {
|
||||
break // this extension is truncated
|
||||
}
|
||||
pos += 4 + extBodyLen
|
||||
validLen = pos
|
||||
}
|
||||
|
||||
if validLen == declaredExtLen {
|
||||
return nil // no truncation found, nothing to fix
|
||||
}
|
||||
|
||||
// Build a corrected copy with adjusted extensions length
|
||||
fixed := make([]byte, len(data))
|
||||
copy(fixed, data)
|
||||
|
||||
// Absolute offset of extensions length field within data
|
||||
extLenAbs := 5 + 4 + extLenOffset
|
||||
diff := declaredExtLen - validLen
|
||||
|
||||
// Update extensions length
|
||||
binary.BigEndian.PutUint16(fixed[extLenAbs:], uint16(validLen))
|
||||
// Update ClientHello handshake length
|
||||
newHelloLen := helloLen - diff
|
||||
fixed[5+1] = byte(newHelloLen >> 16)
|
||||
fixed[5+2] = byte(newHelloLen >> 8)
|
||||
fixed[5+3] = byte(newHelloLen)
|
||||
// Update TLS record length
|
||||
newRecordLen := recordLen - diff
|
||||
binary.BigEndian.PutUint16(fixed[3:5], uint16(newRecordLen))
|
||||
|
||||
return fixed[:5+newRecordLen]
|
||||
}
|
||||
489
services/sentinel/internal/fingerprint/engine_test.go
Normal file
489
services/sentinel/internal/fingerprint/engine_test.go
Normal file
@ -0,0 +1,489 @@
|
||||
package fingerprint
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
|
||||
tlsfingerprint "github.com/psanford/tlsfingerprint"
|
||||
)
|
||||
|
||||
func TestFromClientHello(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ch api.TLSClientHello
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty payload",
|
||||
ch: api.TLSClientHello{
|
||||
Payload: []byte{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid payload",
|
||||
ch: api.TLSClientHello{
|
||||
Payload: []byte{0x00, 0x01, 0x02},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := NewEngine()
|
||||
_, err := engine.FromClientHello(tt.ch)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("FromClientHello() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
engine := NewEngine()
|
||||
if engine == nil {
|
||||
t.Error("NewEngine() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromClientHello_ValidPayload(t *testing.T) {
|
||||
// Use a minimal valid TLS 1.2 ClientHello with extensions
|
||||
// Build a proper ClientHello using the same structure as parser tests
|
||||
clientHello := buildMinimalClientHelloForTest()
|
||||
|
||||
ch := api.TLSClientHello{
|
||||
SrcIP: "192.168.1.100",
|
||||
SrcPort: 54321,
|
||||
DstIP: "10.0.0.1",
|
||||
DstPort: 443,
|
||||
Payload: clientHello,
|
||||
}
|
||||
|
||||
engine := NewEngine()
|
||||
fp, err := engine.FromClientHello(ch)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FromClientHello() error = %v", err)
|
||||
}
|
||||
if fp == nil {
|
||||
t.Fatal("FromClientHello() returned nil")
|
||||
}
|
||||
|
||||
// Verify JA4 is populated (format: t13d... or t12d...)
|
||||
if fp.JA4 == "" {
|
||||
t.Error("JA4 should not be empty")
|
||||
}
|
||||
|
||||
// JA4Hash is populated for internal use (but not serialized to LogRecord)
|
||||
// It contains the hash portions of the JA4 string
|
||||
if fp.JA4Hash == "" {
|
||||
t.Error("JA4Hash should be populated for internal use")
|
||||
}
|
||||
}
|
||||
|
||||
// buildMinimalClientHelloForTest creates a minimal valid TLS 1.2 ClientHello
|
||||
func buildMinimalClientHelloForTest() []byte {
|
||||
// Cipher suites (minimal set)
|
||||
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
|
||||
// Compression methods (null only)
|
||||
compressionMethods := []byte{0x01, 0x00}
|
||||
// No extensions
|
||||
extensions := []byte{}
|
||||
extLen := len(extensions)
|
||||
|
||||
// Build ClientHello handshake body
|
||||
handshakeBody := []byte{
|
||||
0x03, 0x03, // Version: TLS 1.2
|
||||
// Random (32 bytes)
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, // Session ID length: 0
|
||||
}
|
||||
|
||||
// Add cipher suites (with length prefix)
|
||||
cipherSuiteLen := len(cipherSuites)
|
||||
handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen))
|
||||
handshakeBody = append(handshakeBody, cipherSuites...)
|
||||
|
||||
// Add compression methods (with length prefix)
|
||||
handshakeBody = append(handshakeBody, compressionMethods...)
|
||||
|
||||
// Add extensions (with length prefix)
|
||||
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
|
||||
handshakeBody = append(handshakeBody, extensions...)
|
||||
|
||||
// Now build full handshake with type and length
|
||||
handshakeLen := len(handshakeBody)
|
||||
handshake := append([]byte{
|
||||
0x01, // Handshake type: ClientHello
|
||||
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen), // Handshake length
|
||||
}, handshakeBody...)
|
||||
|
||||
// Build TLS record
|
||||
recordLen := len(handshake)
|
||||
record := make([]byte, 5+recordLen)
|
||||
record[0] = 0x16 // Handshake
|
||||
record[1] = 0x03 // Version: TLS 1.2
|
||||
record[2] = 0x03
|
||||
record[3] = byte(recordLen >> 8)
|
||||
record[4] = byte(recordLen)
|
||||
copy(record[5:], handshake)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
// TestExtractJA4Hash tests the extractJA4Hash helper function
|
||||
func TestExtractJA4Hash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ja4 string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "standard_ja4_format",
|
||||
ja4: "t13d1516h2_8daaf6152771_02cb136f2775",
|
||||
want: "8daaf6152771_02cb136f2775",
|
||||
},
|
||||
{
|
||||
name: "ja4_with_single_underscore",
|
||||
ja4: "t12d1234h1_abcdef123456",
|
||||
want: "abcdef123456",
|
||||
},
|
||||
{
|
||||
name: "ja4_no_underscore_returns_empty",
|
||||
ja4: "t13d1516h2",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty_ja4_returns_empty",
|
||||
ja4: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "underscore_at_start",
|
||||
ja4: "_hash1_hash2",
|
||||
want: "hash1_hash2",
|
||||
},
|
||||
{
|
||||
name: "multiple_underscores_returns_after_first",
|
||||
ja4: "base_part1_part2_part3",
|
||||
want: "part1_part2_part3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractJA4Hash(tt.ja4)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractJA4Hash(%q) = %q, want %q", tt.ja4, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromClientHello_NilPayload tests error handling for nil payload
|
||||
func TestFromClientHello_NilPayload(t *testing.T) {
|
||||
engine := NewEngine()
|
||||
ch := api.TLSClientHello{
|
||||
Payload: nil,
|
||||
}
|
||||
|
||||
_, err := engine.FromClientHello(ch)
|
||||
|
||||
if err == nil {
|
||||
t.Error("FromClientHello() with nil payload should return error")
|
||||
}
|
||||
if !strings.HasPrefix(err.Error(), "empty ClientHello payload") {
|
||||
t.Errorf("FromClientHello() error = %v, should start with 'empty ClientHello payload'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromClientHello_JA3Hash tests that JA3Hash is correctly populated
|
||||
func TestFromClientHello_JA3Hash(t *testing.T) {
|
||||
clientHello := buildMinimalClientHelloForTest()
|
||||
|
||||
ch := api.TLSClientHello{
|
||||
Payload: clientHello,
|
||||
}
|
||||
|
||||
engine := NewEngine()
|
||||
fp, err := engine.FromClientHello(ch)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FromClientHello() error = %v", err)
|
||||
}
|
||||
|
||||
// JA3Hash should be populated (MD5 hash of JA3 string)
|
||||
if fp.JA3Hash == "" {
|
||||
t.Error("JA3Hash should be populated")
|
||||
}
|
||||
|
||||
// JA3 should also be populated
|
||||
if fp.JA3 == "" {
|
||||
t.Error("JA3 should be populated")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromClientHello_EmptyJA4Hash tests behavior when JA4 has no underscore
|
||||
func TestFromClientHello_EmptyJA4Hash(t *testing.T) {
|
||||
// This test verifies that even if JA4 format changes, the code handles it gracefully
|
||||
engine := NewEngine()
|
||||
|
||||
// Use a valid ClientHello - the library should produce a proper JA4
|
||||
clientHello := buildMinimalClientHelloForTest()
|
||||
|
||||
ch := api.TLSClientHello{
|
||||
Payload: clientHello,
|
||||
}
|
||||
|
||||
fp, err := engine.FromClientHello(ch)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FromClientHello() error = %v", err)
|
||||
}
|
||||
|
||||
// JA4 should always be populated
|
||||
if fp.JA4 == "" {
|
||||
t.Error("JA4 should be populated")
|
||||
}
|
||||
|
||||
// JA4Hash may be empty if the JA4 format doesn't include underscores
|
||||
// This is acceptable behavior
|
||||
}
|
||||
|
||||
// buildClientHelloWithTruncatedExtension creates a ClientHello where the last
|
||||
// extension declares more data than actually present.
|
||||
func buildClientHelloWithTruncatedExtension() []byte {
|
||||
// Build a valid SNI extension first
|
||||
sniHostname := []byte("example.com")
|
||||
sniExt := []byte{
|
||||
0x00, 0x00, // Extension type: server_name
|
||||
}
|
||||
sniData := []byte{0x00}
|
||||
sniListLen := 1 + 2 + len(sniHostname) // type(1) + len(2) + hostname
|
||||
sniData = append(sniData, byte(sniListLen>>8), byte(sniListLen))
|
||||
sniData = append(sniData, 0x00) // hostname type
|
||||
sniData = append(sniData, byte(len(sniHostname)>>8), byte(len(sniHostname)))
|
||||
sniData = append(sniData, sniHostname...)
|
||||
sniExt = append(sniExt, byte(len(sniData)>>8), byte(len(sniData)))
|
||||
sniExt = append(sniExt, sniData...)
|
||||
|
||||
// Build a truncated extension: declares 100 bytes but only has 5
|
||||
truncatedExt := []byte{
|
||||
0x00, 0x15, // Extension type: padding
|
||||
0x00, 0x64, // Extension data length: 100 (but we only provide 5)
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, // Only 5 bytes of padding
|
||||
}
|
||||
|
||||
// Extensions = valid SNI + truncated padding
|
||||
extensions := append(sniExt, truncatedExt...)
|
||||
// But the extensions length field claims the full size (including the bad extension)
|
||||
extLen := len(extensions)
|
||||
|
||||
// Cipher suites
|
||||
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
|
||||
compressionMethods := []byte{0x01, 0x00}
|
||||
|
||||
handshakeBody := []byte{0x03, 0x03}
|
||||
for i := 0; i < 32; i++ {
|
||||
handshakeBody = append(handshakeBody, 0x01)
|
||||
}
|
||||
handshakeBody = append(handshakeBody, 0x00) // session ID length: 0
|
||||
handshakeBody = append(handshakeBody, byte(len(cipherSuites)>>8), byte(len(cipherSuites)))
|
||||
handshakeBody = append(handshakeBody, cipherSuites...)
|
||||
handshakeBody = append(handshakeBody, compressionMethods...)
|
||||
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
|
||||
handshakeBody = append(handshakeBody, extensions...)
|
||||
|
||||
handshakeLen := len(handshakeBody)
|
||||
handshake := append([]byte{
|
||||
0x01,
|
||||
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen),
|
||||
}, handshakeBody...)
|
||||
|
||||
recordLen := len(handshake)
|
||||
record := make([]byte, 5+recordLen)
|
||||
record[0] = 0x16
|
||||
record[1] = 0x03
|
||||
record[2] = 0x03
|
||||
record[3] = byte(recordLen >> 8)
|
||||
record[4] = byte(recordLen)
|
||||
copy(record[5:], handshake)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func TestFromClientHello_TruncatedExtension_StillGeneratesFingerprint(t *testing.T) {
|
||||
payload := buildClientHelloWithTruncatedExtension()
|
||||
|
||||
ch := api.TLSClientHello{
|
||||
SrcIP: "4.251.36.192",
|
||||
SrcPort: 19346,
|
||||
DstIP: "212.95.72.88",
|
||||
DstPort: 443,
|
||||
Payload: payload,
|
||||
ConnID: "4.251.36.192:19346->212.95.72.88:443",
|
||||
}
|
||||
|
||||
engine := NewEngine()
|
||||
fp, err := engine.FromClientHello(ch)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("FromClientHello() should succeed after sanitization, got error: %v", err)
|
||||
}
|
||||
if fp == nil {
|
||||
t.Fatal("FromClientHello() returned nil fingerprint")
|
||||
}
|
||||
if fp.JA4 == "" {
|
||||
t.Error("JA4 should be populated even with truncated extension")
|
||||
}
|
||||
if fp.JA3 == "" {
|
||||
t.Error("JA3 should be populated even with truncated extension")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeClientHelloExtensions(t *testing.T) {
|
||||
t.Run("valid payload returns nil", func(t *testing.T) {
|
||||
valid := buildMinimalClientHelloForTest()
|
||||
result := sanitizeClientHelloExtensions(valid)
|
||||
if result != nil {
|
||||
t.Error("should return nil for valid payload (no fix needed)")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncated extension is fixed", func(t *testing.T) {
|
||||
truncated := buildClientHelloWithTruncatedExtension()
|
||||
result := sanitizeClientHelloExtensions(truncated)
|
||||
if result == nil {
|
||||
t.Fatal("should return sanitized payload")
|
||||
}
|
||||
// The sanitized payload should be parseable by the library
|
||||
fp, err := tlsfingerprint.ParseClientHello(result)
|
||||
if err != nil {
|
||||
t.Fatalf("sanitized payload should parse without error, got: %v", err)
|
||||
}
|
||||
if fp == nil {
|
||||
t.Fatal("sanitized payload should produce a fingerprint")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("too short returns nil", func(t *testing.T) {
|
||||
if sanitizeClientHelloExtensions([]byte{0x16}) != nil {
|
||||
t.Error("should return nil for short payload")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-TLS returns nil", func(t *testing.T) {
|
||||
if sanitizeClientHelloExtensions([]byte{0x15, 0x03, 0x03, 0x00, 0x01, 0x00}) != nil {
|
||||
t.Error("should return nil for non-TLS payload")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestExtractJA4Hash_Standard tests the hash extraction from a standard JA4 string.
|
||||
func TestExtractJA4Hash_Standard(t *testing.T) {
|
||||
ja4 := "t13d1516h2_8daaf6152771_02cb136f2775"
|
||||
got := extractJA4Hash(ja4)
|
||||
expected := "8daaf6152771_02cb136f2775"
|
||||
if got != expected {
|
||||
t.Errorf("extractJA4Hash(%q) = %q, want %q", ja4, got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractJA4Hash_NoUnderscore tests that no underscore returns empty string.
|
||||
func TestExtractJA4Hash_NoUnderscore(t *testing.T) {
|
||||
got := extractJA4Hash("nounderscore")
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for no underscore, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractJA4Hash_Empty tests that empty string returns empty string.
|
||||
func TestExtractJA4Hash_Empty(t *testing.T) {
|
||||
got := extractJA4Hash("")
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for empty input, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromClientHello_NilPayloadExplicit tests that nil payload (empty) returns error.
|
||||
func TestFromClientHello_NilPayloadExplicit(t *testing.T) {
|
||||
engine := NewEngine()
|
||||
_, err := engine.FromClientHello(api.TLSClientHello{
|
||||
SrcIP: "1.2.3.4",
|
||||
SrcPort: 12345,
|
||||
DstIP: "5.6.7.8",
|
||||
DstPort: 443,
|
||||
Payload: nil,
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for nil payload")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromClientHello_SingleByte tests that single byte payload returns error.
|
||||
func TestFromClientHello_SingleByte(t *testing.T) {
|
||||
engine := NewEngine()
|
||||
_, err := engine.FromClientHello(api.TLSClientHello{
|
||||
Payload: []byte{0x16},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for single-byte payload")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFromClientHello_ErrorContainsAddresses tests that error message includes addresses.
|
||||
func TestFromClientHello_ErrorContainsAddresses(t *testing.T) {
|
||||
engine := NewEngine()
|
||||
_, err := engine.FromClientHello(api.TLSClientHello{
|
||||
SrcIP: "192.168.1.100",
|
||||
SrcPort: 54321,
|
||||
DstIP: "10.0.0.1",
|
||||
DstPort: 443,
|
||||
ConnID: "test-conn-id",
|
||||
Payload: []byte{0x01, 0x02, 0x03}, // invalid
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid payload")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "192.168.1.100") {
|
||||
t.Errorf("expected error to contain src IP, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSanitizeClientHelloExtensions_NilInput tests nil input returns nil.
|
||||
func TestSanitizeClientHelloExtensions_NilInput(t *testing.T) {
|
||||
if sanitizeClientHelloExtensions(nil) != nil {
|
||||
t.Error("nil input should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSanitizeClientHelloExtensions_EmptyInput tests empty input returns nil.
|
||||
func TestSanitizeClientHelloExtensions_EmptyInput(t *testing.T) {
|
||||
if sanitizeClientHelloExtensions([]byte{}) != nil {
|
||||
t.Error("empty input should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestJA4HashExtraction_ConsistentWithFullParse verifies JA4Hash is the tail of JA4 string.
|
||||
func TestJA4HashExtraction_ConsistentWithFullParse(t *testing.T) {
|
||||
// Any JA4 string with exactly one underscore should work
|
||||
ja4 := "t12d4562h0_somehash"
|
||||
hash := extractJA4Hash(ja4)
|
||||
if !strings.HasPrefix(ja4, "t12") {
|
||||
t.Skip("precondition failed")
|
||||
}
|
||||
if hash != "somehash" {
|
||||
t.Errorf("expected 'somehash', got %q", hash)
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time check: EngineImpl satisfies api.Engine.
|
||||
var _ interface {
|
||||
FromClientHello(api.TLSClientHello) (*api.Fingerprints, error)
|
||||
} = (*EngineImpl)(nil)
|
||||
402
services/sentinel/internal/integration/pipeline_test.go
Normal file
402
services/sentinel/internal/integration/pipeline_test.go
Normal file
@ -0,0 +1,402 @@
|
||||
// Package integration provides integration tests for the full ja4sentinel pipeline
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
"github.com/antitbone/ja4/sentinel/internal/fingerprint"
|
||||
"github.com/antitbone/ja4/sentinel/internal/output"
|
||||
"github.com/antitbone/ja4/sentinel/internal/tlsparse"
|
||||
)
|
||||
|
||||
// TestFullPipeline_TLSClientHelloToFingerprint tests the pipeline from TLS ClientHello to fingerprint
|
||||
func TestFullPipeline_TLSClientHelloToFingerprint(t *testing.T) {
|
||||
// Create a minimal TLS 1.2 ClientHello for testing
|
||||
clientHello := buildMinimalTLSClientHello()
|
||||
|
||||
// Step 1: Parse the ClientHello
|
||||
parser := tlsparse.NewParser()
|
||||
if parser == nil {
|
||||
t.Fatal("NewParser() returned nil")
|
||||
}
|
||||
defer parser.Close()
|
||||
|
||||
// Create a raw packet with the ClientHello
|
||||
rawPacket := api.RawPacket{
|
||||
Data: buildEthernetIPPacket(clientHello),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
// Process the packet
|
||||
ch, err := parser.Process(rawPacket)
|
||||
if err != nil {
|
||||
t.Fatalf("Process() error = %v", err)
|
||||
}
|
||||
if ch == nil {
|
||||
t.Fatal("Process() returned nil ClientHello")
|
||||
}
|
||||
|
||||
// Step 2: Generate fingerprints
|
||||
engine := fingerprint.NewEngine()
|
||||
if engine == nil {
|
||||
t.Fatal("NewEngine() returned nil")
|
||||
}
|
||||
|
||||
fp, err := engine.FromClientHello(*ch)
|
||||
if err != nil {
|
||||
t.Fatalf("FromClientHello() error = %v", err)
|
||||
}
|
||||
if fp == nil {
|
||||
t.Fatal("FromClientHello() returned nil")
|
||||
}
|
||||
|
||||
// Verify fingerprints are populated
|
||||
if fp.JA4 == "" {
|
||||
t.Error("JA4 should be populated")
|
||||
}
|
||||
if fp.JA3 == "" {
|
||||
t.Error("JA3 should be populated")
|
||||
}
|
||||
if fp.JA3Hash == "" {
|
||||
t.Error("JA3Hash should be populated")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFullPipeline_FingerprintToOutput tests the pipeline from fingerprint to output
|
||||
func TestFullPipeline_FingerprintToOutput(t *testing.T) {
|
||||
// Create test data
|
||||
clientHello := api.TLSClientHello{
|
||||
SrcIP: "192.168.1.100",
|
||||
SrcPort: 54321,
|
||||
DstIP: "10.0.0.1",
|
||||
DstPort: 443,
|
||||
IPMeta: api.IPMeta{
|
||||
TTL: 64,
|
||||
TotalLength: 512,
|
||||
IPID: 12345,
|
||||
DF: true,
|
||||
},
|
||||
TCPMeta: api.TCPMeta{
|
||||
WindowSize: 65535,
|
||||
MSS: 1460,
|
||||
WindowScale: 7,
|
||||
Options: []string{"MSS", "SACK", "TS", "WS"},
|
||||
},
|
||||
ConnID: "test-flow-123",
|
||||
SNI: "example.com",
|
||||
ALPN: "h2",
|
||||
TLSVersion: "1.3",
|
||||
SynToCHMs: uint32Ptr(50),
|
||||
}
|
||||
|
||||
// Create fingerprints
|
||||
fingerprints := &api.Fingerprints{
|
||||
JA4: "t13d1516h2_8daaf6152771_02cb136f2775",
|
||||
JA4Hash: "8daaf6152771_02cb136f2775",
|
||||
JA3: "771,4865-4866-4867,0-23-65281-10-11-35-16-5-13-18-51-45-43-27-17513,29-23-24,0",
|
||||
JA3Hash: "a0e6f06c7a6d15e5e3f0f0e6f06c7a6d",
|
||||
}
|
||||
|
||||
// Step 1: Create LogRecord
|
||||
logRecord := api.NewLogRecord(clientHello, fingerprints)
|
||||
logRecord.SensorID = "test-sensor"
|
||||
|
||||
// Step 2: Write to output (stdout writer for testing)
|
||||
writer := output.NewStdoutWriter()
|
||||
if writer == nil {
|
||||
t.Fatal("NewStdoutWriter() returned nil")
|
||||
}
|
||||
|
||||
// Capture stdout by using a buffer (we can't easily test stdout, so we verify the record)
|
||||
// Instead, verify the LogRecord is valid JSON
|
||||
data, err := json.Marshal(logRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON is valid and contains expected fields
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify key fields
|
||||
if result["src_ip"] != "192.168.1.100" {
|
||||
t.Errorf("src_ip = %v, want 192.168.1.100", result["src_ip"])
|
||||
}
|
||||
if result["src_port"] != float64(54321) {
|
||||
t.Errorf("src_port = %v, want 54321", result["src_port"])
|
||||
}
|
||||
if result["ja4"] != "t13d1516h2_8daaf6152771_02cb136f2775" {
|
||||
t.Errorf("ja4 = %v, want t13d1516h2_8daaf6152771_02cb136f2775", result["ja4"])
|
||||
}
|
||||
if result["tls_sni"] != "example.com" {
|
||||
t.Errorf("tls_sni = %v, want example.com", result["tls_sni"])
|
||||
}
|
||||
if result["sensor_id"] != "test-sensor" {
|
||||
t.Errorf("sensor_id = %v, want test-sensor", result["sensor_id"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestFullPipeline_EndToEnd tests the complete pipeline with file output
|
||||
func TestFullPipeline_EndToEnd(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
outputPath := tmpDir + "/output.log"
|
||||
|
||||
// Create test ClientHello
|
||||
clientHello := buildMinimalTLSClientHello()
|
||||
|
||||
// Step 1: Parse
|
||||
parser := tlsparse.NewParser()
|
||||
defer parser.Close()
|
||||
|
||||
rawPacket := api.RawPacket{
|
||||
Data: buildEthernetIPPacket(clientHello),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
ch, err := parser.Process(rawPacket)
|
||||
if err != nil {
|
||||
t.Fatalf("Process() error = %v", err)
|
||||
}
|
||||
|
||||
// Step 2: Fingerprint
|
||||
engine := fingerprint.NewEngine()
|
||||
fp, err := engine.FromClientHello(*ch)
|
||||
if err != nil {
|
||||
t.Fatalf("FromClientHello() error = %v", err)
|
||||
}
|
||||
|
||||
// Step 3: Create LogRecord
|
||||
logRecord := api.NewLogRecord(*ch, fp)
|
||||
logRecord.SensorID = "test-sensor-e2e"
|
||||
|
||||
// Step 4: Write to file
|
||||
fileWriter, err := output.NewFileWriter(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileWriter() error = %v", err)
|
||||
}
|
||||
defer fileWriter.Close()
|
||||
|
||||
err = fileWriter.Write(logRecord)
|
||||
if err != nil {
|
||||
t.Errorf("Write() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify output file
|
||||
data, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
t.Fatal("Output file is empty")
|
||||
}
|
||||
|
||||
// Parse and verify
|
||||
var result api.LogRecord
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
if result.SensorID != "test-sensor-e2e" {
|
||||
t.Errorf("SensorID = %v, want test-sensor-e2e", result.SensorID)
|
||||
}
|
||||
if result.JA4 == "" {
|
||||
t.Error("JA4 should be populated")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFullPipeline_MultiOutput tests writing to multiple outputs simultaneously
|
||||
func TestFullPipeline_MultiOutput(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
filePath := tmpDir + "/multi.log"
|
||||
|
||||
// Create multi-writer
|
||||
multiWriter := output.NewMultiWriter()
|
||||
multiWriter.Add(output.NewStdoutWriter())
|
||||
|
||||
fileWriter, err := output.NewFileWriter(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFileWriter() error = %v", err)
|
||||
}
|
||||
multiWriter.Add(fileWriter)
|
||||
|
||||
// Create test record
|
||||
logRecord := api.LogRecord{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 12345,
|
||||
JA4: "test-multi-output",
|
||||
}
|
||||
|
||||
// Write to all outputs
|
||||
err = multiWriter.Write(logRecord)
|
||||
if err != nil {
|
||||
t.Errorf("Write() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify file output
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
t.Fatal("File output is empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFullPipeline_ConfigToOutput tests building output from config
|
||||
func TestFullPipeline_ConfigToOutput(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create config with multiple outputs
|
||||
config := api.AppConfig{
|
||||
Core: api.Config{
|
||||
Interface: "eth0",
|
||||
ListenPorts: []uint16{443},
|
||||
},
|
||||
Outputs: []api.OutputConfig{
|
||||
{
|
||||
Type: "stdout",
|
||||
Enabled: true,
|
||||
AsyncBuffer: 1000,
|
||||
},
|
||||
{
|
||||
Type: "file",
|
||||
Enabled: true,
|
||||
AsyncBuffer: 1000,
|
||||
Params: map[string]string{"path": tmpDir + "/config-output.log"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Build writer from config
|
||||
builder := output.NewBuilder()
|
||||
writer, err := builder.NewFromConfig(config)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify writer is MultiWriter
|
||||
_, ok := writer.(*output.MultiWriter)
|
||||
if !ok {
|
||||
t.Fatal("Expected MultiWriter")
|
||||
}
|
||||
|
||||
// Test writing
|
||||
logRecord := api.LogRecord{
|
||||
SrcIP: "192.168.1.1",
|
||||
JA4: "test-config-output",
|
||||
}
|
||||
|
||||
err = writer.Write(logRecord)
|
||||
if err != nil {
|
||||
t.Errorf("Write() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// buildMinimalTLSClientHello creates a minimal TLS 1.2 ClientHello for testing
|
||||
func buildMinimalTLSClientHello() []byte {
|
||||
// Cipher suites
|
||||
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
|
||||
compressionMethods := []byte{0x01, 0x00}
|
||||
extensions := []byte{}
|
||||
extLen := len(extensions)
|
||||
|
||||
handshakeBody := []byte{
|
||||
0x03, 0x03, // Version: TLS 1.2
|
||||
}
|
||||
// Random (32 bytes)
|
||||
for i := 0; i < 32; i++ {
|
||||
handshakeBody = append(handshakeBody, 0x00)
|
||||
}
|
||||
handshakeBody = append(handshakeBody, 0x00) // Session ID length
|
||||
|
||||
// Cipher suites
|
||||
cipherSuiteLen := len(cipherSuites)
|
||||
handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen))
|
||||
handshakeBody = append(handshakeBody, cipherSuites...)
|
||||
|
||||
// Compression methods
|
||||
handshakeBody = append(handshakeBody, compressionMethods...)
|
||||
|
||||
// Extensions
|
||||
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
|
||||
handshakeBody = append(handshakeBody, extensions...)
|
||||
|
||||
// Build handshake
|
||||
handshakeLen := len(handshakeBody)
|
||||
handshake := append([]byte{
|
||||
0x01, // Handshake type: ClientHello
|
||||
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen),
|
||||
}, handshakeBody...)
|
||||
|
||||
// Build TLS record
|
||||
recordLen := len(handshake)
|
||||
record := make([]byte, 5+recordLen)
|
||||
record[0] = 0x16 // Handshake
|
||||
record[1] = 0x03 // Version: TLS 1.2
|
||||
record[2] = 0x03
|
||||
record[3] = byte(recordLen >> 8)
|
||||
record[4] = byte(recordLen)
|
||||
copy(record[5:], handshake)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
// buildEthernetIPPacket wraps a TLS payload in Ethernet/IP/TCP headers
|
||||
func buildEthernetIPPacket(tlsPayload []byte) []byte {
|
||||
// This is a simplified packet structure for testing
|
||||
// Real packets would have proper Ethernet, IP, and TCP headers
|
||||
|
||||
// Ethernet header (14 bytes)
|
||||
eth := make([]byte, 14)
|
||||
eth[12] = 0x08 // EtherType: IPv4
|
||||
eth[13] = 0x00
|
||||
|
||||
// IP header (20 bytes)
|
||||
ip := make([]byte, 20)
|
||||
ip[0] = 0x45 // Version 4, IHL 5
|
||||
ip[1] = 0x00 // DSCP/ECN
|
||||
ip[2] = byte((20 + 20 + len(tlsPayload)) >> 8) // Total length
|
||||
ip[3] = byte((20 + 20 + len(tlsPayload)) & 0xFF)
|
||||
ip[8] = 64 // TTL
|
||||
ip[9] = 6 // Protocol: TCP
|
||||
ip[12] = 192
|
||||
ip[13] = 168
|
||||
ip[14] = 1
|
||||
ip[15] = 100 // Src IP: 192.168.1.100
|
||||
ip[16] = 10
|
||||
ip[17] = 0
|
||||
ip[18] = 0
|
||||
ip[19] = 1 // Dst IP: 10.0.0.1
|
||||
|
||||
// TCP header (20 bytes)
|
||||
tcp := make([]byte, 20)
|
||||
tcp[0] = byte(54321 >> 8) // Src port high
|
||||
tcp[1] = byte(54321 & 0xFF) // Src port low
|
||||
tcp[2] = byte(443 >> 8) // Dst port high
|
||||
tcp[3] = byte(443 & 0xFF) // Dst port low
|
||||
tcp[12] = 0x50 // Data offset (5 * 4 = 20 bytes)
|
||||
tcp[13] = 0x18 // Flags: ACK, PSH
|
||||
|
||||
// Combine all headers with payload
|
||||
packet := make([]byte, len(eth)+len(ip)+len(tcp)+len(tlsPayload))
|
||||
copy(packet, eth)
|
||||
copy(packet[len(eth):], ip)
|
||||
copy(packet[len(eth)+len(ip):], tcp)
|
||||
copy(packet[len(eth)+len(ip)+len(tcp):], tlsPayload)
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 {
|
||||
return &v
|
||||
}
|
||||
15
services/sentinel/internal/ipfilter/ipfilter.go
Normal file
15
services/sentinel/internal/ipfilter/ipfilter.go
Normal file
@ -0,0 +1,15 @@
|
||||
// Package ipfilter provides IP address and CIDR range matching for filtering.
|
||||
// Implementation is delegated to shared/go/ja4common/ipfilter to avoid duplication.
|
||||
package ipfilter
|
||||
|
||||
import jaipfilter "github.com/antitbone/ja4/ja4common/ipfilter"
|
||||
|
||||
// Filter is a type alias for ja4common/ipfilter.Filter.
|
||||
// All methods (New, ShouldExclude, Count) are inherited from the shared module.
|
||||
type Filter = jaipfilter.Filter
|
||||
|
||||
// New creates a new IP filter from a list of IP addresses or CIDR ranges.
|
||||
// Accepts formats like: "192.168.1.1", "10.0.0.0/8", "2001:db8::/32"
|
||||
func New(excludeList []string) (*Filter, error) {
|
||||
return jaipfilter.New(excludeList)
|
||||
}
|
||||
160
services/sentinel/internal/ipfilter/ipfilter_test.go
Normal file
160
services/sentinel/internal/ipfilter/ipfilter_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
package ipfilter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilter_New(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
list []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty list",
|
||||
list: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "single IP",
|
||||
list: []string{"192.168.1.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "single CIDR",
|
||||
list: []string{"10.0.0.0/8"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed IPs and CIDRs",
|
||||
list: []string{"192.168.1.1", "10.0.0.0/8", "172.16.0.0/12"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid IP",
|
||||
list: []string{"999.999.999.999"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid CIDR",
|
||||
list: []string{"10.0.0.0/33"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
list: []string{"2001:db8::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 CIDR",
|
||||
list: []string{"2001:db8::/32"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := New(tt.list)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil && f == nil {
|
||||
t.Error("New() should return non-nil filter on success")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_ShouldExclude(t *testing.T) {
|
||||
f, err := New([]string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"2001:db8::1",
|
||||
"fc00::/7",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
want bool
|
||||
}{
|
||||
// Exact IP matches
|
||||
{"exact match", "192.168.1.1", true},
|
||||
{"exact IPv6 match", "2001:db8::1", true},
|
||||
|
||||
// CIDR matches
|
||||
{"CIDR match 10.0.0.1", "10.0.0.1", true},
|
||||
{"CIDR match 10.255.255.255", "10.255.255.255", true},
|
||||
{"CIDR match 172.16.0.1", "172.16.0.1", true},
|
||||
{"CIDR match 172.31.255.255", "172.31.255.255", true},
|
||||
{"CIDR IPv6 match", "fc00::1", true},
|
||||
|
||||
// No matches
|
||||
{"no match 192.168.2.1", "192.168.2.1", false},
|
||||
{"no match 11.0.0.1", "11.0.0.1", false},
|
||||
{"no match 172.32.0.1", "172.32.0.1", false},
|
||||
{"no match 8.8.8.8", "8.8.8.8", false},
|
||||
|
||||
// Invalid IP
|
||||
{"invalid IP", "invalid", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := f.ShouldExclude(tt.ip); got != tt.want {
|
||||
t.Errorf("ShouldExclude(%q) = %v, want %v", tt.ip, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_ShouldExclude_NilFilter(t *testing.T) {
|
||||
var f *Filter
|
||||
if f.ShouldExclude("192.168.1.1") {
|
||||
t.Error("ShouldExclude on nil filter should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_Count(t *testing.T) {
|
||||
f, err := New([]string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ips, networks := f.Count()
|
||||
if ips != 2 {
|
||||
t.Errorf("Count() ips = %d, want 2", ips)
|
||||
}
|
||||
if networks != 2 {
|
||||
t.Errorf("Count() networks = %d, want 2", networks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_EmptyEntries(t *testing.T) {
|
||||
f, err := New([]string{"", "192.168.1.1", ""})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ips, _ := f.Count()
|
||||
if ips != 1 {
|
||||
t.Errorf("Count() ips = %d, want 1 (empty entries should be skipped)", ips)
|
||||
}
|
||||
|
||||
if !f.ShouldExclude("192.168.1.1") {
|
||||
t.Error("Should exclude 192.168.1.1")
|
||||
}
|
||||
if f.ShouldExclude("192.168.1.2") {
|
||||
t.Error("Should not exclude 192.168.1.2")
|
||||
}
|
||||
}
|
||||
19
services/sentinel/internal/logging/logger_factory.go
Normal file
19
services/sentinel/internal/logging/logger_factory.go
Normal file
@ -0,0 +1,19 @@
|
||||
// Package logging provides a factory for creating loggers
|
||||
package logging
|
||||
|
||||
import (
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
)
|
||||
|
||||
// LoggerFactory creates logger instances
|
||||
type LoggerFactory struct{}
|
||||
|
||||
// NewLogger creates a new logger based on configuration
|
||||
func (f *LoggerFactory) NewLogger(level string) api.Logger {
|
||||
return NewServiceLogger(level)
|
||||
}
|
||||
|
||||
// NewDefaultLogger creates a logger with default settings
|
||||
func (f *LoggerFactory) NewDefaultLogger() api.Logger {
|
||||
return NewServiceLogger("info")
|
||||
}
|
||||
47
services/sentinel/internal/logging/service_logger.go
Normal file
47
services/sentinel/internal/logging/service_logger.go
Normal file
@ -0,0 +1,47 @@
|
||||
// Package logging provides structured logging for the sentinel service.
|
||||
// Implementation is delegated to shared/go/ja4common/logger to avoid duplication.
|
||||
package logging
|
||||
|
||||
import (
|
||||
jalogger "github.com/antitbone/ja4/ja4common/logger"
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
)
|
||||
|
||||
// ServiceLogger satisfies api.Logger using ja4common/logger.ComponentLogger.
|
||||
// This avoids duplicating logging logic that is now shared across all ja4-platform services.
|
||||
type ServiceLogger struct {
|
||||
inner *jalogger.ComponentLogger
|
||||
}
|
||||
|
||||
// NewServiceLogger creates a new ServiceLogger backed by ja4common.
|
||||
func NewServiceLogger(level string) *ServiceLogger {
|
||||
return &ServiceLogger{inner: jalogger.NewComponentLogger(level)}
|
||||
}
|
||||
|
||||
// Log emits a structured log entry for the given component.
|
||||
func (l *ServiceLogger) Log(component, level, message string, details map[string]string) {
|
||||
l.inner.Log(component, level, message, details)
|
||||
}
|
||||
|
||||
// Debug logs a debug entry for the given component.
|
||||
func (l *ServiceLogger) Debug(component, message string, details map[string]string) {
|
||||
l.inner.Debug(component, message, details)
|
||||
}
|
||||
|
||||
// Info logs an info entry for the given component.
|
||||
func (l *ServiceLogger) Info(component, message string, details map[string]string) {
|
||||
l.inner.Info(component, message, details)
|
||||
}
|
||||
|
||||
// Warn logs a warning entry for the given component.
|
||||
func (l *ServiceLogger) Warn(component, message string, details map[string]string) {
|
||||
l.inner.Warn(component, message, details)
|
||||
}
|
||||
|
||||
// Error logs an error entry for the given component.
|
||||
func (l *ServiceLogger) Error(component, message string, details map[string]string) {
|
||||
l.inner.Error(component, message, details)
|
||||
}
|
||||
|
||||
// compile-time check: ServiceLogger must satisfy api.Logger
|
||||
var _ api.Logger = (*ServiceLogger)(nil)
|
||||
79
services/sentinel/internal/logging/service_logger_test.go
Normal file
79
services/sentinel/internal/logging/service_logger_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
// Package logging tests — behavioral tests for ServiceLogger.
|
||||
// Since ServiceLogger delegates to ja4common/logger.ComponentLogger,
|
||||
// we test behavior (no-panic, interface satisfaction, level filtering)
|
||||
// rather than internal output buffering.
|
||||
package logging_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
"github.com/antitbone/ja4/sentinel/internal/logging"
|
||||
)
|
||||
|
||||
func TestNewServiceLogger_NonNil(t *testing.T) {
|
||||
logger := logging.NewServiceLogger("info")
|
||||
if logger == nil {
|
||||
t.Fatal("expected non-nil logger")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceLogger_ImplementsApiLogger(t *testing.T) {
|
||||
logger := logging.NewServiceLogger("debug")
|
||||
var _ api.Logger = logger // compile-time check
|
||||
}
|
||||
|
||||
func TestServiceLogger_AllLevels_NoPanic(t *testing.T) {
|
||||
levels := []string{"debug", "info", "warn", "error", "invalid"}
|
||||
for _, level := range levels {
|
||||
t.Run(level, func(t *testing.T) {
|
||||
logger := logging.NewServiceLogger(level)
|
||||
logger.Debug("comp", "debug msg", map[string]string{"k": "v"})
|
||||
logger.Info("comp", "info msg", nil)
|
||||
logger.Warn("comp", "warn msg", map[string]string{"x": "y"})
|
||||
logger.Error("comp", "error msg", nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceLogger_WithDetails(t *testing.T) {
|
||||
logger := logging.NewServiceLogger("debug")
|
||||
details := map[string]string{"error": "test error", "trace_id": "abc123"}
|
||||
logger.Info("service", "test message", details)
|
||||
}
|
||||
|
||||
func TestServiceLogger_NilDetails(t *testing.T) {
|
||||
logger := logging.NewServiceLogger("debug")
|
||||
logger.Info("service", "test message", nil)
|
||||
}
|
||||
|
||||
func TestServiceLogger_ConcurrentLogging(t *testing.T) {
|
||||
logger := logging.NewServiceLogger("debug")
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
logger.Info("service", "concurrent message", map[string]string{"id": string(rune('0'+id))})
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerFactory(t *testing.T) {
|
||||
factory := &logging.LoggerFactory{}
|
||||
levels := []string{"debug", "info", "warn", "error"}
|
||||
for _, level := range levels {
|
||||
t.Run(level, func(t *testing.T) {
|
||||
logger := factory.NewLogger(level)
|
||||
if logger == nil {
|
||||
t.Fatalf("NewLogger(%q) returned nil", level)
|
||||
}
|
||||
})
|
||||
}
|
||||
logger := factory.NewDefaultLogger()
|
||||
if logger == nil {
|
||||
t.Fatal("NewDefaultLogger() returned nil")
|
||||
}
|
||||
}
|
||||
644
services/sentinel/internal/output/writers.go
Normal file
644
services/sentinel/internal/output/writers.go
Normal file
@ -0,0 +1,644 @@
|
||||
// Package output provides writers for ja4sentinel log records
|
||||
package output
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/sentinel/api"
|
||||
)
|
||||
|
||||
// Socket configuration constants
|
||||
const (
|
||||
// DefaultDialTimeout is the default timeout for socket connections
|
||||
DefaultDialTimeout = 5 * time.Second
|
||||
// DefaultWriteTimeout is the default timeout for socket writes
|
||||
DefaultWriteTimeout = 5 * time.Second
|
||||
// DefaultMaxReconnectAttempts is the maximum number of reconnection attempts
|
||||
DefaultMaxReconnectAttempts = 3
|
||||
// DefaultReconnectBackoff is the initial backoff duration for reconnection
|
||||
DefaultReconnectBackoff = 100 * time.Millisecond
|
||||
// DefaultMaxReconnectBackoff is the maximum backoff duration
|
||||
DefaultMaxReconnectBackoff = 2 * time.Second
|
||||
// DefaultQueueSize is the size of the write queue for async writes
|
||||
DefaultQueueSize = 1000
|
||||
// DefaultMaxFileSize is the default maximum file size in bytes before rotation (100MB)
|
||||
DefaultMaxFileSize = 100 * 1024 * 1024
|
||||
// DefaultMaxBackups is the default number of backup files to keep
|
||||
DefaultMaxBackups = 3
|
||||
)
|
||||
|
||||
// StdoutWriter writes log records to stdout
|
||||
type StdoutWriter struct {
|
||||
encoder *json.Encoder
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewStdoutWriter creates a new stdout writer
|
||||
func NewStdoutWriter() *StdoutWriter {
|
||||
return &StdoutWriter{
|
||||
encoder: json.NewEncoder(os.Stdout),
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes a log record to stdout
|
||||
func (w *StdoutWriter) Write(rec api.LogRecord) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
return w.encoder.Encode(rec)
|
||||
}
|
||||
|
||||
// Close closes the writer (no-op for stdout)
|
||||
func (w *StdoutWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileWriter writes log records to a file with rotation support
|
||||
type FileWriter struct {
|
||||
file *os.File
|
||||
encoder *json.Encoder
|
||||
mutex sync.Mutex
|
||||
path string
|
||||
maxSize int64
|
||||
maxBackups int
|
||||
currentSize int64
|
||||
errorCallback ErrorCallback
|
||||
failuresMu sync.Mutex
|
||||
failures int
|
||||
}
|
||||
|
||||
// FileWriterOption is a function type for configuring FileWriter
|
||||
type FileWriterOption func(*FileWriter)
|
||||
|
||||
// WithFileErrorCallback sets an error callback for file write errors
|
||||
func WithFileErrorCallback(cb ErrorCallback) FileWriterOption {
|
||||
return func(w *FileWriter) {
|
||||
w.errorCallback = cb
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileWriter creates a new file writer with rotation
|
||||
func NewFileWriter(path string) (*FileWriter, error) {
|
||||
return NewFileWriterWithConfig(path, DefaultMaxFileSize, DefaultMaxBackups)
|
||||
}
|
||||
|
||||
// NewFileWriterWithConfig creates a new file writer with custom rotation config
|
||||
func NewFileWriterWithConfig(path string, maxSize int64, maxBackups int, opts ...FileWriterOption) (*FileWriter, error) {
|
||||
// Create directory if it doesn't exist
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Open file with secure permissions (owner read/write only)
|
||||
file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Get current file size
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, fmt.Errorf("failed to stat file: %w", err)
|
||||
}
|
||||
|
||||
w := &FileWriter{
|
||||
file: file,
|
||||
encoder: json.NewEncoder(file),
|
||||
path: path,
|
||||
maxSize: maxSize,
|
||||
maxBackups: maxBackups,
|
||||
currentSize: info.Size(),
|
||||
}
|
||||
|
||||
// Apply options (for error callback)
|
||||
for _, opt := range opts {
|
||||
opt(w)
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// rotate rotates the log file if it exceeds the max size
|
||||
func (w *FileWriter) rotate() error {
|
||||
if err := w.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close file: %w", err)
|
||||
}
|
||||
|
||||
// Rotate existing backups
|
||||
for i := w.maxBackups; i > 1; i-- {
|
||||
oldPath := fmt.Sprintf("%s.%d", w.path, i-1)
|
||||
newPath := fmt.Sprintf("%s.%d", w.path, i)
|
||||
os.Rename(oldPath, newPath) // Ignore errors - file may not exist
|
||||
}
|
||||
|
||||
// Move current file to .1
|
||||
backupPath := fmt.Sprintf("%s.1", w.path)
|
||||
if err := os.Rename(w.path, backupPath); err != nil {
|
||||
// If rename fails, just truncate
|
||||
if err := os.Truncate(w.path, 0); err != nil {
|
||||
return fmt.Errorf("failed to truncate file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Open new file
|
||||
newFile, err := os.OpenFile(w.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open new file: %w", err)
|
||||
}
|
||||
|
||||
w.file = newFile
|
||||
w.encoder = json.NewEncoder(newFile)
|
||||
w.currentSize = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes a log record to the file
|
||||
func (w *FileWriter) Write(rec api.LogRecord) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
// Check if rotation is needed
|
||||
if w.currentSize >= w.maxSize {
|
||||
if err := w.rotate(); err != nil {
|
||||
w.reportError(fmt.Errorf("failed to rotate file: %w", err))
|
||||
return fmt.Errorf("failed to rotate file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode to buffer first to get size
|
||||
data, err := json.Marshal(rec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal record: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
// Write to file
|
||||
n, err := w.file.Write(data)
|
||||
if err != nil {
|
||||
w.reportError(fmt.Errorf("failed to write to file: %w", err))
|
||||
return fmt.Errorf("failed to write to file: %w", err)
|
||||
}
|
||||
w.currentSize += int64(n)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reportError reports a file write error via the configured callback
|
||||
func (w *FileWriter) reportError(err error) {
|
||||
if w.errorCallback != nil {
|
||||
w.failuresMu.Lock()
|
||||
w.failures++
|
||||
failures := w.failures
|
||||
w.failuresMu.Unlock()
|
||||
w.errorCallback(w.path, err, failures)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the file
|
||||
func (w *FileWriter) Close() error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
if w.file != nil {
|
||||
return w.file.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reopen reopens the log file (for logrotate support)
|
||||
func (w *FileWriter) Reopen() error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
if err := w.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close file during reopen: %w", err)
|
||||
}
|
||||
|
||||
// Open new file
|
||||
newFile, err := os.OpenFile(w.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reopen file %s: %w", w.path, err)
|
||||
}
|
||||
|
||||
w.file = newFile
|
||||
w.encoder = json.NewEncoder(newFile)
|
||||
w.currentSize = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrorCallback is a function type for reporting socket connection errors
|
||||
type ErrorCallback func(socketPath string, err error, attempt int)
|
||||
|
||||
// UnixSocketWriter writes log records to a UNIX socket with reconnection logic
|
||||
// No internal logging - only LogRecord JSON data is sent to the socket
|
||||
type UnixSocketWriter struct {
|
||||
socketPath string
|
||||
conn net.Conn
|
||||
mutex sync.Mutex
|
||||
dialTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
maxReconnects int
|
||||
reconnectBackoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
queue chan []byte
|
||||
queueClose chan struct{}
|
||||
queueDone chan struct{}
|
||||
closeOnce sync.Once
|
||||
isClosed bool
|
||||
pendingWrites [][]byte
|
||||
pendingMu sync.Mutex
|
||||
errorCallback ErrorCallback
|
||||
consecutiveFailures int
|
||||
failuresMu sync.Mutex
|
||||
networkType string // "unix" for STREAM, "unixgram" for DGRAM
|
||||
}
|
||||
|
||||
// NewUnixSocketWriter creates a new UNIX socket writer with reconnection logic
|
||||
func NewUnixSocketWriter(socketPath string) (*UnixSocketWriter, error) {
|
||||
return NewUnixSocketWriterWithConfig(socketPath, DefaultDialTimeout, DefaultWriteTimeout, DefaultQueueSize)
|
||||
}
|
||||
|
||||
// UnixSocketWriterOption is a function type for configuring UnixSocketWriter
|
||||
type UnixSocketWriterOption func(*UnixSocketWriter)
|
||||
|
||||
// WithErrorCallback sets an error callback for socket connection errors
|
||||
func WithErrorCallback(cb ErrorCallback) UnixSocketWriterOption {
|
||||
return func(w *UnixSocketWriter) {
|
||||
w.errorCallback = cb
|
||||
}
|
||||
}
|
||||
|
||||
// NewUnixSocketWriterWithConfig creates a new UNIX socket writer with custom configuration
|
||||
func NewUnixSocketWriterWithConfig(socketPath string, dialTimeout, writeTimeout time.Duration, queueSize int, opts ...UnixSocketWriterOption) (*UnixSocketWriter, error) {
|
||||
w := &UnixSocketWriter{
|
||||
socketPath: socketPath,
|
||||
dialTimeout: dialTimeout,
|
||||
writeTimeout: writeTimeout,
|
||||
maxReconnects: DefaultMaxReconnectAttempts,
|
||||
reconnectBackoff: DefaultReconnectBackoff,
|
||||
maxBackoff: DefaultMaxReconnectBackoff,
|
||||
queue: make(chan []byte, queueSize),
|
||||
queueClose: make(chan struct{}),
|
||||
queueDone: make(chan struct{}),
|
||||
pendingWrites: make([][]byte, 0),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(w)
|
||||
}
|
||||
|
||||
// Start the queue processor
|
||||
go w.processQueue()
|
||||
|
||||
// Try initial connection silently (socket may not exist yet - that's okay)
|
||||
// Use unixgram (DGRAM) for connectionless UDP-like socket communication
|
||||
conn, err := net.DialTimeout("unixgram", socketPath, w.dialTimeout)
|
||||
if err == nil {
|
||||
w.conn = conn
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// processQueue handles queued writes with reconnection logic
|
||||
func (w *UnixSocketWriter) processQueue() {
|
||||
defer close(w.queueDone)
|
||||
|
||||
backoff := w.reconnectBackoff
|
||||
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-w.queue:
|
||||
if !ok {
|
||||
// Channel closed, drain remaining data
|
||||
w.flushPendingData()
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.writeWithReconnect(data); err != nil {
|
||||
w.failuresMu.Lock()
|
||||
w.consecutiveFailures++
|
||||
failures := w.consecutiveFailures
|
||||
w.failuresMu.Unlock()
|
||||
|
||||
// Report error via callback if configured
|
||||
w.reportError(err, failures)
|
||||
|
||||
// Queue for retry
|
||||
w.pendingMu.Lock()
|
||||
if len(w.pendingWrites) < DefaultQueueSize {
|
||||
w.pendingWrites = append(w.pendingWrites, data)
|
||||
}
|
||||
w.pendingMu.Unlock()
|
||||
|
||||
// Exponential backoff
|
||||
if failures > w.maxReconnects {
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
if backoff > w.maxBackoff {
|
||||
backoff = w.maxBackoff
|
||||
}
|
||||
}
|
||||
} else {
|
||||
w.failuresMu.Lock()
|
||||
w.consecutiveFailures = 0
|
||||
w.failuresMu.Unlock()
|
||||
backoff = w.reconnectBackoff
|
||||
// Try to flush pending data
|
||||
w.flushPendingData()
|
||||
}
|
||||
|
||||
case <-w.queueClose:
|
||||
w.flushPendingData()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reportError reports a socket connection error via the configured callback
|
||||
func (w *UnixSocketWriter) reportError(err error, attempt int) {
|
||||
if w.errorCallback != nil {
|
||||
w.errorCallback(w.socketPath, err, attempt)
|
||||
}
|
||||
}
|
||||
|
||||
// flushPendingData attempts to write any pending data
|
||||
func (w *UnixSocketWriter) flushPendingData() {
|
||||
w.pendingMu.Lock()
|
||||
pending := w.pendingWrites
|
||||
w.pendingWrites = make([][]byte, 0)
|
||||
w.pendingMu.Unlock()
|
||||
|
||||
for _, data := range pending {
|
||||
if err := w.writeWithReconnect(data); err != nil {
|
||||
// Put it back for next flush attempt
|
||||
w.pendingMu.Lock()
|
||||
if len(w.pendingWrites) < DefaultQueueSize {
|
||||
w.pendingWrites = append(w.pendingWrites, data)
|
||||
}
|
||||
w.pendingMu.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeWithReconnect attempts to write data with reconnection logic
|
||||
func (w *UnixSocketWriter) writeWithReconnect(data []byte) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
ensureConn := func() error {
|
||||
if w.conn != nil {
|
||||
return nil
|
||||
}
|
||||
// Use unixgram (DGRAM) for connectionless UDP-like socket communication
|
||||
conn, err := net.DialTimeout("unixgram", w.socketPath, w.dialTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
|
||||
}
|
||||
w.conn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := ensureConn(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
|
||||
return fmt.Errorf("failed to set write deadline: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.conn.Write(data); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connection failed, try to reconnect
|
||||
_ = w.conn.Close()
|
||||
w.conn = nil
|
||||
|
||||
if err := ensureConn(); err != nil {
|
||||
return fmt.Errorf("failed to reconnect: %w", err)
|
||||
}
|
||||
|
||||
if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
|
||||
_ = w.conn.Close()
|
||||
w.conn = nil
|
||||
return fmt.Errorf("failed to set write deadline after reconnect: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.conn.Write(data); err != nil {
|
||||
_ = w.conn.Close()
|
||||
w.conn = nil
|
||||
return fmt.Errorf("failed to write after reconnect: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes a log record to the UNIX socket (non-blocking with queue).
|
||||
// Bug 12 fix: marshal JSON outside the lock, then hold mutex through both the
|
||||
// isClosed check AND the non-blocking channel send so Close() cannot close the
|
||||
// channel between those two operations.
|
||||
func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
|
||||
data, err := json.Marshal(rec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal record: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
if w.isClosed {
|
||||
return fmt.Errorf("writer is closed")
|
||||
}
|
||||
select {
|
||||
case w.queue <- data:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("write queue is full, dropping message")
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the UNIX socket connection and stops the queue processor.
|
||||
// Bug 12 fix: set isClosed=true under mutex BEFORE closing the channel so a
|
||||
// concurrent Write() sees the flag and returns early instead of panicking on
|
||||
// a send-on-closed-channel.
|
||||
func (w *UnixSocketWriter) Close() error {
|
||||
w.closeOnce.Do(func() {
|
||||
w.mutex.Lock()
|
||||
w.isClosed = true
|
||||
w.mutex.Unlock()
|
||||
|
||||
close(w.queueClose)
|
||||
<-w.queueDone
|
||||
close(w.queue)
|
||||
|
||||
w.mutex.Lock()
|
||||
if w.conn != nil {
|
||||
w.conn.Close()
|
||||
w.conn = nil
|
||||
}
|
||||
w.mutex.Unlock()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// MultiWriter combines multiple writers
|
||||
type MultiWriter struct {
|
||||
writers []api.Writer
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewMultiWriter creates a new multi-writer
|
||||
func NewMultiWriter() *MultiWriter {
|
||||
return &MultiWriter{
|
||||
writers: make([]api.Writer, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes a log record to all writers
|
||||
func (mw *MultiWriter) Write(rec api.LogRecord) error {
|
||||
mw.mutex.Lock()
|
||||
defer mw.mutex.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for _, w := range mw.writers {
|
||||
if err := w.Write(rec); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Add adds a writer to the multi-writer
|
||||
func (mw *MultiWriter) Add(writer api.Writer) {
|
||||
mw.mutex.Lock()
|
||||
defer mw.mutex.Unlock()
|
||||
mw.writers = append(mw.writers, writer)
|
||||
}
|
||||
|
||||
// CloseAll closes all writers
|
||||
func (mw *MultiWriter) CloseAll() error {
|
||||
mw.mutex.Lock()
|
||||
defer mw.mutex.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for _, w := range mw.writers {
|
||||
if closer, ok := w.(io.Closer); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Reopen reopens all writers that support log rotation
|
||||
func (mw *MultiWriter) Reopen() error {
|
||||
mw.mutex.Lock()
|
||||
defer mw.mutex.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for _, w := range mw.writers {
|
||||
if reopenable, ok := w.(api.Reopenable); ok {
|
||||
if err := reopenable.Reopen(); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// BuilderImpl implements the api.Builder interface
|
||||
type BuilderImpl struct {
|
||||
errorCallback ErrorCallback
|
||||
}
|
||||
|
||||
// NewBuilder creates a new output builder
|
||||
func NewBuilder() *BuilderImpl {
|
||||
return &BuilderImpl{}
|
||||
}
|
||||
|
||||
// WithErrorCallback sets an error callback for all unix_socket and file writers created by this builder
|
||||
func (b *BuilderImpl) WithErrorCallback(cb ErrorCallback) *BuilderImpl {
|
||||
b.errorCallback = cb
|
||||
return b
|
||||
}
|
||||
|
||||
// NewFromConfig constructs writers from AppConfig
|
||||
// Uses AsyncBuffer from OutputConfig if specified, otherwise uses DefaultQueueSize
|
||||
func (b *BuilderImpl) NewFromConfig(cfg api.AppConfig) (api.Writer, error) {
|
||||
multiWriter := NewMultiWriter()
|
||||
|
||||
for _, outputCfg := range cfg.Outputs {
|
||||
if !outputCfg.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var writer api.Writer
|
||||
var err error
|
||||
|
||||
// Determine queue size: use AsyncBuffer if specified, otherwise default
|
||||
queueSize := DefaultQueueSize
|
||||
if outputCfg.AsyncBuffer > 0 {
|
||||
queueSize = outputCfg.AsyncBuffer
|
||||
}
|
||||
|
||||
switch outputCfg.Type {
|
||||
case "stdout":
|
||||
writer = NewStdoutWriter()
|
||||
case "file":
|
||||
path := outputCfg.Params["path"]
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("file output requires 'path' parameter")
|
||||
}
|
||||
// Build options list for file writer
|
||||
var fileOpts []FileWriterOption
|
||||
if b.errorCallback != nil {
|
||||
fileOpts = append(fileOpts, WithFileErrorCallback(b.errorCallback))
|
||||
}
|
||||
writer, err = NewFileWriterWithConfig(path, DefaultMaxFileSize, DefaultMaxBackups, fileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "unix_socket":
|
||||
socketPath := outputCfg.Params["socket_path"]
|
||||
if socketPath == "" {
|
||||
return nil, fmt.Errorf("unix_socket output requires 'socket_path' parameter")
|
||||
}
|
||||
// Build options list
|
||||
var opts []UnixSocketWriterOption
|
||||
if b.errorCallback != nil {
|
||||
opts = append(opts, WithErrorCallback(b.errorCallback))
|
||||
}
|
||||
writer, err = NewUnixSocketWriterWithConfig(socketPath, DefaultDialTimeout, DefaultWriteTimeout, queueSize, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown output type: %s", outputCfg.Type)
|
||||
}
|
||||
|
||||
multiWriter.Add(writer)
|
||||
}
|
||||
|
||||
// If no outputs configured, default to stdout
|
||||
if len(multiWriter.writers) == 0 {
|
||||
multiWriter.Add(NewStdoutWriter())
|
||||
}
|
||||
|
||||
return multiWriter, nil
|
||||
}
|
||||
1110
services/sentinel/internal/output/writers_test.go
Normal file
1110
services/sentinel/internal/output/writers_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1008
services/sentinel/internal/tlsparse/parser.go
Normal file
1008
services/sentinel/internal/tlsparse/parser.go
Normal file
File diff suppressed because it is too large
Load Diff
1824
services/sentinel/internal/tlsparse/parser_test.go
Normal file
1824
services/sentinel/internal/tlsparse/parser_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user