fix(ja4ebpf): split bpf2go generate into Ja4Tc + Ja4Ssl, fix RPM systemd-rpm-macros

- Use two separate //go:generate directives (Ja4Tc for tc_capture.c, Ja4Ssl
  for uprobe_ssl.c) to avoid duplicate LICENSE symbol and multi-file clang issue
- Update loader.go to hold tcObjs/sslObjs separately with correct field names:
  UprobeSslSetFd, UprobeSslReadEntry, UretprobeSslReadExit,
  KprobeAccept4Entry, KretprobeAccept4Exit
- Add systemd-rpm-macros to all three RPM build stages (el8/el9/el10)
  so that %{_unitdir} macro resolves correctly
- RPMs now build successfully for el8, el9, el10

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
toto
2026-04-11 23:21:11 +02:00
parent a1e4c1dad5
commit 3b047b680a
155 changed files with 197011 additions and 599 deletions

View File

@ -0,0 +1,397 @@
// Package capture provides network packet capture functionality for ja4sentinel
package capture
import (
"fmt"
"log"
"net"
"regexp"
"strings"
"sync"
"sync/atomic"
"github.com/google/gopacket"
"github.com/google/gopacket/pcap"
"github.com/antitbone/ja4/sentinel/api"
)
// Capture configuration constants
const (
// DefaultSnapLen is the default snapshot length for packet capture
// Increased from 1600 to 65535 to capture full packets including large TLS handshakes
DefaultSnapLen = 65535
// DefaultPromiscuous is the default promiscuous mode setting
DefaultPromiscuous = false
// MaxBPFFilterLength is the maximum allowed length for BPF filters
MaxBPFFilterLength = 1024
)
// validBPFPattern checks if a BPF filter contains only valid characters
// This is a basic validation to prevent injection attacks
var validBPFPattern = regexp.MustCompile(`^[a-zA-Z0-9\s\(\)\-\_\.\*\+\?\:\=\!\&\|\<\>\[\]\/\@,]+$`)
// CaptureImpl implements the capture.Capture interface for packet capture
type CaptureImpl struct {
handle *pcap.Handle
mu sync.Mutex
snapLen int
promisc bool
isClosed bool
localIPs []string // Local IPs to filter (dst host)
linkType int // Link type from pcap handle
interfaceName string // Interface name (for diagnostics)
bpfFilter string // Applied BPF filter (for diagnostics)
// Metrics counters (atomic)
packetsReceived uint64 // Total packets received from interface
packetsSent uint64 // Total packets sent to channel
packetsDropped uint64 // Total packets dropped (channel full)
}
// New creates a new capture instance
func New() *CaptureImpl {
return &CaptureImpl{
snapLen: DefaultSnapLen,
promisc: DefaultPromiscuous,
}
}
// NewWithSnapLen creates a new capture instance with custom snapshot length
func NewWithSnapLen(snapLen int) *CaptureImpl {
if snapLen <= 0 || snapLen > 65535 {
snapLen = DefaultSnapLen
}
return &CaptureImpl{
snapLen: snapLen,
promisc: DefaultPromiscuous,
}
}
// Run starts network packet capture according to the configuration
func (c *CaptureImpl) Run(cfg api.Config, out chan<- api.RawPacket) error {
// Validate interface name (basic check)
if cfg.Interface == "" {
return fmt.Errorf("interface cannot be empty")
}
// Find available interfaces to validate the interface exists
ifaces, err := pcap.FindAllDevs()
if err != nil {
return fmt.Errorf("failed to list network interfaces: %w", err)
}
// Special handling for "any" interface
interfaceFound := cfg.Interface == "any"
if !interfaceFound {
for _, iface := range ifaces {
if iface.Name == cfg.Interface {
interfaceFound = true
break
}
}
}
if !interfaceFound {
return fmt.Errorf("interface %s not found (available: %v)", cfg.Interface, getInterfaceNames(ifaces))
}
handle, err := pcap.OpenLive(cfg.Interface, int32(c.snapLen), c.promisc, pcap.BlockForever)
if err != nil {
return fmt.Errorf("failed to open interface %s: %w", cfg.Interface, err)
}
c.mu.Lock()
c.handle = handle
c.mu.Unlock()
defer func() {
c.mu.Lock()
if c.handle != nil && !c.isClosed {
c.handle.Close()
c.handle = nil
}
c.mu.Unlock()
}()
// Store interface name for diagnostics
c.interfaceName = cfg.Interface
// Resolve local IPs for filtering (if not manually specified)
localIPs := cfg.LocalIPs
if len(localIPs) == 0 {
localIPs, err = c.detectLocalIPs(cfg.Interface)
if err != nil {
return fmt.Errorf("failed to detect local IPs: %w", err)
}
if len(localIPs) == 0 {
// NAT/VIP: destination IP may not be assigned to this interface.
// Fall back to port-only BPF filter instead of aborting.
log.Printf("WARN capture: no local IPs found on interface %s; using port-only BPF filter (NAT/VIP mode)", cfg.Interface)
}
}
c.localIPs = localIPs
// Build and apply BPF filter
bpfFilter := cfg.BPFFilter
if bpfFilter == "" {
bpfFilter = c.buildBPFFilter(cfg.ListenPorts, localIPs)
}
c.bpfFilter = bpfFilter
// Validate BPF filter before applying
if err := validateBPFFilter(bpfFilter); err != nil {
return fmt.Errorf("invalid BPF filter: %w", err)
}
err = handle.SetBPFFilter(bpfFilter)
if err != nil {
return fmt.Errorf("failed to set BPF filter '%s': %w", bpfFilter, err)
}
// Store link type once, after the handle is fully configured (BPF filter applied).
// A single write avoids the race where packetToRawPacket reads a stale value
// that existed before the BPF filter was set.
c.mu.Lock()
c.linkType = int(handle.LinkType())
c.mu.Unlock()
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
for packet := range packetSource.Packets() {
// Convert packet to RawPacket
rawPkt := c.packetToRawPacket(packet)
if rawPkt != nil {
atomic.AddUint64(&c.packetsReceived, 1)
select {
case out <- *rawPkt:
// Packet sent successfully
atomic.AddUint64(&c.packetsSent, 1)
default:
// Channel full, drop packet
atomic.AddUint64(&c.packetsDropped, 1)
}
}
}
return nil
}
// validateBPFFilter performs basic validation of BPF filter strings
func validateBPFFilter(filter string) error {
if filter == "" {
return nil
}
if len(filter) > MaxBPFFilterLength {
return fmt.Errorf("BPF filter too long (max %d characters)", MaxBPFFilterLength)
}
// Check for potentially dangerous patterns
if !validBPFPattern.MatchString(filter) {
return fmt.Errorf("BPF filter contains invalid characters")
}
// Check for unbalanced parentheses
openParens := 0
for _, ch := range filter {
if ch == '(' {
openParens++
} else if ch == ')' {
openParens--
if openParens < 0 {
return fmt.Errorf("BPF filter has unbalanced parentheses")
}
}
}
if openParens != 0 {
return fmt.Errorf("BPF filter has unbalanced parentheses")
}
return nil
}
// getInterfaceNames extracts interface names from a list of devices
func getInterfaceNames(ifaces []pcap.Interface) []string {
names := make([]string, len(ifaces))
for i, iface := range ifaces {
names[i] = iface.Name
}
return names
}
// detectLocalIPs detects local IP addresses on the specified interface
// Excludes loopback addresses (127.0.0.0/8, ::1) and IPv6 link-local (fe80::)
func (c *CaptureImpl) detectLocalIPs(interfaceName string) ([]string, error) {
var localIPs []string
// Special case: "any" interface - get all non-loopback IPs
if interfaceName == "any" {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("failed to list interfaces: %w", err)
}
for _, iface := range ifaces {
// Skip loopback interfaces
if iface.Flags&net.FlagLoopback != 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue // Skip this interface, try others
}
for _, addr := range addrs {
ip := extractIP(addr)
if ip != nil && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() {
localIPs = append(localIPs, ip.String())
}
}
}
return localIPs, nil
}
// Specific interface - get IPs from that interface only
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
}
addrs, err := iface.Addrs()
if err != nil {
return nil, fmt.Errorf("failed to get addresses for %s: %w", interfaceName, err)
}
for _, addr := range addrs {
ip := extractIP(addr)
if ip != nil && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() {
localIPs = append(localIPs, ip.String())
}
}
return localIPs, nil
}
// extractIP extracts the IP address from a net.Addr
func extractIP(addr net.Addr) net.IP {
switch v := addr.(type) {
case *net.IPNet:
ip := v.IP
// Return IPv4 as 4-byte, IPv6 as 16-byte
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
case *net.IPAddr:
ip := v.IP
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
}
return nil
}
// buildBPFFilter builds a BPF filter for the specified ports and local IPs
// Filter: (tcp dst port 443 or tcp dst port 8443) and (dst host 192.168.1.10 or dst host 10.0.0.5)
// Uses "tcp dst port" to only capture client→server traffic (not server→client responses)
func (c *CaptureImpl) buildBPFFilter(ports []uint16, localIPs []string) string {
if len(ports) == 0 {
return "tcp"
}
// Build port filter (dst port only to avoid capturing server responses)
portParts := make([]string, len(ports))
for i, port := range ports {
portParts[i] = fmt.Sprintf("tcp dst port %d", port)
}
portFilter := "(" + strings.Join(portParts, ") or (") + ")"
// Build destination host filter
if len(localIPs) == 0 {
return portFilter
}
hostParts := make([]string, len(localIPs))
for i, ip := range localIPs {
// Handle IPv6 addresses
if strings.Contains(ip, ":") {
hostParts[i] = fmt.Sprintf("dst host %s", ip)
} else {
hostParts[i] = fmt.Sprintf("dst host %s", ip)
}
}
hostFilter := "(" + strings.Join(hostParts, ") or (") + ")"
// Combine port and host filters
return portFilter + " and " + hostFilter
}
// joinString joins strings with a separator (kept for backward compatibility)
func joinString(parts []string, sep string) string {
if len(parts) == 0 {
return ""
}
result := parts[0]
for _, part := range parts[1:] {
result += sep + part
}
return result
}
// packetToRawPacket converts a gopacket packet to RawPacket
// Uses the raw packet bytes from the link layer
func (c *CaptureImpl) packetToRawPacket(packet gopacket.Packet) *api.RawPacket {
// Try to get link layer contents + payload for full packet
var data []byte
linkLayer := packet.LinkLayer()
if linkLayer != nil {
// Combine link layer contents with payload to get full packet
data = append(data, linkLayer.LayerContents()...)
data = append(data, linkLayer.LayerPayload()...)
} else {
// Fallback to packet.Data()
data = packet.Data()
}
if len(data) == 0 {
return nil
}
return &api.RawPacket{
Data: data,
Timestamp: packet.Metadata().Timestamp.UnixNano(),
LinkType: c.linkType,
}
}
// Close properly closes the capture handle
func (c *CaptureImpl) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.handle != nil && !c.isClosed {
c.handle.Close()
c.handle = nil
c.isClosed = true
return nil
}
c.isClosed = true
return nil
}
// GetStats returns capture statistics (for monitoring/debugging)
func (c *CaptureImpl) GetStats() (received, sent, dropped uint64) {
return atomic.LoadUint64(&c.packetsReceived),
atomic.LoadUint64(&c.packetsSent),
atomic.LoadUint64(&c.packetsDropped)
}
// GetDiagnostics returns capture diagnostics information (for debugging)
func (c *CaptureImpl) GetDiagnostics() (interfaceName string, localIPs []string, bpfFilter string, linkType int) {
c.mu.Lock()
defer c.mu.Unlock()
return c.interfaceName, c.localIPs, c.bpfFilter, c.linkType
}

View File

@ -0,0 +1,661 @@
package capture
import (
"net"
"strings"
"testing"
"time"
"github.com/antitbone/ja4/sentinel/api"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcap"
)
func TestCaptureImpl_Run_EmptyInterface(t *testing.T) {
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
cfg := api.Config{
Interface: "",
ListenPorts: []uint16{443},
}
out := make(chan api.RawPacket, 10)
err := c.Run(cfg, out)
if err == nil {
t.Error("Run() with empty interface should return error")
}
if err.Error() != "interface cannot be empty" {
t.Errorf("Run() error = %v, want 'interface cannot be empty'", err)
}
}
func TestCaptureImpl_Run_NonExistentInterface(t *testing.T) {
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
cfg := api.Config{
Interface: "nonexistent_interface_xyz123",
ListenPorts: []uint16{443},
}
out := make(chan api.RawPacket, 10)
err := c.Run(cfg, out)
if err == nil {
t.Error("Run() with non-existent interface should return error")
}
}
func TestCaptureImpl_Run_InvalidBPFFilter(t *testing.T) {
// Get a real interface name
ifaces, err := pcap.FindAllDevs()
if err != nil || len(ifaces) == 0 {
t.Skip("No network interfaces available for testing")
}
c := New()
cfg := api.Config{
Interface: ifaces[0].Name,
ListenPorts: []uint16{443},
BPFFilter: "invalid; rm -rf /", // Invalid characters
}
out := make(chan api.RawPacket, 10)
err = c.Run(cfg, out)
if err == nil {
t.Error("Run() with invalid BPF filter should return error")
}
}
func TestCaptureImpl_Run_ChannelFull_DropsPackets(t *testing.T) {
// This test verifies that when the output channel is full,
// packets are dropped gracefully (non-blocking write)
// We can't easily test the full Run() loop without real interfaces,
// but we can verify the channel behavior with a small buffer
out := make(chan api.RawPacket, 1)
// Fill the channel
out <- api.RawPacket{Data: []byte{1, 2, 3}, Timestamp: time.Now().UnixNano()}
// Channel should be full now, select default should trigger
done := make(chan bool)
go func() {
select {
case out <- api.RawPacket{Data: []byte{4, 5, 6}, Timestamp: time.Now().UnixNano()}:
done <- false // Would block
default:
done <- true // Dropped as expected
}
}()
dropped := <-done
if !dropped {
t.Error("Expected packet to be dropped when channel is full")
}
}
func TestPacketToRawPacket(t *testing.T) {
t.Run("valid_packet", func(t *testing.T) {
// Create a simple TCP packet
eth := layers.Ethernet{
SrcMAC: []byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
DstMAC: []byte{0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB},
EthernetType: layers.EthernetTypeIPv4,
}
ip := layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: []byte{192, 168, 1, 1},
DstIP: []byte{10, 0, 0, 1},
}
tcp := layers.TCP{
SrcPort: 12345,
DstPort: 443,
}
tcp.SetNetworkLayerForChecksum(&ip)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{}
gopacket.SerializeLayers(buf, opts, &eth, &ip, &tcp)
packet := gopacket.NewPacket(buf.Bytes(), layers.LinkTypeEthernet, gopacket.Default)
// Create capture instance for method call
c := New()
rawPkt := c.packetToRawPacket(packet)
if rawPkt == nil {
t.Fatal("packetToRawPacket() returned nil for valid packet")
}
if len(rawPkt.Data) == 0 {
t.Error("packetToRawPacket() returned empty data")
}
if rawPkt.Timestamp == 0 {
t.Error("packetToRawPacket() returned zero timestamp")
}
})
t.Run("empty_packet", func(t *testing.T) {
// Create packet with no data
packet := gopacket.NewPacket([]byte{}, layers.LinkTypeEthernet, gopacket.Default)
c := New()
rawPkt := c.packetToRawPacket(packet)
if rawPkt != nil {
t.Error("packetToRawPacket() should return nil for empty packet")
}
})
t.Run("nil_packet", func(t *testing.T) {
// packetToRawPacket will panic with nil packet due to Metadata() call
// This is expected behavior - the function is not designed to handle nil
defer func() {
if r := recover(); r == nil {
t.Error("packetToRawPacket() with nil packet should panic")
}
}()
c := New()
var packet gopacket.Packet
_ = c.packetToRawPacket(packet)
})
}
func TestGetInterfaceNames(t *testing.T) {
t.Run("empty_list", func(t *testing.T) {
names := getInterfaceNames([]pcap.Interface{})
if len(names) != 0 {
t.Errorf("getInterfaceNames() with empty list = %v, want []", names)
}
})
t.Run("single_interface", func(t *testing.T) {
ifaces := []pcap.Interface{
{Name: "eth0"},
}
names := getInterfaceNames(ifaces)
if len(names) != 1 || names[0] != "eth0" {
t.Errorf("getInterfaceNames() = %v, want [eth0]", names)
}
})
t.Run("multiple_interfaces", func(t *testing.T) {
ifaces := []pcap.Interface{
{Name: "eth0"},
{Name: "lo"},
{Name: "docker0"},
}
names := getInterfaceNames(ifaces)
if len(names) != 3 {
t.Errorf("getInterfaceNames() returned %d names, want 3", len(names))
}
expected := []string{"eth0", "lo", "docker0"}
for i, name := range names {
if name != expected[i] {
t.Errorf("getInterfaceNames()[%d] = %s, want %s", i, name, expected[i])
}
}
})
}
func TestValidateBPFFilter(t *testing.T) {
tests := []struct {
name string
filter string
wantErr bool
}{
{
name: "empty filter",
filter: "",
wantErr: false,
},
{
name: "valid simple filter",
filter: "tcp port 443",
wantErr: false,
},
{
name: "valid complex filter",
filter: "(tcp port 443) or (tcp port 8443)",
wantErr: false,
},
{
name: "filter with special chars",
filter: "tcp port 443 and host 192.168.1.1",
wantErr: false,
},
{
name: "too long filter",
filter: string(make([]byte, MaxBPFFilterLength+1)),
wantErr: true,
},
{
name: "unbalanced parentheses - extra open",
filter: "(tcp port 443",
wantErr: true,
},
{
name: "unbalanced parentheses - extra close",
filter: "tcp port 443)",
wantErr: true,
},
{
name: "invalid characters - semicolon",
filter: "tcp port 443; rm -rf /",
wantErr: true,
},
{
name: "invalid characters - backtick",
filter: "tcp port `whoami`",
wantErr: true,
},
{
name: "invalid characters - dollar",
filter: "tcp port $HOME",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateBPFFilter(tt.filter)
if (err != nil) != tt.wantErr {
t.Errorf("validateBPFFilter() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestJoinString(t *testing.T) {
tests := []struct {
name string
parts []string
sep string
want string
}{
{
name: "empty slice",
parts: []string{},
sep: ") or (",
want: "",
},
{
name: "single element",
parts: []string{"tcp port 443"},
sep: ") or (",
want: "tcp port 443",
},
{
name: "multiple elements",
parts: []string{"tcp port 443", "tcp port 8443"},
sep: ") or (",
want: "tcp port 443) or (tcp port 8443",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := joinString(tt.parts, tt.sep)
if got != tt.want {
t.Errorf("joinString() = %v, want %v", got, tt.want)
}
})
}
}
func TestNewCapture(t *testing.T) {
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
if c.snapLen != DefaultSnapLen {
t.Errorf("snapLen = %d, want %d", c.snapLen, DefaultSnapLen)
}
if c.promisc != DefaultPromiscuous {
t.Errorf("promisc = %v, want %v", c.promisc, DefaultPromiscuous)
}
}
func TestNewWithSnapLen(t *testing.T) {
tests := []struct {
name string
snapLen int
wantSnapLen int
}{
{
name: "valid snapLen",
snapLen: 2048,
wantSnapLen: 2048,
},
{
name: "zero snapLen uses default",
snapLen: 0,
wantSnapLen: DefaultSnapLen,
},
{
name: "negative snapLen uses default",
snapLen: -100,
wantSnapLen: DefaultSnapLen,
},
{
name: "too large snapLen uses default",
snapLen: 100000,
wantSnapLen: DefaultSnapLen,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewWithSnapLen(tt.snapLen)
if c == nil {
t.Fatal("NewWithSnapLen() returned nil")
}
if c.snapLen != tt.wantSnapLen {
t.Errorf("snapLen = %d, want %d", c.snapLen, tt.wantSnapLen)
}
})
}
}
func TestCaptureImpl_Close(t *testing.T) {
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
// Close should not panic on fresh instance
if err := c.Close(); err != nil {
t.Errorf("Close() error = %v", err)
}
// Multiple closes should be safe
if err := c.Close(); err != nil {
t.Errorf("Close() second call error = %v", err)
}
}
func TestValidateBPFFilter_BalancedParentheses(t *testing.T) {
// Test various balanced parentheses scenarios
validFilters := []string{
"(tcp port 443)",
"((tcp port 443))",
"(tcp port 443) or (tcp port 8443)",
"((tcp port 443) or (tcp port 8443))",
"(tcp port 443 and host 1.2.3.4) or (tcp port 8443)",
}
for _, filter := range validFilters {
t.Run(filter, func(t *testing.T) {
if err := validateBPFFilter(filter); err != nil {
t.Errorf("validateBPFFilter(%q) unexpected error = %v", filter, err)
}
})
}
}
func TestCaptureImpl_detectLocalIPs(t *testing.T) {
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
t.Run("any_interface", func(t *testing.T) {
ips, err := c.detectLocalIPs("any")
if err != nil {
t.Errorf("detectLocalIPs(any) error = %v", err)
}
// Should return at least one non-loopback IP or empty if none available
for _, ip := range ips {
if ip == "127.0.0.1" || ip == "::1" {
t.Errorf("detectLocalIPs(any) should exclude loopback, got %s", ip)
}
}
})
t.Run("loopback_excluded", func(t *testing.T) {
ips, err := c.detectLocalIPs("any")
if err != nil {
t.Skipf("Skipping loopback test: %v", err)
}
// Verify no loopback addresses are included
for _, ip := range ips {
if ip == "127.0.0.1" {
t.Error("detectLocalIPs should exclude 127.0.0.1")
}
}
})
}
func TestCaptureImpl_detectLocalIPs_SpecificInterface(t *testing.T) {
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
// Test with a non-existent interface
_, err := c.detectLocalIPs("nonexistent_interface_xyz")
if err == nil {
t.Error("detectLocalIPs with non-existent interface should return error")
}
}
func TestCaptureImpl_extractIP(t *testing.T) {
tests := []struct {
name string
addr net.Addr
wantIPv4 bool
wantIPv6 bool
}{
{
name: "IPv4",
addr: &net.IPNet{
IP: net.ParseIP("192.168.1.10"),
Mask: net.CIDRMask(24, 32),
},
wantIPv4: true,
},
{
name: "IPv6",
addr: &net.IPNet{
IP: net.ParseIP("2001:db8::1"),
Mask: net.CIDRMask(64, 128),
},
wantIPv6: true,
},
{
name: "nil",
addr: nil,
wantIPv4: false,
wantIPv6: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractIP(tt.addr)
if tt.wantIPv4 {
if got == nil || got.To4() == nil {
t.Error("extractIP() should return IPv4 address")
}
}
if tt.wantIPv6 {
if got == nil || got.To4() != nil {
t.Error("extractIP() should return IPv6 address")
}
}
if !tt.wantIPv4 && !tt.wantIPv6 {
if got != nil {
t.Error("extractIP() should return nil for nil address")
}
}
})
}
}
func TestCaptureImpl_buildBPFFilter(t *testing.T) {
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
tests := []struct {
name string
ports []uint16
localIPs []string
wantParts []string // Parts that should be in the filter
}{
{
name: "no ports",
ports: []uint16{},
localIPs: []string{},
wantParts: []string{"tcp"},
},
{
name: "single port no IPs",
ports: []uint16{443},
localIPs: []string{},
wantParts: []string{"tcp dst port 443"},
},
{
name: "single port with single IP",
ports: []uint16{443},
localIPs: []string{"192.168.1.10"},
wantParts: []string{"tcp dst port 443", "dst host 192.168.1.10"},
},
{
name: "multiple ports with multiple IPs",
ports: []uint16{443, 8443},
localIPs: []string{"192.168.1.10", "10.0.0.5"},
wantParts: []string{"tcp dst port 443", "tcp dst port 8443", "dst host 192.168.1.10", "dst host 10.0.0.5"},
},
{
name: "IPv6 address",
ports: []uint16{443},
localIPs: []string{"2001:db8::1"},
wantParts: []string{"tcp dst port 443", "dst host 2001:db8::1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.buildBPFFilter(tt.ports, tt.localIPs)
for _, part := range tt.wantParts {
if !strings.Contains(got, part) {
t.Errorf("buildBPFFilter() = %q, should contain %q", got, part)
}
}
})
}
}
func TestCaptureImpl_Run_AnyInterface(t *testing.T) {
t.Skip("integration: pcap on 'any' interface blocks until close; run with -run=Integration in a real network env")
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
cfg := api.Config{
Interface: "any",
ListenPorts: []uint16{443},
LocalIPs: []string{"192.168.1.10"},
}
out := make(chan api.RawPacket, 10)
errCh := make(chan error, 1)
go func() { errCh <- c.Run(cfg, out) }()
// Allow up to 300ms for the handle to open (or fail immediately)
select {
case err := <-errCh:
// Immediate error: permission or "not found"
if err != nil && strings.Contains(err.Error(), "not found") {
t.Errorf("Run() with 'any' interface should be valid, got: %v", err)
}
case <-time.After(300 * time.Millisecond):
// Run() started successfully (blocking on packets) — close to stop it
c.Close()
}
}
func TestCaptureImpl_Run_WithManualLocalIPs(t *testing.T) {
t.Skip("integration: pcap on 'any' interface blocks until close; run with -run=Integration in a real network env")
c := New()
if c == nil {
t.Fatal("New() returned nil")
}
cfg := api.Config{
Interface: "any",
ListenPorts: []uint16{443},
LocalIPs: []string{"192.168.1.10", "10.0.0.5"},
}
out := make(chan api.RawPacket, 10)
errCh := make(chan error, 1)
go func() { errCh <- c.Run(cfg, out) }()
select {
case err := <-errCh:
if err != nil && strings.Contains(err.Error(), "not found") {
t.Errorf("Run() with manual LocalIPs should be valid, got: %v", err)
}
case <-time.After(300 * time.Millisecond):
c.Close()
}
}
// TestCaptureImpl_LinkTypeInitializedOnce verifies that linkType is set exactly once,
// after the BPF filter is applied (Bug 2 fix: removed the redundant early assignment).
func TestCaptureImpl_LinkTypeInitializedOnce(t *testing.T) {
c := New()
// Fresh instance: linkType must be zero before Run() is called.
if c.linkType != 0 {
t.Errorf("new CaptureImpl should have linkType=0, got %d", c.linkType)
}
// GetDiagnostics reflects linkType correctly.
_, _, _, lt := c.GetDiagnostics()
if lt != 0 {
t.Errorf("GetDiagnostics() linkType before Run() should be 0, got %d", lt)
}
// Simulate what Run() does: set linkType once under the mutex.
c.mu.Lock()
c.linkType = 1 // 1 = Ethernet
c.mu.Unlock()
_, _, _, lt = c.GetDiagnostics()
if lt != 1 {
t.Errorf("GetDiagnostics() linkType after set = %d, want 1", lt)
}
}
// TestBuildBPFFilter_NoLocalIPs verifies Bug 3 fix: when no local IPs are
// available (NAT/VIP), buildBPFFilter returns a port-only filter.
func TestBuildBPFFilter_NoLocalIPs(t *testing.T) {
c := New()
filter := c.buildBPFFilter([]uint16{443}, nil)
if strings.Contains(filter, "dst host") {
t.Errorf("port-only filter expected when localIPs nil, got: %s", filter)
}
if !strings.Contains(filter, "tcp dst port 443") {
t.Errorf("expected tcp dst port 443, got: %s", filter)
}
}
func TestBuildBPFFilter_EmptyLocalIPs(t *testing.T) {
c := New()
filter := c.buildBPFFilter([]uint16{443, 8443}, []string{})
if strings.Contains(filter, "dst host") {
t.Errorf("port-only filter expected when localIPs empty, got: %s", filter)
}
if !strings.Contains(filter, "tcp dst port 443") || !strings.Contains(filter, "tcp dst port 8443") {
t.Errorf("expected both ports in filter, got: %s", filter)
}
}

View File

@ -0,0 +1,295 @@
// Package config provides configuration loading and validation for ja4sentinel
package config
import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
"gopkg.in/yaml.v3"
"github.com/antitbone/ja4/sentinel/api"
)
// LoaderImpl implements the api.Loader interface for configuration loading
type LoaderImpl struct {
configPath string
}
// NewLoader creates a new configuration loader
func NewLoader(configPath string) *LoaderImpl {
return &LoaderImpl{
configPath: configPath,
}
}
// Load reads and merges configuration from file, environment variables, and CLI
func (l *LoaderImpl) Load() (api.AppConfig, error) {
config := api.DefaultConfig()
path := l.configPath
explicit := path != ""
if !explicit {
path = "config.yml"
}
fileConfig, err := l.loadFromFile(path)
if err == nil {
config = mergeConfigs(config, fileConfig)
} else if !(!explicit && errors.Is(err, os.ErrNotExist)) {
return config, fmt.Errorf("failed to load config file: %w", err)
}
// Override with environment variables
config = l.loadFromEnv(config)
// Validate the final configuration
if err := l.validate(config); err != nil {
return config, fmt.Errorf("invalid configuration: %w", err)
}
return config, nil
}
// loadFromFile reads configuration from a YAML file
func (l *LoaderImpl) loadFromFile(path string) (api.AppConfig, error) {
config := api.AppConfig{}
data, err := os.ReadFile(path)
if err != nil {
return config, fmt.Errorf("failed to read config file: %w", err)
}
err = yaml.Unmarshal(data, &config)
if err != nil {
return config, fmt.Errorf("failed to parse config file: %w", err)
}
return config, nil
}
// loadFromEnv overrides configuration with environment variables
func (l *LoaderImpl) loadFromEnv(config api.AppConfig) api.AppConfig {
// JA4SENTINEL_INTERFACE
if val := os.Getenv("JA4SENTINEL_INTERFACE"); val != "" {
config.Core.Interface = val
}
// JA4SENTINEL_PORTS (comma-separated list)
if val := os.Getenv("JA4SENTINEL_PORTS"); val != "" {
ports := parsePorts(val)
if len(ports) > 0 {
config.Core.ListenPorts = ports
}
}
// JA4SENTINEL_BPF_FILTER
if val := os.Getenv("JA4SENTINEL_BPF_FILTER"); val != "" {
config.Core.BPFFilter = val
}
// JA4SENTINEL_FLOW_TIMEOUT (in seconds)
if val := os.Getenv("JA4SENTINEL_FLOW_TIMEOUT"); val != "" {
if timeout, err := strconv.Atoi(val); err == nil && timeout > 0 {
config.Core.FlowTimeoutSec = timeout
}
}
// JA4SENTINEL_PACKET_BUFFER_SIZE
if val := os.Getenv("JA4SENTINEL_PACKET_BUFFER_SIZE"); val != "" {
if size, err := strconv.Atoi(val); err == nil && size > 0 {
config.Core.PacketBufferSize = size
}
}
// Note: JA4SENTINEL_LOG_LEVEL is intentionally NOT loaded from env.
// log_level must be configured exclusively via the YAML config file.
return config
}
// parsePorts parses a comma-separated list of ports
func parsePorts(s string) []uint16 {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
ports := make([]uint16, 0, len(parts))
seen := make(map[uint16]struct{}, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
port, err := strconv.ParseUint(part, 10, 16)
if err != nil {
continue
}
p := uint16(port)
if p == 0 {
continue
}
if _, exists := seen[p]; exists {
continue
}
seen[p] = struct{}{}
ports = append(ports, p)
}
return ports
}
// mergeConfigs merges two configs, with override taking precedence
func mergeConfigs(base, override api.AppConfig) api.AppConfig {
result := base
if override.Core.Interface != "" {
result.Core.Interface = override.Core.Interface
}
if len(override.Core.ListenPorts) > 0 {
result.Core.ListenPorts = override.Core.ListenPorts
}
if override.Core.BPFFilter != "" {
result.Core.BPFFilter = override.Core.BPFFilter
}
if override.Core.FlowTimeoutSec > 0 {
result.Core.FlowTimeoutSec = override.Core.FlowTimeoutSec
}
if override.Core.PacketBufferSize > 0 {
result.Core.PacketBufferSize = override.Core.PacketBufferSize
}
if override.Core.LogLevel != "" {
result.Core.LogLevel = override.Core.LogLevel
}
// Merge exclude_source_ips (override takes precedence)
if len(override.Core.ExcludeSourceIPs) > 0 {
result.Core.ExcludeSourceIPs = override.Core.ExcludeSourceIPs
}
if len(override.Outputs) > 0 {
result.Outputs = override.Outputs
}
return result
}
// validate checks if the configuration is valid
func (l *LoaderImpl) validate(config api.AppConfig) error {
if strings.TrimSpace(config.Core.Interface) == "" {
return fmt.Errorf("interface cannot be empty")
}
if len(config.Core.ListenPorts) == 0 {
return fmt.Errorf("at least one listen port is required")
}
for _, p := range config.Core.ListenPorts {
if p == 0 {
return fmt.Errorf("listen port 0 is invalid")
}
}
if config.Core.FlowTimeoutSec <= 0 || config.Core.FlowTimeoutSec > 300 {
return fmt.Errorf("flow_timeout_sec must be between 1 and 300")
}
if config.Core.PacketBufferSize <= 0 || config.Core.PacketBufferSize > 1_000_000 {
return fmt.Errorf("packet_buffer_size must be between 1 and 1000000")
}
// Validate log level
validLogLevels := map[string]struct{}{
"debug": {},
"info": {},
"warn": {},
"error": {},
}
if config.Core.LogLevel != "" {
if _, ok := validLogLevels[config.Core.LogLevel]; !ok {
return fmt.Errorf("log_level must be one of: debug, info, warn, error")
}
}
// Validate exclude_source_ips (if provided)
if len(config.Core.ExcludeSourceIPs) > 0 {
for i, ip := range config.Core.ExcludeSourceIPs {
if ip == "" {
return fmt.Errorf("exclude_source_ips[%d]: entry cannot be empty", i)
}
// Basic validation: check if it looks like an IP or CIDR
if !strings.Contains(ip, "/") {
// Single IP - basic check
if !isValidIP(ip) {
return fmt.Errorf("exclude_source_ips[%d]: invalid IP address %q", i, ip)
}
} else {
// CIDR - basic check
if !isValidCIDR(ip) {
return fmt.Errorf("exclude_source_ips[%d]: invalid CIDR %q", i, ip)
}
}
}
}
allowedTypes := map[string]struct{}{
"stdout": {},
"file": {},
"unix_socket": {},
}
// Validate outputs
for i, output := range config.Outputs {
outputType := strings.TrimSpace(output.Type)
if outputType == "" {
return fmt.Errorf("output[%d]: type cannot be empty", i)
}
if _, ok := allowedTypes[outputType]; !ok {
return fmt.Errorf("output[%d]: unknown type %q", i, outputType)
}
switch outputType {
case "file":
if strings.TrimSpace(output.Params["path"]) == "" {
return fmt.Errorf("output[%d]: file output requires non-empty path", i)
}
case "unix_socket":
if strings.TrimSpace(output.Params["socket_path"]) == "" {
return fmt.Errorf("output[%d]: unix_socket output requires non-empty socket_path", i)
}
}
}
return nil
}
// ToJSON converts config to JSON string for debugging
func ToJSON(config api.AppConfig) string {
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return fmt.Sprintf("error marshaling config: %v", err)
}
return string(data)
}
// isValidIP checks if a string is a valid IP address using net.ParseIP
func isValidIP(ip string) bool {
return net.ParseIP(ip) != nil
}
// isValidCIDR checks if a string is a valid CIDR notation using net.ParseCIDR
func isValidCIDR(cidr string) bool {
_, _, err := net.ParseCIDR(cidr)
return err == nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,192 @@
// Package fingerprint provides JA4/JA3 fingerprint generation for TLS ClientHello
package fingerprint
import (
"encoding/binary"
"fmt"
"strconv"
"strings"
"github.com/antitbone/ja4/sentinel/api"
tlsfingerprint "github.com/psanford/tlsfingerprint"
)
// EngineImpl implements the api.Engine interface for fingerprint generation
type EngineImpl struct{}
// NewEngine creates a new fingerprint engine
func NewEngine() *EngineImpl {
return &EngineImpl{}
}
// FromClientHello generates JA4 (and optionally JA3) fingerprints from a TLS ClientHello
// Note: JA4 hash portion is extracted for internal use but NOT serialized to LogRecord
// as the JA4 format already includes its own hash portions (per architecture.yml)
func (e *EngineImpl) FromClientHello(ch api.TLSClientHello) (*api.Fingerprints, error) {
if len(ch.Payload) == 0 {
return nil, fmt.Errorf("empty ClientHello payload from %s:%d -> %s:%d",
ch.SrcIP, ch.SrcPort, ch.DstIP, ch.DstPort)
}
// Parse the ClientHello using tlsfingerprint
fp, err := tlsfingerprint.ParseClientHello(ch.Payload)
if err != nil {
// Try to sanitize truncated extensions and retry
sanitized := sanitizeClientHelloExtensions(ch.Payload)
if sanitized != nil {
fp, err = tlsfingerprint.ParseClientHello(sanitized)
}
if err != nil {
sanitizeStatus := "unavailable"
if sanitized != nil {
sanitizeStatus = "failed"
}
return nil, fmt.Errorf("fingerprint generation failed for %s:%d -> %s:%d (conn_id=%s, payload_len=%d, tls_version=%s, sni=%s, sanitization=%s): %w",
ch.SrcIP, ch.SrcPort, ch.DstIP, ch.DstPort, ch.ConnID, len(ch.Payload), ch.TLSVersion, ch.SNI, sanitizeStatus, err)
}
}
// Generate JA4 fingerprint
// Note: JA4 string format already includes the hash portion
// e.g., "t13d1516h2_8daaf6152771_02cb136f2775" where the last part is the SHA256 hash
ja4 := fp.JA4String()
// Generate JA3 fingerprint and its MD5 hash
ja3 := fp.JA3String()
ja3Hash := fp.JA3Hash()
// Extract JA4 hash portion (last segment after underscore)
// JA4 format: <tls_ver><ciphers><extensions>_<sni_hash>_<cipher_extension_hash>
// This is kept for internal use but NOT serialized to LogRecord
ja4Hash := extractJA4Hash(ja4)
// Generate JA4T fingerprint from TCP SYN parameters
ja4t := computeJA4T(ch.TCPMeta)
return &api.Fingerprints{
JA4: ja4,
JA4Hash: ja4Hash, // Internal use only - not serialized to LogRecord
JA4T: ja4t,
JA3: ja3,
JA3Hash: ja3Hash,
}, nil
}
// computeJA4T génère l'empreinte JA4T à partir des métadonnées TCP SYN.
// Format : {WindowSize}_{OptionKinds}_{WindowScale}_{MSS}
func computeJA4T(tcp api.TCPMeta) string {
optStr := ""
if len(tcp.OptionKinds) > 0 {
parts := make([]string, len(tcp.OptionKinds))
for i, k := range tcp.OptionKinds {
parts[i] = strconv.Itoa(int(k))
}
optStr = strings.Join(parts, "-")
}
return fmt.Sprintf("%d_%s_%d_%d", tcp.WindowSize, optStr, tcp.WindowScale, tcp.MSS)
}
// extractJA4Hash extracts the hash portion from a JA4 string
// JA4 format: <base>_<sni_hash>_<cipher_hash> -> returns "<sni_hash>_<cipher_hash>"
func extractJA4Hash(ja4 string) string {
// JA4 string format: t13d1516h2_8daaf6152771_02cb136f2775
// We extract everything after the first underscore as the "hash" portion
for i, c := range ja4 {
if c == '_' {
return ja4[i+1:]
}
}
return ""
}
// sanitizeClientHelloExtensions fixes ClientHellos with truncated extension data
// by adjusting the extensions length to include only complete extensions.
// Returns a corrected copy, or nil if the payload cannot be fixed.
func sanitizeClientHelloExtensions(data []byte) []byte {
if len(data) < 5 || data[0] != 0x16 {
return nil
}
recordLen := int(data[3])<<8 | int(data[4])
if len(data) < 5+recordLen {
return nil
}
payload := data[5 : 5+recordLen]
if len(payload) < 4 || payload[0] != 0x01 {
return nil
}
helloLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3])
if len(payload) < 4+helloLen {
return nil
}
hello := payload[4 : 4+helloLen]
// Skip through ClientHello fields to reach extensions
offset := 2 + 32 // version + random
if len(hello) < offset+1 {
return nil
}
offset += 1 + int(hello[offset]) // session ID
if len(hello) < offset+2 {
return nil
}
csLen := int(hello[offset])<<8 | int(hello[offset+1])
offset += 2 + csLen // cipher suites
if len(hello) < offset+1 {
return nil
}
offset += 1 + int(hello[offset]) // compression methods
if len(hello) < offset+2 {
return nil
}
extLenOffset := offset // position of extensions length field
declaredExtLen := int(hello[offset])<<8 | int(hello[offset+1])
offset += 2
extStart := offset
if len(hello) < extStart+declaredExtLen {
return nil
}
extData := hello[extStart : extStart+declaredExtLen]
// Walk extensions, find how many complete ones exist
validLen := 0
pos := 0
for pos < len(extData) {
if pos+4 > len(extData) {
break
}
extBodyLen := int(extData[pos+2])<<8 | int(extData[pos+3])
if pos+4+extBodyLen > len(extData) {
break // this extension is truncated
}
pos += 4 + extBodyLen
validLen = pos
}
if validLen == declaredExtLen {
return nil // no truncation found, nothing to fix
}
// Build a corrected copy with adjusted extensions length
fixed := make([]byte, len(data))
copy(fixed, data)
// Absolute offset of extensions length field within data
extLenAbs := 5 + 4 + extLenOffset
diff := declaredExtLen - validLen
// Update extensions length
binary.BigEndian.PutUint16(fixed[extLenAbs:], uint16(validLen))
// Update ClientHello handshake length
newHelloLen := helloLen - diff
fixed[5+1] = byte(newHelloLen >> 16)
fixed[5+2] = byte(newHelloLen >> 8)
fixed[5+3] = byte(newHelloLen)
// Update TLS record length
newRecordLen := recordLen - diff
binary.BigEndian.PutUint16(fixed[3:5], uint16(newRecordLen))
return fixed[:5+newRecordLen]
}

View File

@ -0,0 +1,585 @@
package fingerprint
import (
"strings"
"testing"
"github.com/antitbone/ja4/sentinel/api"
tlsfingerprint "github.com/psanford/tlsfingerprint"
)
func TestFromClientHello(t *testing.T) {
tests := []struct {
name string
ch api.TLSClientHello
wantErr bool
}{
{
name: "empty payload",
ch: api.TLSClientHello{
Payload: []byte{},
},
wantErr: true,
},
{
name: "invalid payload",
ch: api.TLSClientHello{
Payload: []byte{0x00, 0x01, 0x02},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := NewEngine()
_, err := engine.FromClientHello(tt.ch)
if (err != nil) != tt.wantErr {
t.Errorf("FromClientHello() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestNewEngine(t *testing.T) {
engine := NewEngine()
if engine == nil {
t.Error("NewEngine() returned nil")
}
}
func TestFromClientHello_ValidPayload(t *testing.T) {
// Use a minimal valid TLS 1.2 ClientHello with extensions
// Build a proper ClientHello using the same structure as parser tests
clientHello := buildMinimalClientHelloForTest()
ch := api.TLSClientHello{
SrcIP: "192.168.1.100",
SrcPort: 54321,
DstIP: "10.0.0.1",
DstPort: 443,
Payload: clientHello,
}
engine := NewEngine()
fp, err := engine.FromClientHello(ch)
if err != nil {
t.Fatalf("FromClientHello() error = %v", err)
}
if fp == nil {
t.Fatal("FromClientHello() returned nil")
}
// Verify JA4 is populated (format: t13d... or t12d...)
if fp.JA4 == "" {
t.Error("JA4 should not be empty")
}
// JA4Hash is populated for internal use (but not serialized to LogRecord)
// It contains the hash portions of the JA4 string
if fp.JA4Hash == "" {
t.Error("JA4Hash should be populated for internal use")
}
}
// buildMinimalClientHelloForTest creates a minimal valid TLS 1.2 ClientHello
func buildMinimalClientHelloForTest() []byte {
// Cipher suites (minimal set)
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
// Compression methods (null only)
compressionMethods := []byte{0x01, 0x00}
// No extensions
extensions := []byte{}
extLen := len(extensions)
// Build ClientHello handshake body
handshakeBody := []byte{
0x03, 0x03, // Version: TLS 1.2
// Random (32 bytes)
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, // Session ID length: 0
}
// Add cipher suites (with length prefix)
cipherSuiteLen := len(cipherSuites)
handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen))
handshakeBody = append(handshakeBody, cipherSuites...)
// Add compression methods (with length prefix)
handshakeBody = append(handshakeBody, compressionMethods...)
// Add extensions (with length prefix)
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
handshakeBody = append(handshakeBody, extensions...)
// Now build full handshake with type and length
handshakeLen := len(handshakeBody)
handshake := append([]byte{
0x01, // Handshake type: ClientHello
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen), // Handshake length
}, handshakeBody...)
// Build TLS record
recordLen := len(handshake)
record := make([]byte, 5+recordLen)
record[0] = 0x16 // Handshake
record[1] = 0x03 // Version: TLS 1.2
record[2] = 0x03
record[3] = byte(recordLen >> 8)
record[4] = byte(recordLen)
copy(record[5:], handshake)
return record
}
// TestExtractJA4Hash tests the extractJA4Hash helper function
func TestExtractJA4Hash(t *testing.T) {
tests := []struct {
name string
ja4 string
want string
}{
{
name: "standard_ja4_format",
ja4: "t13d1516h2_8daaf6152771_02cb136f2775",
want: "8daaf6152771_02cb136f2775",
},
{
name: "ja4_with_single_underscore",
ja4: "t12d1234h1_abcdef123456",
want: "abcdef123456",
},
{
name: "ja4_no_underscore_returns_empty",
ja4: "t13d1516h2",
want: "",
},
{
name: "empty_ja4_returns_empty",
ja4: "",
want: "",
},
{
name: "underscore_at_start",
ja4: "_hash1_hash2",
want: "hash1_hash2",
},
{
name: "multiple_underscores_returns_after_first",
ja4: "base_part1_part2_part3",
want: "part1_part2_part3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractJA4Hash(tt.ja4)
if got != tt.want {
t.Errorf("extractJA4Hash(%q) = %q, want %q", tt.ja4, got, tt.want)
}
})
}
}
// TestFromClientHello_NilPayload tests error handling for nil payload
func TestFromClientHello_NilPayload(t *testing.T) {
engine := NewEngine()
ch := api.TLSClientHello{
Payload: nil,
}
_, err := engine.FromClientHello(ch)
if err == nil {
t.Error("FromClientHello() with nil payload should return error")
}
if !strings.HasPrefix(err.Error(), "empty ClientHello payload") {
t.Errorf("FromClientHello() error = %v, should start with 'empty ClientHello payload'", err)
}
}
// TestFromClientHello_JA3Hash tests that JA3Hash is correctly populated
func TestFromClientHello_JA3Hash(t *testing.T) {
clientHello := buildMinimalClientHelloForTest()
ch := api.TLSClientHello{
Payload: clientHello,
}
engine := NewEngine()
fp, err := engine.FromClientHello(ch)
if err != nil {
t.Fatalf("FromClientHello() error = %v", err)
}
// JA3Hash should be populated (MD5 hash of JA3 string)
if fp.JA3Hash == "" {
t.Error("JA3Hash should be populated")
}
// JA3 should also be populated
if fp.JA3 == "" {
t.Error("JA3 should be populated")
}
}
// TestFromClientHello_EmptyJA4Hash tests behavior when JA4 has no underscore
func TestFromClientHello_EmptyJA4Hash(t *testing.T) {
// This test verifies that even if JA4 format changes, the code handles it gracefully
engine := NewEngine()
// Use a valid ClientHello - the library should produce a proper JA4
clientHello := buildMinimalClientHelloForTest()
ch := api.TLSClientHello{
Payload: clientHello,
}
fp, err := engine.FromClientHello(ch)
if err != nil {
t.Fatalf("FromClientHello() error = %v", err)
}
// JA4 should always be populated
if fp.JA4 == "" {
t.Error("JA4 should be populated")
}
// JA4Hash may be empty if the JA4 format doesn't include underscores
// This is acceptable behavior
}
// buildClientHelloWithTruncatedExtension creates a ClientHello where the last
// extension declares more data than actually present.
func buildClientHelloWithTruncatedExtension() []byte {
// Build a valid SNI extension first
sniHostname := []byte("example.com")
sniExt := []byte{
0x00, 0x00, // Extension type: server_name
}
sniData := []byte{0x00}
sniListLen := 1 + 2 + len(sniHostname) // type(1) + len(2) + hostname
sniData = append(sniData, byte(sniListLen>>8), byte(sniListLen))
sniData = append(sniData, 0x00) // hostname type
sniData = append(sniData, byte(len(sniHostname)>>8), byte(len(sniHostname)))
sniData = append(sniData, sniHostname...)
sniExt = append(sniExt, byte(len(sniData)>>8), byte(len(sniData)))
sniExt = append(sniExt, sniData...)
// Build a truncated extension: declares 100 bytes but only has 5
truncatedExt := []byte{
0x00, 0x15, // Extension type: padding
0x00, 0x64, // Extension data length: 100 (but we only provide 5)
0x00, 0x00, 0x00, 0x00, 0x00, // Only 5 bytes of padding
}
// Extensions = valid SNI + truncated padding
extensions := append(sniExt, truncatedExt...)
// But the extensions length field claims the full size (including the bad extension)
extLen := len(extensions)
// Cipher suites
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
compressionMethods := []byte{0x01, 0x00}
handshakeBody := []byte{0x03, 0x03}
for i := 0; i < 32; i++ {
handshakeBody = append(handshakeBody, 0x01)
}
handshakeBody = append(handshakeBody, 0x00) // session ID length: 0
handshakeBody = append(handshakeBody, byte(len(cipherSuites)>>8), byte(len(cipherSuites)))
handshakeBody = append(handshakeBody, cipherSuites...)
handshakeBody = append(handshakeBody, compressionMethods...)
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
handshakeBody = append(handshakeBody, extensions...)
handshakeLen := len(handshakeBody)
handshake := append([]byte{
0x01,
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen),
}, handshakeBody...)
recordLen := len(handshake)
record := make([]byte, 5+recordLen)
record[0] = 0x16
record[1] = 0x03
record[2] = 0x03
record[3] = byte(recordLen >> 8)
record[4] = byte(recordLen)
copy(record[5:], handshake)
return record
}
func TestFromClientHello_TruncatedExtension_StillGeneratesFingerprint(t *testing.T) {
payload := buildClientHelloWithTruncatedExtension()
ch := api.TLSClientHello{
SrcIP: "4.251.36.192",
SrcPort: 19346,
DstIP: "212.95.72.88",
DstPort: 443,
Payload: payload,
ConnID: "4.251.36.192:19346->212.95.72.88:443",
}
engine := NewEngine()
fp, err := engine.FromClientHello(ch)
if err != nil {
t.Fatalf("FromClientHello() should succeed after sanitization, got error: %v", err)
}
if fp == nil {
t.Fatal("FromClientHello() returned nil fingerprint")
}
if fp.JA4 == "" {
t.Error("JA4 should be populated even with truncated extension")
}
if fp.JA3 == "" {
t.Error("JA3 should be populated even with truncated extension")
}
}
func TestSanitizeClientHelloExtensions(t *testing.T) {
t.Run("valid payload returns nil", func(t *testing.T) {
valid := buildMinimalClientHelloForTest()
result := sanitizeClientHelloExtensions(valid)
if result != nil {
t.Error("should return nil for valid payload (no fix needed)")
}
})
t.Run("truncated extension is fixed", func(t *testing.T) {
truncated := buildClientHelloWithTruncatedExtension()
result := sanitizeClientHelloExtensions(truncated)
if result == nil {
t.Fatal("should return sanitized payload")
}
// The sanitized payload should be parseable by the library
fp, err := tlsfingerprint.ParseClientHello(result)
if err != nil {
t.Fatalf("sanitized payload should parse without error, got: %v", err)
}
if fp == nil {
t.Fatal("sanitized payload should produce a fingerprint")
}
})
t.Run("too short returns nil", func(t *testing.T) {
if sanitizeClientHelloExtensions([]byte{0x16}) != nil {
t.Error("should return nil for short payload")
}
})
t.Run("non-TLS returns nil", func(t *testing.T) {
if sanitizeClientHelloExtensions([]byte{0x15, 0x03, 0x03, 0x00, 0x01, 0x00}) != nil {
t.Error("should return nil for non-TLS payload")
}
})
}
// TestExtractJA4Hash_Standard tests the hash extraction from a standard JA4 string.
func TestExtractJA4Hash_Standard(t *testing.T) {
ja4 := "t13d1516h2_8daaf6152771_02cb136f2775"
got := extractJA4Hash(ja4)
expected := "8daaf6152771_02cb136f2775"
if got != expected {
t.Errorf("extractJA4Hash(%q) = %q, want %q", ja4, got, expected)
}
}
// TestExtractJA4Hash_NoUnderscore tests that no underscore returns empty string.
func TestExtractJA4Hash_NoUnderscore(t *testing.T) {
got := extractJA4Hash("nounderscore")
if got != "" {
t.Errorf("expected empty string for no underscore, got %q", got)
}
}
// TestExtractJA4Hash_Empty tests that empty string returns empty string.
func TestExtractJA4Hash_Empty(t *testing.T) {
got := extractJA4Hash("")
if got != "" {
t.Errorf("expected empty string for empty input, got %q", got)
}
}
// TestFromClientHello_NilPayloadExplicit tests that nil payload (empty) returns error.
func TestFromClientHello_NilPayloadExplicit(t *testing.T) {
engine := NewEngine()
_, err := engine.FromClientHello(api.TLSClientHello{
SrcIP: "1.2.3.4",
SrcPort: 12345,
DstIP: "5.6.7.8",
DstPort: 443,
Payload: nil,
})
if err == nil {
t.Error("expected error for nil payload")
}
}
// TestFromClientHello_SingleByte tests that single byte payload returns error.
func TestFromClientHello_SingleByte(t *testing.T) {
engine := NewEngine()
_, err := engine.FromClientHello(api.TLSClientHello{
Payload: []byte{0x16},
})
if err == nil {
t.Error("expected error for single-byte payload")
}
}
// TestFromClientHello_ErrorContainsAddresses tests that error message includes addresses.
func TestFromClientHello_ErrorContainsAddresses(t *testing.T) {
engine := NewEngine()
_, err := engine.FromClientHello(api.TLSClientHello{
SrcIP: "192.168.1.100",
SrcPort: 54321,
DstIP: "10.0.0.1",
DstPort: 443,
ConnID: "test-conn-id",
Payload: []byte{0x01, 0x02, 0x03}, // invalid
})
if err == nil {
t.Fatal("expected error for invalid payload")
}
if !strings.Contains(err.Error(), "192.168.1.100") {
t.Errorf("expected error to contain src IP, got: %v", err)
}
}
// TestSanitizeClientHelloExtensions_NilInput tests nil input returns nil.
func TestSanitizeClientHelloExtensions_NilInput(t *testing.T) {
if sanitizeClientHelloExtensions(nil) != nil {
t.Error("nil input should return nil")
}
}
// TestSanitizeClientHelloExtensions_EmptyInput tests empty input returns nil.
func TestSanitizeClientHelloExtensions_EmptyInput(t *testing.T) {
if sanitizeClientHelloExtensions([]byte{}) != nil {
t.Error("empty input should return nil")
}
}
// TestJA4HashExtraction_ConsistentWithFullParse verifies JA4Hash is the tail of JA4 string.
func TestJA4HashExtraction_ConsistentWithFullParse(t *testing.T) {
// Any JA4 string with exactly one underscore should work
ja4 := "t12d4562h0_somehash"
hash := extractJA4Hash(ja4)
if !strings.HasPrefix(ja4, "t12") {
t.Skip("precondition failed")
}
if hash != "somehash" {
t.Errorf("expected 'somehash', got %q", hash)
}
}
// Compile-time check: EngineImpl satisfies api.Engine.
var _ interface {
FromClientHello(api.TLSClientHello) (*api.Fingerprints, error)
} = (*EngineImpl)(nil)
// TestComputeJA4T tests the JA4T fingerprint generation.
func TestComputeJA4T(t *testing.T) {
tests := []struct {
name string
tcp api.TCPMeta
want string
}{
{
name: "linux_5x_typical",
tcp: api.TCPMeta{
WindowSize: 64240,
OptionKinds: []uint8{2, 4, 8, 1, 3},
WindowScale: 7,
MSS: 1460,
},
want: "64240_2-4-8-1-3_7_1460",
},
{
name: "windows_11_typical",
tcp: api.TCPMeta{
WindowSize: 64240,
OptionKinds: []uint8{2, 4, 8, 1, 3},
WindowScale: 8,
MSS: 1460,
},
want: "64240_2-4-8-1-3_8_1460",
},
{
name: "macos_14_typical",
tcp: api.TCPMeta{
WindowSize: 65535,
OptionKinds: []uint8{2, 4, 8, 1, 3},
WindowScale: 6,
MSS: 1460,
},
want: "65535_2-4-8-1-3_6_1460",
},
{
name: "no_options",
tcp: api.TCPMeta{
WindowSize: 8192,
OptionKinds: nil,
WindowScale: 0,
MSS: 0,
},
want: "8192__0_0",
},
{
name: "windows_no_ts",
tcp: api.TCPMeta{
WindowSize: 8192,
OptionKinds: []uint8{2, 4, 1, 3},
WindowScale: 2,
MSS: 1460,
},
want: "8192_2-4-1-3_2_1460",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := computeJA4T(tt.tcp)
if got != tt.want {
t.Errorf("computeJA4T() = %q, want %q", got, tt.want)
}
})
}
}
// TestFromClientHello_JA4T_Populated tests that JA4T is populated in FromClientHello.
func TestFromClientHello_JA4T_Populated(t *testing.T) {
clientHello := buildMinimalClientHelloForTest()
ch := api.TLSClientHello{
Payload: clientHello,
TCPMeta: api.TCPMeta{
WindowSize: 64240,
MSS: 1460,
WindowScale: 7,
OptionKinds: []uint8{2, 4, 8, 1, 3},
Options: []string{"MSS", "SACK", "TS", "NOP", "WS"},
},
}
engine := NewEngine()
fp, err := engine.FromClientHello(ch)
if err != nil {
t.Fatalf("FromClientHello() error = %v", err)
}
expected := "64240_2-4-8-1-3_7_1460"
if fp.JA4T != expected {
t.Errorf("JA4T = %q, want %q", fp.JA4T, expected)
}
}

View File

@ -0,0 +1,402 @@
// Package integration provides integration tests for the full ja4sentinel pipeline
package integration
import (
"encoding/json"
"os"
"testing"
"time"
"github.com/antitbone/ja4/sentinel/api"
"github.com/antitbone/ja4/sentinel/internal/fingerprint"
"github.com/antitbone/ja4/sentinel/internal/output"
"github.com/antitbone/ja4/sentinel/internal/tlsparse"
)
// TestFullPipeline_TLSClientHelloToFingerprint tests the pipeline from TLS ClientHello to fingerprint
func TestFullPipeline_TLSClientHelloToFingerprint(t *testing.T) {
// Create a minimal TLS 1.2 ClientHello for testing
clientHello := buildMinimalTLSClientHello()
// Step 1: Parse the ClientHello
parser := tlsparse.NewParser()
if parser == nil {
t.Fatal("NewParser() returned nil")
}
defer parser.Close()
// Create a raw packet with the ClientHello
rawPacket := api.RawPacket{
Data: buildEthernetIPPacket(clientHello),
Timestamp: time.Now().UnixNano(),
}
// Process the packet
ch, err := parser.Process(rawPacket)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
if ch == nil {
t.Fatal("Process() returned nil ClientHello")
}
// Step 2: Generate fingerprints
engine := fingerprint.NewEngine()
if engine == nil {
t.Fatal("NewEngine() returned nil")
}
fp, err := engine.FromClientHello(*ch)
if err != nil {
t.Fatalf("FromClientHello() error = %v", err)
}
if fp == nil {
t.Fatal("FromClientHello() returned nil")
}
// Verify fingerprints are populated
if fp.JA4 == "" {
t.Error("JA4 should be populated")
}
if fp.JA3 == "" {
t.Error("JA3 should be populated")
}
if fp.JA3Hash == "" {
t.Error("JA3Hash should be populated")
}
}
// TestFullPipeline_FingerprintToOutput tests the pipeline from fingerprint to output
func TestFullPipeline_FingerprintToOutput(t *testing.T) {
// Create test data
clientHello := api.TLSClientHello{
SrcIP: "192.168.1.100",
SrcPort: 54321,
DstIP: "10.0.0.1",
DstPort: 443,
IPMeta: api.IPMeta{
TTL: 64,
TotalLength: 512,
IPID: 12345,
DF: true,
},
TCPMeta: api.TCPMeta{
WindowSize: 65535,
MSS: 1460,
WindowScale: 7,
Options: []string{"MSS", "SACK", "TS", "WS"},
},
ConnID: "test-flow-123",
SNI: "example.com",
ALPN: "h2",
TLSVersion: "1.3",
SynToCHMs: uint32Ptr(50),
}
// Create fingerprints
fingerprints := &api.Fingerprints{
JA4: "t13d1516h2_8daaf6152771_02cb136f2775",
JA4Hash: "8daaf6152771_02cb136f2775",
JA3: "771,4865-4866-4867,0-23-65281-10-11-35-16-5-13-18-51-45-43-27-17513,29-23-24,0",
JA3Hash: "a0e6f06c7a6d15e5e3f0f0e6f06c7a6d",
}
// Step 1: Create LogRecord
logRecord := api.NewLogRecord(clientHello, fingerprints)
logRecord.SensorID = "test-sensor"
// Step 2: Write to output (stdout writer for testing)
writer := output.NewStdoutWriter()
if writer == nil {
t.Fatal("NewStdoutWriter() returned nil")
}
// Capture stdout by using a buffer (we can't easily test stdout, so we verify the record)
// Instead, verify the LogRecord is valid JSON
data, err := json.Marshal(logRecord)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
// Verify JSON is valid and contains expected fields
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Verify key fields
if result["src_ip"] != "192.168.1.100" {
t.Errorf("src_ip = %v, want 192.168.1.100", result["src_ip"])
}
if result["src_port"] != float64(54321) {
t.Errorf("src_port = %v, want 54321", result["src_port"])
}
if result["ja4"] != "t13d1516h2_8daaf6152771_02cb136f2775" {
t.Errorf("ja4 = %v, want t13d1516h2_8daaf6152771_02cb136f2775", result["ja4"])
}
if result["tls_sni"] != "example.com" {
t.Errorf("tls_sni = %v, want example.com", result["tls_sni"])
}
if result["sensor_id"] != "test-sensor" {
t.Errorf("sensor_id = %v, want test-sensor", result["sensor_id"])
}
}
// TestFullPipeline_EndToEnd tests the complete pipeline with file output
func TestFullPipeline_EndToEnd(t *testing.T) {
tmpDir := t.TempDir()
outputPath := tmpDir + "/output.log"
// Create test ClientHello
clientHello := buildMinimalTLSClientHello()
// Step 1: Parse
parser := tlsparse.NewParser()
defer parser.Close()
rawPacket := api.RawPacket{
Data: buildEthernetIPPacket(clientHello),
Timestamp: time.Now().UnixNano(),
}
ch, err := parser.Process(rawPacket)
if err != nil {
t.Fatalf("Process() error = %v", err)
}
// Step 2: Fingerprint
engine := fingerprint.NewEngine()
fp, err := engine.FromClientHello(*ch)
if err != nil {
t.Fatalf("FromClientHello() error = %v", err)
}
// Step 3: Create LogRecord
logRecord := api.NewLogRecord(*ch, fp)
logRecord.SensorID = "test-sensor-e2e"
// Step 4: Write to file
fileWriter, err := output.NewFileWriter(outputPath)
if err != nil {
t.Fatalf("NewFileWriter() error = %v", err)
}
defer fileWriter.Close()
err = fileWriter.Write(logRecord)
if err != nil {
t.Errorf("Write() error = %v", err)
}
// Verify output file
data, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
if len(data) == 0 {
t.Fatal("Output file is empty")
}
// Parse and verify
var result api.LogRecord
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if result.SensorID != "test-sensor-e2e" {
t.Errorf("SensorID = %v, want test-sensor-e2e", result.SensorID)
}
if result.JA4 == "" {
t.Error("JA4 should be populated")
}
}
// TestFullPipeline_MultiOutput tests writing to multiple outputs simultaneously
func TestFullPipeline_MultiOutput(t *testing.T) {
tmpDir := t.TempDir()
filePath := tmpDir + "/multi.log"
// Create multi-writer
multiWriter := output.NewMultiWriter()
multiWriter.Add(output.NewStdoutWriter())
fileWriter, err := output.NewFileWriter(filePath)
if err != nil {
t.Fatalf("NewFileWriter() error = %v", err)
}
multiWriter.Add(fileWriter)
// Create test record
logRecord := api.LogRecord{
SrcIP: "192.168.1.1",
SrcPort: 12345,
JA4: "test-multi-output",
}
// Write to all outputs
err = multiWriter.Write(logRecord)
if err != nil {
t.Errorf("Write() error = %v", err)
}
// Verify file output
data, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
if len(data) == 0 {
t.Fatal("File output is empty")
}
}
// TestFullPipeline_ConfigToOutput tests building output from config
func TestFullPipeline_ConfigToOutput(t *testing.T) {
tmpDir := t.TempDir()
// Create config with multiple outputs
config := api.AppConfig{
Core: api.Config{
Interface: "eth0",
ListenPorts: []uint16{443},
},
Outputs: []api.OutputConfig{
{
Type: "stdout",
Enabled: true,
AsyncBuffer: 1000,
},
{
Type: "file",
Enabled: true,
AsyncBuffer: 1000,
Params: map[string]string{"path": tmpDir + "/config-output.log"},
},
},
}
// Build writer from config
builder := output.NewBuilder()
writer, err := builder.NewFromConfig(config)
if err != nil {
t.Fatalf("NewFromConfig() error = %v", err)
}
// Verify writer is MultiWriter
_, ok := writer.(*output.MultiWriter)
if !ok {
t.Fatal("Expected MultiWriter")
}
// Test writing
logRecord := api.LogRecord{
SrcIP: "192.168.1.1",
JA4: "test-config-output",
}
err = writer.Write(logRecord)
if err != nil {
t.Errorf("Write() error = %v", err)
}
}
// Helper functions
// buildMinimalTLSClientHello creates a minimal TLS 1.2 ClientHello for testing
func buildMinimalTLSClientHello() []byte {
// Cipher suites
cipherSuites := []byte{0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0xc0, 0x2f}
compressionMethods := []byte{0x01, 0x00}
extensions := []byte{}
extLen := len(extensions)
handshakeBody := []byte{
0x03, 0x03, // Version: TLS 1.2
}
// Random (32 bytes)
for i := 0; i < 32; i++ {
handshakeBody = append(handshakeBody, 0x00)
}
handshakeBody = append(handshakeBody, 0x00) // Session ID length
// Cipher suites
cipherSuiteLen := len(cipherSuites)
handshakeBody = append(handshakeBody, byte(cipherSuiteLen>>8), byte(cipherSuiteLen))
handshakeBody = append(handshakeBody, cipherSuites...)
// Compression methods
handshakeBody = append(handshakeBody, compressionMethods...)
// Extensions
handshakeBody = append(handshakeBody, byte(extLen>>8), byte(extLen))
handshakeBody = append(handshakeBody, extensions...)
// Build handshake
handshakeLen := len(handshakeBody)
handshake := append([]byte{
0x01, // Handshake type: ClientHello
byte(handshakeLen >> 16), byte(handshakeLen >> 8), byte(handshakeLen),
}, handshakeBody...)
// Build TLS record
recordLen := len(handshake)
record := make([]byte, 5+recordLen)
record[0] = 0x16 // Handshake
record[1] = 0x03 // Version: TLS 1.2
record[2] = 0x03
record[3] = byte(recordLen >> 8)
record[4] = byte(recordLen)
copy(record[5:], handshake)
return record
}
// buildEthernetIPPacket wraps a TLS payload in Ethernet/IP/TCP headers
func buildEthernetIPPacket(tlsPayload []byte) []byte {
// This is a simplified packet structure for testing
// Real packets would have proper Ethernet, IP, and TCP headers
// Ethernet header (14 bytes)
eth := make([]byte, 14)
eth[12] = 0x08 // EtherType: IPv4
eth[13] = 0x00
// IP header (20 bytes)
ip := make([]byte, 20)
ip[0] = 0x45 // Version 4, IHL 5
ip[1] = 0x00 // DSCP/ECN
ip[2] = byte((20 + 20 + len(tlsPayload)) >> 8) // Total length
ip[3] = byte((20 + 20 + len(tlsPayload)) & 0xFF)
ip[8] = 64 // TTL
ip[9] = 6 // Protocol: TCP
ip[12] = 192
ip[13] = 168
ip[14] = 1
ip[15] = 100 // Src IP: 192.168.1.100
ip[16] = 10
ip[17] = 0
ip[18] = 0
ip[19] = 1 // Dst IP: 10.0.0.1
// TCP header (20 bytes)
tcp := make([]byte, 20)
tcp[0] = byte(54321 >> 8) // Src port high
tcp[1] = byte(54321 & 0xFF) // Src port low
tcp[2] = byte(443 >> 8) // Dst port high
tcp[3] = byte(443 & 0xFF) // Dst port low
tcp[12] = 0x50 // Data offset (5 * 4 = 20 bytes)
tcp[13] = 0x18 // Flags: ACK, PSH
// Combine all headers with payload
packet := make([]byte, len(eth)+len(ip)+len(tcp)+len(tlsPayload))
copy(packet, eth)
copy(packet[len(eth):], ip)
copy(packet[len(eth)+len(ip):], tcp)
copy(packet[len(eth)+len(ip)+len(tcp):], tlsPayload)
return packet
}
func uint32Ptr(v uint32) *uint32 {
return &v
}

View File

@ -0,0 +1,15 @@
// Package ipfilter provides IP address and CIDR range matching for filtering.
// Implementation is delegated to shared/go/ja4common/ipfilter to avoid duplication.
package ipfilter
import jaipfilter "github.com/antitbone/ja4/ja4common/ipfilter"
// Filter is a type alias for ja4common/ipfilter.Filter.
// All methods (New, ShouldExclude, Count) are inherited from the shared module.
type Filter = jaipfilter.Filter
// New creates a new IP filter from a list of IP addresses or CIDR ranges.
// Accepts formats like: "192.168.1.1", "10.0.0.0/8", "2001:db8::/32"
func New(excludeList []string) (*Filter, error) {
return jaipfilter.New(excludeList)
}

View File

@ -0,0 +1,160 @@
package ipfilter
import (
"testing"
)
func TestFilter_New(t *testing.T) {
tests := []struct {
name string
list []string
wantErr bool
}{
{
name: "empty list",
list: []string{},
wantErr: false,
},
{
name: "single IP",
list: []string{"192.168.1.1"},
wantErr: false,
},
{
name: "single CIDR",
list: []string{"10.0.0.0/8"},
wantErr: false,
},
{
name: "mixed IPs and CIDRs",
list: []string{"192.168.1.1", "10.0.0.0/8", "172.16.0.0/12"},
wantErr: false,
},
{
name: "invalid IP",
list: []string{"999.999.999.999"},
wantErr: true,
},
{
name: "invalid CIDR",
list: []string{"10.0.0.0/33"},
wantErr: true,
},
{
name: "IPv6 address",
list: []string{"2001:db8::1"},
wantErr: false,
},
{
name: "IPv6 CIDR",
list: []string{"2001:db8::/32"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f, err := New(tt.list)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && f == nil {
t.Error("New() should return non-nil filter on success")
}
})
}
}
func TestFilter_ShouldExclude(t *testing.T) {
f, err := New([]string{
"192.168.1.1",
"10.0.0.0/8",
"172.16.0.0/12",
"2001:db8::1",
"fc00::/7",
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
tests := []struct {
name string
ip string
want bool
}{
// Exact IP matches
{"exact match", "192.168.1.1", true},
{"exact IPv6 match", "2001:db8::1", true},
// CIDR matches
{"CIDR match 10.0.0.1", "10.0.0.1", true},
{"CIDR match 10.255.255.255", "10.255.255.255", true},
{"CIDR match 172.16.0.1", "172.16.0.1", true},
{"CIDR match 172.31.255.255", "172.31.255.255", true},
{"CIDR IPv6 match", "fc00::1", true},
// No matches
{"no match 192.168.2.1", "192.168.2.1", false},
{"no match 11.0.0.1", "11.0.0.1", false},
{"no match 172.32.0.1", "172.32.0.1", false},
{"no match 8.8.8.8", "8.8.8.8", false},
// Invalid IP
{"invalid IP", "invalid", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := f.ShouldExclude(tt.ip); got != tt.want {
t.Errorf("ShouldExclude(%q) = %v, want %v", tt.ip, got, tt.want)
}
})
}
}
func TestFilter_ShouldExclude_NilFilter(t *testing.T) {
var f *Filter
if f.ShouldExclude("192.168.1.1") {
t.Error("ShouldExclude on nil filter should return false")
}
}
func TestFilter_Count(t *testing.T) {
f, err := New([]string{
"192.168.1.1",
"10.0.0.1",
"10.0.0.0/8",
"172.16.0.0/12",
})
if err != nil {
t.Fatalf("New() error = %v", err)
}
ips, networks := f.Count()
if ips != 2 {
t.Errorf("Count() ips = %d, want 2", ips)
}
if networks != 2 {
t.Errorf("Count() networks = %d, want 2", networks)
}
}
func TestFilter_EmptyEntries(t *testing.T) {
f, err := New([]string{"", "192.168.1.1", ""})
if err != nil {
t.Fatalf("New() error = %v", err)
}
ips, _ := f.Count()
if ips != 1 {
t.Errorf("Count() ips = %d, want 1 (empty entries should be skipped)", ips)
}
if !f.ShouldExclude("192.168.1.1") {
t.Error("Should exclude 192.168.1.1")
}
if f.ShouldExclude("192.168.1.2") {
t.Error("Should not exclude 192.168.1.2")
}
}

View File

@ -0,0 +1,19 @@
// Package logging provides a factory for creating loggers
package logging
import (
"github.com/antitbone/ja4/sentinel/api"
)
// LoggerFactory creates logger instances
type LoggerFactory struct{}
// NewLogger creates a new logger based on configuration
func (f *LoggerFactory) NewLogger(level string) api.Logger {
return NewServiceLogger(level)
}
// NewDefaultLogger creates a logger with default settings
func (f *LoggerFactory) NewDefaultLogger() api.Logger {
return NewServiceLogger("info")
}

View File

@ -0,0 +1,47 @@
// Package logging provides structured logging for the sentinel service.
// Implementation is delegated to shared/go/ja4common/logger to avoid duplication.
package logging
import (
jalogger "github.com/antitbone/ja4/ja4common/logger"
"github.com/antitbone/ja4/sentinel/api"
)
// ServiceLogger satisfies api.Logger using ja4common/logger.ComponentLogger.
// This avoids duplicating logging logic that is now shared across all ja4-platform services.
type ServiceLogger struct {
inner *jalogger.ComponentLogger
}
// NewServiceLogger creates a new ServiceLogger backed by ja4common.
func NewServiceLogger(level string) *ServiceLogger {
return &ServiceLogger{inner: jalogger.NewComponentLogger(level)}
}
// Log emits a structured log entry for the given component.
func (l *ServiceLogger) Log(component, level, message string, details map[string]string) {
l.inner.Log(component, level, message, details)
}
// Debug logs a debug entry for the given component.
func (l *ServiceLogger) Debug(component, message string, details map[string]string) {
l.inner.Debug(component, message, details)
}
// Info logs an info entry for the given component.
func (l *ServiceLogger) Info(component, message string, details map[string]string) {
l.inner.Info(component, message, details)
}
// Warn logs a warning entry for the given component.
func (l *ServiceLogger) Warn(component, message string, details map[string]string) {
l.inner.Warn(component, message, details)
}
// Error logs an error entry for the given component.
func (l *ServiceLogger) Error(component, message string, details map[string]string) {
l.inner.Error(component, message, details)
}
// compile-time check: ServiceLogger must satisfy api.Logger
var _ api.Logger = (*ServiceLogger)(nil)

View File

@ -0,0 +1,79 @@
// Package logging tests — behavioral tests for ServiceLogger.
// Since ServiceLogger delegates to ja4common/logger.ComponentLogger,
// we test behavior (no-panic, interface satisfaction, level filtering)
// rather than internal output buffering.
package logging_test
import (
"testing"
"github.com/antitbone/ja4/sentinel/api"
"github.com/antitbone/ja4/sentinel/internal/logging"
)
func TestNewServiceLogger_NonNil(t *testing.T) {
logger := logging.NewServiceLogger("info")
if logger == nil {
t.Fatal("expected non-nil logger")
}
}
func TestServiceLogger_ImplementsApiLogger(t *testing.T) {
logger := logging.NewServiceLogger("debug")
var _ api.Logger = logger // compile-time check
}
func TestServiceLogger_AllLevels_NoPanic(t *testing.T) {
levels := []string{"debug", "info", "warn", "error", "invalid"}
for _, level := range levels {
t.Run(level, func(t *testing.T) {
logger := logging.NewServiceLogger(level)
logger.Debug("comp", "debug msg", map[string]string{"k": "v"})
logger.Info("comp", "info msg", nil)
logger.Warn("comp", "warn msg", map[string]string{"x": "y"})
logger.Error("comp", "error msg", nil)
})
}
}
func TestServiceLogger_WithDetails(t *testing.T) {
logger := logging.NewServiceLogger("debug")
details := map[string]string{"error": "test error", "trace_id": "abc123"}
logger.Info("service", "test message", details)
}
func TestServiceLogger_NilDetails(t *testing.T) {
logger := logging.NewServiceLogger("debug")
logger.Info("service", "test message", nil)
}
func TestServiceLogger_ConcurrentLogging(t *testing.T) {
logger := logging.NewServiceLogger("debug")
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(id int) {
logger.Info("service", "concurrent message", map[string]string{"id": string(rune('0'+id))})
done <- true
}(i)
}
for i := 0; i < 10; i++ {
<-done
}
}
func TestLoggerFactory(t *testing.T) {
factory := &logging.LoggerFactory{}
levels := []string{"debug", "info", "warn", "error"}
for _, level := range levels {
t.Run(level, func(t *testing.T) {
logger := factory.NewLogger(level)
if logger == nil {
t.Fatalf("NewLogger(%q) returned nil", level)
}
})
}
logger := factory.NewDefaultLogger()
if logger == nil {
t.Fatal("NewDefaultLogger() returned nil")
}
}

View File

@ -0,0 +1,644 @@
// Package output provides writers for ja4sentinel log records
package output
import (
"encoding/json"
"fmt"
"io"
"net"
"os"
"path/filepath"
"sync"
"time"
"github.com/antitbone/ja4/sentinel/api"
)
// Socket configuration constants
const (
// DefaultDialTimeout is the default timeout for socket connections
DefaultDialTimeout = 5 * time.Second
// DefaultWriteTimeout is the default timeout for socket writes
DefaultWriteTimeout = 5 * time.Second
// DefaultMaxReconnectAttempts is the maximum number of reconnection attempts
DefaultMaxReconnectAttempts = 3
// DefaultReconnectBackoff is the initial backoff duration for reconnection
DefaultReconnectBackoff = 100 * time.Millisecond
// DefaultMaxReconnectBackoff is the maximum backoff duration
DefaultMaxReconnectBackoff = 2 * time.Second
// DefaultQueueSize is the size of the write queue for async writes
DefaultQueueSize = 1000
// DefaultMaxFileSize is the default maximum file size in bytes before rotation (100MB)
DefaultMaxFileSize = 100 * 1024 * 1024
// DefaultMaxBackups is the default number of backup files to keep
DefaultMaxBackups = 3
)
// StdoutWriter writes log records to stdout
type StdoutWriter struct {
encoder *json.Encoder
mutex sync.Mutex
}
// NewStdoutWriter creates a new stdout writer
func NewStdoutWriter() *StdoutWriter {
return &StdoutWriter{
encoder: json.NewEncoder(os.Stdout),
}
}
// Write writes a log record to stdout
func (w *StdoutWriter) Write(rec api.LogRecord) error {
w.mutex.Lock()
defer w.mutex.Unlock()
return w.encoder.Encode(rec)
}
// Close closes the writer (no-op for stdout)
func (w *StdoutWriter) Close() error {
return nil
}
// FileWriter writes log records to a file with rotation support
type FileWriter struct {
file *os.File
encoder *json.Encoder
mutex sync.Mutex
path string
maxSize int64
maxBackups int
currentSize int64
errorCallback ErrorCallback
failuresMu sync.Mutex
failures int
}
// FileWriterOption is a function type for configuring FileWriter
type FileWriterOption func(*FileWriter)
// WithFileErrorCallback sets an error callback for file write errors
func WithFileErrorCallback(cb ErrorCallback) FileWriterOption {
return func(w *FileWriter) {
w.errorCallback = cb
}
}
// NewFileWriter creates a new file writer with rotation
func NewFileWriter(path string) (*FileWriter, error) {
return NewFileWriterWithConfig(path, DefaultMaxFileSize, DefaultMaxBackups)
}
// NewFileWriterWithConfig creates a new file writer with custom rotation config
func NewFileWriterWithConfig(path string, maxSize int64, maxBackups int, opts ...FileWriterOption) (*FileWriter, error) {
// Create directory if it doesn't exist
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Open file with secure permissions (owner read/write only)
file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return nil, fmt.Errorf("failed to open file %s: %w", path, err)
}
// Get current file size
info, err := file.Stat()
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to stat file: %w", err)
}
w := &FileWriter{
file: file,
encoder: json.NewEncoder(file),
path: path,
maxSize: maxSize,
maxBackups: maxBackups,
currentSize: info.Size(),
}
// Apply options (for error callback)
for _, opt := range opts {
opt(w)
}
return w, nil
}
// rotate rotates the log file if it exceeds the max size
func (w *FileWriter) rotate() error {
if err := w.file.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
// Rotate existing backups
for i := w.maxBackups; i > 1; i-- {
oldPath := fmt.Sprintf("%s.%d", w.path, i-1)
newPath := fmt.Sprintf("%s.%d", w.path, i)
os.Rename(oldPath, newPath) // Ignore errors - file may not exist
}
// Move current file to .1
backupPath := fmt.Sprintf("%s.1", w.path)
if err := os.Rename(w.path, backupPath); err != nil {
// If rename fails, just truncate
if err := os.Truncate(w.path, 0); err != nil {
return fmt.Errorf("failed to truncate file: %w", err)
}
}
// Open new file
newFile, err := os.OpenFile(w.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("failed to open new file: %w", err)
}
w.file = newFile
w.encoder = json.NewEncoder(newFile)
w.currentSize = 0
return nil
}
// Write writes a log record to the file
func (w *FileWriter) Write(rec api.LogRecord) error {
w.mutex.Lock()
defer w.mutex.Unlock()
// Check if rotation is needed
if w.currentSize >= w.maxSize {
if err := w.rotate(); err != nil {
w.reportError(fmt.Errorf("failed to rotate file: %w", err))
return fmt.Errorf("failed to rotate file: %w", err)
}
}
// Encode to buffer first to get size
data, err := json.Marshal(rec)
if err != nil {
return fmt.Errorf("failed to marshal record: %w", err)
}
data = append(data, '\n')
// Write to file
n, err := w.file.Write(data)
if err != nil {
w.reportError(fmt.Errorf("failed to write to file: %w", err))
return fmt.Errorf("failed to write to file: %w", err)
}
w.currentSize += int64(n)
return nil
}
// reportError reports a file write error via the configured callback
func (w *FileWriter) reportError(err error) {
if w.errorCallback != nil {
w.failuresMu.Lock()
w.failures++
failures := w.failures
w.failuresMu.Unlock()
w.errorCallback(w.path, err, failures)
}
}
// Close closes the file
func (w *FileWriter) Close() error {
w.mutex.Lock()
defer w.mutex.Unlock()
if w.file != nil {
return w.file.Close()
}
return nil
}
// Reopen reopens the log file (for logrotate support)
func (w *FileWriter) Reopen() error {
w.mutex.Lock()
defer w.mutex.Unlock()
if err := w.file.Close(); err != nil {
return fmt.Errorf("failed to close file during reopen: %w", err)
}
// Open new file
newFile, err := os.OpenFile(w.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("failed to reopen file %s: %w", w.path, err)
}
w.file = newFile
w.encoder = json.NewEncoder(newFile)
w.currentSize = 0
return nil
}
// ErrorCallback is a function type for reporting socket connection errors
type ErrorCallback func(socketPath string, err error, attempt int)
// UnixSocketWriter writes log records to a UNIX socket with reconnection logic
// No internal logging - only LogRecord JSON data is sent to the socket
type UnixSocketWriter struct {
socketPath string
conn net.Conn
mutex sync.Mutex
dialTimeout time.Duration
writeTimeout time.Duration
maxReconnects int
reconnectBackoff time.Duration
maxBackoff time.Duration
queue chan []byte
queueClose chan struct{}
queueDone chan struct{}
closeOnce sync.Once
isClosed bool
pendingWrites [][]byte
pendingMu sync.Mutex
errorCallback ErrorCallback
consecutiveFailures int
failuresMu sync.Mutex
networkType string // "unix" for STREAM, "unixgram" for DGRAM
}
// NewUnixSocketWriter creates a new UNIX socket writer with reconnection logic
func NewUnixSocketWriter(socketPath string) (*UnixSocketWriter, error) {
return NewUnixSocketWriterWithConfig(socketPath, DefaultDialTimeout, DefaultWriteTimeout, DefaultQueueSize)
}
// UnixSocketWriterOption is a function type for configuring UnixSocketWriter
type UnixSocketWriterOption func(*UnixSocketWriter)
// WithErrorCallback sets an error callback for socket connection errors
func WithErrorCallback(cb ErrorCallback) UnixSocketWriterOption {
return func(w *UnixSocketWriter) {
w.errorCallback = cb
}
}
// NewUnixSocketWriterWithConfig creates a new UNIX socket writer with custom configuration
func NewUnixSocketWriterWithConfig(socketPath string, dialTimeout, writeTimeout time.Duration, queueSize int, opts ...UnixSocketWriterOption) (*UnixSocketWriter, error) {
w := &UnixSocketWriter{
socketPath: socketPath,
dialTimeout: dialTimeout,
writeTimeout: writeTimeout,
maxReconnects: DefaultMaxReconnectAttempts,
reconnectBackoff: DefaultReconnectBackoff,
maxBackoff: DefaultMaxReconnectBackoff,
queue: make(chan []byte, queueSize),
queueClose: make(chan struct{}),
queueDone: make(chan struct{}),
pendingWrites: make([][]byte, 0),
}
// Apply options
for _, opt := range opts {
opt(w)
}
// Start the queue processor
go w.processQueue()
// Try initial connection silently (socket may not exist yet - that's okay)
// Use unixgram (DGRAM) for connectionless UDP-like socket communication
conn, err := net.DialTimeout("unixgram", socketPath, w.dialTimeout)
if err == nil {
w.conn = conn
}
return w, nil
}
// processQueue handles queued writes with reconnection logic
func (w *UnixSocketWriter) processQueue() {
defer close(w.queueDone)
backoff := w.reconnectBackoff
for {
select {
case data, ok := <-w.queue:
if !ok {
// Channel closed, drain remaining data
w.flushPendingData()
return
}
if err := w.writeWithReconnect(data); err != nil {
w.failuresMu.Lock()
w.consecutiveFailures++
failures := w.consecutiveFailures
w.failuresMu.Unlock()
// Report error via callback if configured
w.reportError(err, failures)
// Queue for retry
w.pendingMu.Lock()
if len(w.pendingWrites) < DefaultQueueSize {
w.pendingWrites = append(w.pendingWrites, data)
}
w.pendingMu.Unlock()
// Exponential backoff
if failures > w.maxReconnects {
time.Sleep(backoff)
backoff *= 2
if backoff > w.maxBackoff {
backoff = w.maxBackoff
}
}
} else {
w.failuresMu.Lock()
w.consecutiveFailures = 0
w.failuresMu.Unlock()
backoff = w.reconnectBackoff
// Try to flush pending data
w.flushPendingData()
}
case <-w.queueClose:
w.flushPendingData()
return
}
}
}
// reportError reports a socket connection error via the configured callback
func (w *UnixSocketWriter) reportError(err error, attempt int) {
if w.errorCallback != nil {
w.errorCallback(w.socketPath, err, attempt)
}
}
// flushPendingData attempts to write any pending data
func (w *UnixSocketWriter) flushPendingData() {
w.pendingMu.Lock()
pending := w.pendingWrites
w.pendingWrites = make([][]byte, 0)
w.pendingMu.Unlock()
for _, data := range pending {
if err := w.writeWithReconnect(data); err != nil {
// Put it back for next flush attempt
w.pendingMu.Lock()
if len(w.pendingWrites) < DefaultQueueSize {
w.pendingWrites = append(w.pendingWrites, data)
}
w.pendingMu.Unlock()
break
}
}
}
// writeWithReconnect attempts to write data with reconnection logic
func (w *UnixSocketWriter) writeWithReconnect(data []byte) error {
w.mutex.Lock()
defer w.mutex.Unlock()
ensureConn := func() error {
if w.conn != nil {
return nil
}
// Use unixgram (DGRAM) for connectionless UDP-like socket communication
conn, err := net.DialTimeout("unixgram", w.socketPath, w.dialTimeout)
if err != nil {
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
}
w.conn = conn
return nil
}
if err := ensureConn(); err != nil {
return err
}
if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
if _, err := w.conn.Write(data); err == nil {
return nil
}
// Connection failed, try to reconnect
_ = w.conn.Close()
w.conn = nil
if err := ensureConn(); err != nil {
return fmt.Errorf("failed to reconnect: %w", err)
}
if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
_ = w.conn.Close()
w.conn = nil
return fmt.Errorf("failed to set write deadline after reconnect: %w", err)
}
if _, err := w.conn.Write(data); err != nil {
_ = w.conn.Close()
w.conn = nil
return fmt.Errorf("failed to write after reconnect: %w", err)
}
return nil
}
// Write writes a log record to the UNIX socket (non-blocking with queue).
// Bug 12 fix: marshal JSON outside the lock, then hold mutex through both the
// isClosed check AND the non-blocking channel send so Close() cannot close the
// channel between those two operations.
func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
data, err := json.Marshal(rec)
if err != nil {
return fmt.Errorf("failed to marshal record: %w", err)
}
data = append(data, '\n')
w.mutex.Lock()
defer w.mutex.Unlock()
if w.isClosed {
return fmt.Errorf("writer is closed")
}
select {
case w.queue <- data:
return nil
default:
return fmt.Errorf("write queue is full, dropping message")
}
}
// Close closes the UNIX socket connection and stops the queue processor.
// Bug 12 fix: set isClosed=true under mutex BEFORE closing the channel so a
// concurrent Write() sees the flag and returns early instead of panicking on
// a send-on-closed-channel.
func (w *UnixSocketWriter) Close() error {
w.closeOnce.Do(func() {
w.mutex.Lock()
w.isClosed = true
w.mutex.Unlock()
close(w.queueClose)
<-w.queueDone
close(w.queue)
w.mutex.Lock()
if w.conn != nil {
w.conn.Close()
w.conn = nil
}
w.mutex.Unlock()
})
return nil
}
// MultiWriter combines multiple writers
type MultiWriter struct {
writers []api.Writer
mutex sync.Mutex
}
// NewMultiWriter creates a new multi-writer
func NewMultiWriter() *MultiWriter {
return &MultiWriter{
writers: make([]api.Writer, 0),
}
}
// Write writes a log record to all writers
func (mw *MultiWriter) Write(rec api.LogRecord) error {
mw.mutex.Lock()
defer mw.mutex.Unlock()
var lastErr error
for _, w := range mw.writers {
if err := w.Write(rec); err != nil {
lastErr = err
}
}
return lastErr
}
// Add adds a writer to the multi-writer
func (mw *MultiWriter) Add(writer api.Writer) {
mw.mutex.Lock()
defer mw.mutex.Unlock()
mw.writers = append(mw.writers, writer)
}
// CloseAll closes all writers
func (mw *MultiWriter) CloseAll() error {
mw.mutex.Lock()
defer mw.mutex.Unlock()
var lastErr error
for _, w := range mw.writers {
if closer, ok := w.(io.Closer); ok {
if err := closer.Close(); err != nil {
lastErr = err
}
}
}
return lastErr
}
// Reopen reopens all writers that support log rotation
func (mw *MultiWriter) Reopen() error {
mw.mutex.Lock()
defer mw.mutex.Unlock()
var lastErr error
for _, w := range mw.writers {
if reopenable, ok := w.(api.Reopenable); ok {
if err := reopenable.Reopen(); err != nil {
lastErr = err
}
}
}
return lastErr
}
// BuilderImpl implements the api.Builder interface
type BuilderImpl struct {
errorCallback ErrorCallback
}
// NewBuilder creates a new output builder
func NewBuilder() *BuilderImpl {
return &BuilderImpl{}
}
// WithErrorCallback sets an error callback for all unix_socket and file writers created by this builder
func (b *BuilderImpl) WithErrorCallback(cb ErrorCallback) *BuilderImpl {
b.errorCallback = cb
return b
}
// NewFromConfig constructs writers from AppConfig
// Uses AsyncBuffer from OutputConfig if specified, otherwise uses DefaultQueueSize
func (b *BuilderImpl) NewFromConfig(cfg api.AppConfig) (api.Writer, error) {
multiWriter := NewMultiWriter()
for _, outputCfg := range cfg.Outputs {
if !outputCfg.Enabled {
continue
}
var writer api.Writer
var err error
// Determine queue size: use AsyncBuffer if specified, otherwise default
queueSize := DefaultQueueSize
if outputCfg.AsyncBuffer > 0 {
queueSize = outputCfg.AsyncBuffer
}
switch outputCfg.Type {
case "stdout":
writer = NewStdoutWriter()
case "file":
path := outputCfg.Params["path"]
if path == "" {
return nil, fmt.Errorf("file output requires 'path' parameter")
}
// Build options list for file writer
var fileOpts []FileWriterOption
if b.errorCallback != nil {
fileOpts = append(fileOpts, WithFileErrorCallback(b.errorCallback))
}
writer, err = NewFileWriterWithConfig(path, DefaultMaxFileSize, DefaultMaxBackups, fileOpts...)
if err != nil {
return nil, err
}
case "unix_socket":
socketPath := outputCfg.Params["socket_path"]
if socketPath == "" {
return nil, fmt.Errorf("unix_socket output requires 'socket_path' parameter")
}
// Build options list
var opts []UnixSocketWriterOption
if b.errorCallback != nil {
opts = append(opts, WithErrorCallback(b.errorCallback))
}
writer, err = NewUnixSocketWriterWithConfig(socketPath, DefaultDialTimeout, DefaultWriteTimeout, queueSize, opts...)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown output type: %s", outputCfg.Type)
}
multiWriter.Add(writer)
}
// If no outputs configured, default to stdout
if len(multiWriter.writers) == 0 {
multiWriter.Add(NewStdoutWriter())
}
return multiWriter, nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff