feat: ja4-platform monorepo — 5 services unified, tests & RPM builds standardized
Services: - ja4sentinel: TLS/JA4 fingerprint capture daemon (Go, libpcap) - logcorrelator: JA4 log correlation engine (Go, ClickHouse) - mod_reqin_log: Apache module (C, JSON request logging) - bot_detector: ML bot detection pipeline (Python) - dashboard: FastAPI/Streamlit analytics UI (Python) Shared libraries: - shared/go/ja4common: logger, config, shutdown, ipfilter (Go module) - shared/python/ja4_common: ClickHouseClient, ClickHouseSettings (Python package) - shared/clickhouse/: canonical SQL migrations (10 files) Build & packaging: - Unified 3-stage Dockerfile.package for Go RPMs (el8/el9/el10) - go.work workspace linking sentinel, correlator, ja4common - Makefile with test-all, build-all, rpm-* targets Fixes applied: - go.work: 1.21 → 1.24.6 (required by sentinel) - correlator Dockerfiles: golang:1.21 → golang:1.24 - replace directives in go.mod for ja4common local path - pyproject.toml: setuptools.backends → setuptools.build_meta - Removed static libpcap linking (unavailable on Rocky 9) - Fixed data races in output/writers_test.go (sync.Mutex + atomic.Int32) - Rewrote corrupted test files (logger_test.go × 2) Test coverage: - correlator: 67.1% total (unixsocket 80.5%, config 91.7%, app 83.3%, multi 87.7%, stdout 100%) - sentinel: all 10 packages pass (api, capture, config, fingerprint, ipfilter, logging, output, tlsparse) Documentation: - README.md + docs/ (architecture, development, 5 services, shared libs, DB schema & migrations) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@ -0,0 +1,376 @@
|
||||
package unixsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
"github.com/antitbone/ja4/correlator/internal/observability"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum datagram size for JSON logs (64KB - Unix datagram limit)
|
||||
MaxDatagramSize = 65535
|
||||
// Rate limit: max events per second
|
||||
MaxEventsPerSecond = 10000
|
||||
)
|
||||
|
||||
// Config holds the Unix socket source configuration.
|
||||
type Config struct {
|
||||
Name string
|
||||
Path string
|
||||
SourceType string // "A" for Apache/HTTP, "B" for Network, "" for auto-detect
|
||||
SocketPermissions os.FileMode
|
||||
}
|
||||
|
||||
// UnixSocketSource reads JSON events from a Unix datagram socket.
|
||||
type UnixSocketSource struct {
|
||||
config Config
|
||||
mu sync.Mutex
|
||||
conn *net.UnixConn
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
stopOnce sync.Once
|
||||
logger *observability.Logger
|
||||
}
|
||||
|
||||
// NewUnixSocketSource creates a new Unix socket source.
|
||||
func NewUnixSocketSource(config Config) *UnixSocketSource {
|
||||
return &UnixSocketSource{
|
||||
config: config,
|
||||
done: make(chan struct{}),
|
||||
logger: observability.NewLogger("unixsocket:" + config.Name),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogger sets the logger for the source (for debug mode).
|
||||
func (s *UnixSocketSource) SetLogger(logger *observability.Logger) {
|
||||
s.logger = logger.WithFields(map[string]any{"source": s.config.Name})
|
||||
}
|
||||
|
||||
// Name returns the source name.
|
||||
func (s *UnixSocketSource) Name() string {
|
||||
return s.config.Name
|
||||
}
|
||||
|
||||
// Start begins listening on the Unix datagram socket.
|
||||
func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error {
|
||||
if strings.TrimSpace(s.config.Path) == "" {
|
||||
return fmt.Errorf("socket path cannot be empty")
|
||||
}
|
||||
|
||||
// Create parent directory if it doesn't exist
|
||||
socketDir := filepath.Dir(s.config.Path)
|
||||
if err := os.MkdirAll(socketDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create socket directory %s: %w", socketDir, err)
|
||||
}
|
||||
|
||||
// Remove existing socket file if present
|
||||
if info, err := os.Stat(s.config.Path); err == nil {
|
||||
if info.Mode()&os.ModeSocket != 0 {
|
||||
if err := os.Remove(s.config.Path); err != nil {
|
||||
return fmt.Errorf("failed to remove existing socket: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("path exists but is not a socket: %s", s.config.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Create Unix datagram socket
|
||||
addr, err := net.ResolveUnixAddr("unixgram", s.config.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve unix socket address: %w", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenUnixgram("unixgram", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create unix datagram socket: %w", err)
|
||||
}
|
||||
s.conn = conn
|
||||
|
||||
// Set permissions - fail if we can't
|
||||
permissions := s.config.SocketPermissions
|
||||
if permissions == 0 {
|
||||
permissions = 0666 // default
|
||||
}
|
||||
if err := os.Chmod(s.config.Path, permissions); err != nil {
|
||||
_ = conn.Close()
|
||||
_ = os.Remove(s.config.Path)
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.readDatagrams(ctx, eventChan)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UnixSocketSource) readDatagrams(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) {
|
||||
buf := make([]byte, MaxDatagramSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Set read deadline to allow periodic context checks
|
||||
_ = s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
|
||||
n, _, err := s.conn.ReadFromUnix(buf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Read timeout, continue to check context
|
||||
continue
|
||||
}
|
||||
// Other errors (e.g., closed socket)
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.logger.Warnf("read error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
data := make([]byte, n)
|
||||
copy(data, buf[:n])
|
||||
|
||||
event, err := parseJSONEvent(data, s.config.SourceType)
|
||||
if err != nil {
|
||||
// Log parse errors with the raw data for debugging
|
||||
s.logger.Warnf("parse error: %v | raw: %s", err, string(data))
|
||||
continue
|
||||
}
|
||||
|
||||
// Debug: log raw events with all key details
|
||||
s.logger.Debugf("event received: source=%s src_ip=%s src_port=%d timestamp=%v raw_timestamp=%v",
|
||||
event.Source, event.SrcIP, event.SrcPort, event.Timestamp, event.Raw["timestamp"])
|
||||
|
||||
select {
|
||||
case eventChan <- event:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSource(sourceType string, headers map[string]string) domain.EventSource {
|
||||
switch strings.ToLower(strings.TrimSpace(sourceType)) {
|
||||
case "a", "apache", "http":
|
||||
return domain.SourceA
|
||||
case "b", "network", "net":
|
||||
return domain.SourceB
|
||||
default:
|
||||
// fallback compat
|
||||
if len(headers) > 0 {
|
||||
return domain.SourceA
|
||||
}
|
||||
return domain.SourceB
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, error) {
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
|
||||
event := &domain.NormalizedEvent{
|
||||
Raw: raw,
|
||||
Extra: make(map[string]any),
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
|
||||
// Extract headers (header_* fields) first
|
||||
for k, v := range raw {
|
||||
if strings.HasPrefix(k, "header_") {
|
||||
if sv, ok := v.(string); ok {
|
||||
event.Headers[k[7:]] = sv
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve source first (strict timestamp logic depends on source)
|
||||
event.Source = resolveSource(sourceType, event.Headers)
|
||||
|
||||
// Extract and validate src_ip
|
||||
if v, ok := getString(raw, "src_ip"); ok {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return nil, fmt.Errorf("src_ip cannot be empty")
|
||||
}
|
||||
event.SrcIP = v
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing required field: src_ip")
|
||||
}
|
||||
|
||||
// Extract and validate src_port
|
||||
if v, ok := getInt(raw, "src_port"); ok {
|
||||
if v < 1 || v > 65535 {
|
||||
return nil, fmt.Errorf("src_port must be between 1 and 65535, got %d", v)
|
||||
}
|
||||
event.SrcPort = v
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing required field: src_port")
|
||||
}
|
||||
|
||||
// Extract dst_ip (optional)
|
||||
if v, ok := getString(raw, "dst_ip"); ok {
|
||||
event.DstIP = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
// Extract dst_port (optional)
|
||||
if v, ok := getInt(raw, "dst_port"); ok {
|
||||
if v < 0 || v > 65535 {
|
||||
return nil, fmt.Errorf("dst_port must be between 0 and 65535, got %d", v)
|
||||
}
|
||||
event.DstPort = v
|
||||
}
|
||||
|
||||
// Extract timestamp based on source contract
|
||||
switch event.Source {
|
||||
case domain.SourceA:
|
||||
ts, ok := getInt64(raw, "timestamp")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required numeric field: timestamp for source A")
|
||||
}
|
||||
// Assume nanoseconds
|
||||
event.Timestamp = time.Unix(0, ts)
|
||||
case domain.SourceB:
|
||||
// For network source, try to use event timestamp if available,
|
||||
// fallback to reception time. This improves correlation accuracy
|
||||
// when network logs include their own timestamp (e.g., from packet capture).
|
||||
if ts, ok := getInt64(raw, "timestamp"); ok {
|
||||
event.Timestamp = time.Unix(0, ts)
|
||||
} else if timeStr, ok := getString(raw, "time"); ok {
|
||||
// Try RFC3339 format
|
||||
if t, err := time.Parse(time.RFC3339, timeStr); err == nil {
|
||||
event.Timestamp = t
|
||||
} else if t, err := time.Parse(time.RFC3339Nano, timeStr); err == nil {
|
||||
event.Timestamp = t
|
||||
} else {
|
||||
event.Timestamp = time.Now()
|
||||
}
|
||||
} else {
|
||||
event.Timestamp = time.Now()
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported source type: %s", event.Source)
|
||||
}
|
||||
|
||||
// Extra fields
|
||||
knownFields := map[string]bool{
|
||||
"src_ip": true, "src_port": true, "dst_ip": true, "dst_port": true,
|
||||
"timestamp": true, "time": true,
|
||||
}
|
||||
for k, v := range raw {
|
||||
if knownFields[k] {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(k, "header_") {
|
||||
continue
|
||||
}
|
||||
event.Extra[k] = v
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func getString(m map[string]any, key string) (string, bool) {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func getInt(m map[string]any, key string) (int, bool) {
|
||||
if v, ok := m[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
if math.Trunc(val) != val {
|
||||
return 0, false
|
||||
}
|
||||
return int(val), true
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case string:
|
||||
if i, err := strconv.Atoi(val); err == nil {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func getInt64(m map[string]any, key string) (int64, bool) {
|
||||
if v, ok := m[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
if math.Trunc(val) != val {
|
||||
return 0, false
|
||||
}
|
||||
return int64(val), true
|
||||
case int:
|
||||
return int64(val), true
|
||||
case int64:
|
||||
return val, true
|
||||
case string:
|
||||
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Stop gracefully stops the source.
|
||||
func (s *UnixSocketSource) Stop() error {
|
||||
var stopErr error
|
||||
|
||||
s.stopOnce.Do(func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
close(s.done)
|
||||
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
// Clean up socket file
|
||||
if err := os.Remove(s.config.Path); err != nil && !os.IsNotExist(err) {
|
||||
stopErr = fmt.Errorf("failed to remove socket file: %w", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
return stopErr
|
||||
}
|
||||
@ -0,0 +1,596 @@
|
||||
package unixsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
)
|
||||
|
||||
func TestParseJSONEvent_Apache(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"dst_ip": "10.0.0.1",
|
||||
"dst_port": 80,
|
||||
"timestamp": 1704110400000000000,
|
||||
"method": "GET",
|
||||
"path": "/api/test",
|
||||
"header_host": "example.com",
|
||||
"header_user_agent": "Mozilla/5.0"
|
||||
}`)
|
||||
|
||||
event, err := parseJSONEvent(data, "A")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if event.SrcIP != "192.168.1.1" {
|
||||
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
|
||||
}
|
||||
if event.SrcPort != 8080 {
|
||||
t.Errorf("expected src_port 8080, got %d", event.SrcPort)
|
||||
}
|
||||
if event.Headers["host"] != "example.com" {
|
||||
t.Errorf("expected header host example.com, got %s", event.Headers["host"])
|
||||
}
|
||||
if event.Headers["user_agent"] != "Mozilla/5.0" {
|
||||
t.Errorf("expected header_user_agent Mozilla/5.0, got %s", event.Headers["user_agent"])
|
||||
}
|
||||
if event.Source != domain.SourceA {
|
||||
t.Errorf("expected source A, got %s", event.Source)
|
||||
}
|
||||
expectedTs := time.Unix(0, 1704110400000000000)
|
||||
if !event.Timestamp.Equal(expectedTs) {
|
||||
t.Errorf("expected timestamp %v, got %v", expectedTs, event.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_Network(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"dst_ip": "10.0.0.1",
|
||||
"dst_port": 443,
|
||||
"timestamp": 1704110400000000000,
|
||||
"ja3": "abc123def456",
|
||||
"ja4": "xyz789",
|
||||
"tcp_meta_flags": "SYN"
|
||||
}`)
|
||||
|
||||
event, err := parseJSONEvent(data, "B")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if event.SrcIP != "192.168.1.1" {
|
||||
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
|
||||
}
|
||||
if event.Extra["ja3"] != "abc123def456" {
|
||||
t.Errorf("expected ja3 abc123def456, got %v", event.Extra["ja3"])
|
||||
}
|
||||
if event.Source != domain.SourceB {
|
||||
t.Errorf("expected source B, got %s", event.Source)
|
||||
}
|
||||
// Network source now uses payload timestamp if available
|
||||
expectedTs := time.Unix(0, 1704110400000000000)
|
||||
if !event.Timestamp.Equal(expectedTs) {
|
||||
t.Errorf("expected network timestamp %v, got %v", expectedTs, event.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_InvalidJSON(t *testing.T) {
|
||||
data := []byte(`{invalid json}`)
|
||||
|
||||
_, err := parseJSONEvent(data, "")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_MissingFields(t *testing.T) {
|
||||
data := []byte(`{"other_field": "value"}`)
|
||||
|
||||
_, err := parseJSONEvent(data, "")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing src_ip/src_port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_SourceARequiresNumericTimestamp(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"time": "2024-01-01T12:00:00Z"
|
||||
}`)
|
||||
|
||||
_, err := parseJSONEvent(data, "A")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for source A without numeric timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_SourceBUsesPayloadTimestamp(t *testing.T) {
|
||||
expectedTs := int64(1704110400000000000)
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"timestamp": 1704110400000000000
|
||||
}`)
|
||||
|
||||
event, err := parseJSONEvent(data, "B")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedTime := time.Unix(0, expectedTs)
|
||||
if !event.Timestamp.Equal(expectedTime) {
|
||||
t.Errorf("expected source B to use payload timestamp %v, got %v", expectedTime, event.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_SourceBUsesTimeField(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"time": "2024-01-01T12:00:00Z"
|
||||
}`)
|
||||
|
||||
event, err := parseJSONEvent(data, "B")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedTime := time.Unix(0, 1704110400000000000)
|
||||
if !event.Timestamp.Equal(expectedTime) {
|
||||
t.Errorf("expected source B to use time field %v, got %v", expectedTime, event.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_SourceBFallbackToNow(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080
|
||||
}`)
|
||||
|
||||
before := time.Now()
|
||||
event, err := parseJSONEvent(data, "B")
|
||||
after := time.Now()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if event.Timestamp.Before(before.Add(-2*time.Second)) || event.Timestamp.After(after.Add(2*time.Second)) {
|
||||
t.Errorf("expected source B timestamp near now, got %v", event.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_ExplicitSourceType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
sourceType string
|
||||
expected domain.EventSource
|
||||
}{
|
||||
{
|
||||
name: "explicit A",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000}`,
|
||||
sourceType: "A",
|
||||
expected: domain.SourceA,
|
||||
},
|
||||
{
|
||||
name: "explicit B",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`,
|
||||
sourceType: "B",
|
||||
expected: domain.SourceB,
|
||||
},
|
||||
{
|
||||
name: "explicit apache",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000}`,
|
||||
sourceType: "apache",
|
||||
expected: domain.SourceA,
|
||||
},
|
||||
{
|
||||
name: "explicit network",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`,
|
||||
sourceType: "network",
|
||||
expected: domain.SourceB,
|
||||
},
|
||||
{
|
||||
name: "auto-detect A with headers",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000, "header_host": "example.com"}`,
|
||||
sourceType: "",
|
||||
expected: domain.SourceA,
|
||||
},
|
||||
{
|
||||
name: "auto-detect B without headers",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "abc"}`,
|
||||
sourceType: "",
|
||||
expected: domain.SourceB,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event, err := parseJSONEvent([]byte(tt.data), tt.sourceType)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if event.Source != tt.expected {
|
||||
t.Errorf("expected source %s, got %s", tt.expected, event.Source)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_Name(t *testing.T) {
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_source",
|
||||
Path: "/tmp/test.sock",
|
||||
})
|
||||
|
||||
if source.Name() != "test_source" {
|
||||
t.Errorf("expected name 'test_source', got %s", source.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_StopWithoutStart(t *testing.T) {
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_source",
|
||||
Path: "/tmp/test.sock",
|
||||
})
|
||||
|
||||
// Should not panic
|
||||
err := source.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on stop without start, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_EmptyPath(t *testing.T) {
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_source",
|
||||
Path: "",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
eventChan := make(chan *domain.NormalizedEvent, 10)
|
||||
|
||||
err := source.Start(ctx, eventChan)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetString(t *testing.T) {
|
||||
m := map[string]any{
|
||||
"string": "hello",
|
||||
"int": 42,
|
||||
"nil": nil,
|
||||
}
|
||||
|
||||
v, ok := getString(m, "string")
|
||||
if !ok || v != "hello" {
|
||||
t.Errorf("expected 'hello', got %v, %v", v, ok)
|
||||
}
|
||||
|
||||
_, ok = getString(m, "int")
|
||||
if ok {
|
||||
t.Error("expected false for int")
|
||||
}
|
||||
|
||||
_, ok = getString(m, "missing")
|
||||
if ok {
|
||||
t.Error("expected false for missing key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInt(t *testing.T) {
|
||||
m := map[string]any{
|
||||
"float": 42.5,
|
||||
"int": 42,
|
||||
"int64": int64(42),
|
||||
"string": "42",
|
||||
"bad": "not a number",
|
||||
"nil": nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
key string
|
||||
expected int
|
||||
ok bool
|
||||
}{
|
||||
{"float", 0, false},
|
||||
{"int", 42, true},
|
||||
{"int64", 42, true},
|
||||
{"string", 42, true},
|
||||
{"bad", 0, false},
|
||||
{"nil", 0, false},
|
||||
{"missing", 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
v, ok := getInt(m, tt.key)
|
||||
if ok != tt.ok {
|
||||
t.Errorf("getInt(%q) ok = %v, want %v", tt.key, ok, tt.ok)
|
||||
}
|
||||
if v != tt.expected {
|
||||
t.Errorf("getInt(%q) = %v, want %v", tt.key, v, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInt64(t *testing.T) {
|
||||
m := map[string]any{
|
||||
"float": 42.5,
|
||||
"int": 42,
|
||||
"int64": int64(42),
|
||||
"string": "42",
|
||||
"bad": "not a number",
|
||||
"nil": nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
key string
|
||||
expected int64
|
||||
ok bool
|
||||
}{
|
||||
{"float", 0, false},
|
||||
{"int", 42, true},
|
||||
{"int64", 42, true},
|
||||
{"string", 42, true},
|
||||
{"bad", 0, false},
|
||||
{"nil", 0, false},
|
||||
{"missing", 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
v, ok := getInt64(m, tt.key)
|
||||
if ok != tt.ok {
|
||||
t.Errorf("getInt64(%q) ok = %v, want %v", tt.key, ok, tt.ok)
|
||||
}
|
||||
if v != tt.expected {
|
||||
t.Errorf("getInt64(%q) = %v, want %v", tt.key, v, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_PortValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
sourceType string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid src_port",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`,
|
||||
sourceType: "B",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "src_port zero",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 0}`,
|
||||
sourceType: "B",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "src_port negative",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": -1}`,
|
||||
sourceType: "B",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "src_port too high",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 70000}`,
|
||||
sourceType: "B",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid dst_port zero",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 0}`,
|
||||
sourceType: "B",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "dst_port too high",
|
||||
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 70000}`,
|
||||
sourceType: "B",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := parseJSONEvent([]byte(tt.data), tt.sourceType)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseJSONEvent() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_TimestampFallback(t *testing.T) {
|
||||
data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080}`)
|
||||
event, err := parseJSONEvent(data, "B")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// For source B, timestamp is reception time
|
||||
if event.Timestamp.IsZero() {
|
||||
t.Error("expected non-zero timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_StartStopDatagram(t *testing.T) {
|
||||
tmpPath := "/tmp/test_logcorrelator_datagram.sock"
|
||||
// Clean up any existing socket
|
||||
os.Remove(tmpPath)
|
||||
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_datagram",
|
||||
Path: tmpPath,
|
||||
SourceType: "B",
|
||||
SocketPermissions: 0666,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
eventChan := make(chan *domain.NormalizedEvent, 10)
|
||||
|
||||
err := source.Start(ctx, eventChan)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start source: %v", err)
|
||||
}
|
||||
|
||||
// Give socket time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify socket file exists
|
||||
if _, err := os.Stat(tmpPath); os.IsNotExist(err) {
|
||||
t.Error("socket file should exist")
|
||||
}
|
||||
|
||||
// Stop the source
|
||||
err = source.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("failed to stop source: %v", err)
|
||||
}
|
||||
|
||||
// Socket file should be cleaned up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
|
||||
t.Error("socket file should be removed after stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_SendDatagram(t *testing.T) {
|
||||
tmpPath := "/tmp/test_logcorrelator_send.sock"
|
||||
os.Remove(tmpPath)
|
||||
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_send",
|
||||
Path: tmpPath,
|
||||
SourceType: "B",
|
||||
SocketPermissions: 0666,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
eventChan := make(chan *domain.NormalizedEvent, 10)
|
||||
|
||||
err := source.Start(ctx, eventChan)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start source: %v", err)
|
||||
}
|
||||
|
||||
// Give socket time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect and send a datagram
|
||||
conn, err := net.Dial("unixgram", tmpPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial socket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "test"}`)
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
// Wait for event
|
||||
select {
|
||||
case event := <-eventChan:
|
||||
if event.SrcIP != "192.168.1.1" {
|
||||
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
|
||||
}
|
||||
if event.SrcPort != 8080 {
|
||||
t.Errorf("expected src_port 8080, got %d", event.SrcPort)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("timeout waiting for event")
|
||||
case <-ctx.Done():
|
||||
t.Error("context cancelled")
|
||||
}
|
||||
|
||||
err = source.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("failed to stop source: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_MultipleDatagrams(t *testing.T) {
|
||||
tmpPath := "/tmp/test_logcorrelator_multi.sock"
|
||||
os.Remove(tmpPath)
|
||||
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_multi",
|
||||
Path: tmpPath,
|
||||
SourceType: "B",
|
||||
SocketPermissions: 0666,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
eventChan := make(chan *domain.NormalizedEvent, 100)
|
||||
|
||||
err := source.Start(ctx, eventChan)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start source: %v", err)
|
||||
}
|
||||
|
||||
// Give socket time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect and send multiple datagrams
|
||||
conn, err := net.Dial("unixgram", tmpPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial socket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"src_ip": "192.168.1.%d", "src_port": %d, "ja3": "test%d"}`, i+1, 8080+i, i))
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write datagram %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all events
|
||||
received := 0
|
||||
timeout := time.After(3 * time.Second)
|
||||
for received < 5 {
|
||||
select {
|
||||
case event := <-eventChan:
|
||||
received++
|
||||
t.Logf("received event %d: src_ip=%s", received, event.SrcIP)
|
||||
case <-timeout:
|
||||
t.Errorf("timeout waiting for events, received %d/5", received)
|
||||
goto done
|
||||
case <-ctx.Done():
|
||||
t.Error("context cancelled")
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
err = source.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("failed to stop source: %v", err)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,391 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ClickHouse/clickhouse-go/v2"
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
"github.com/antitbone/ja4/correlator/internal/observability"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBatchSize is the default number of records per batch
|
||||
DefaultBatchSize = 500
|
||||
// DefaultFlushIntervalMs is the default flush interval in milliseconds
|
||||
DefaultFlushIntervalMs = 200
|
||||
// DefaultMaxBufferSize is the default maximum buffer size
|
||||
DefaultMaxBufferSize = 5000
|
||||
// DefaultTimeoutMs is the default timeout for operations in milliseconds
|
||||
DefaultTimeoutMs = 1000
|
||||
// DefaultPingTimeoutMs is the timeout for initial connection ping
|
||||
DefaultPingTimeoutMs = 5000
|
||||
// MaxRetries is the maximum number of retry attempts for failed inserts
|
||||
MaxRetries = 3
|
||||
// RetryBaseDelay is the base delay between retries
|
||||
RetryBaseDelay = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// Config holds the ClickHouse sink configuration.
|
||||
type Config struct {
|
||||
DSN string
|
||||
Table string
|
||||
BatchSize int
|
||||
FlushIntervalMs int
|
||||
MaxBufferSize int
|
||||
DropOnOverflow bool
|
||||
AsyncInsert bool
|
||||
TimeoutMs int
|
||||
}
|
||||
|
||||
// ClickHouseSink writes correlated logs to ClickHouse.
|
||||
type ClickHouseSink struct {
|
||||
config Config
|
||||
conn clickhouse.Conn
|
||||
mu sync.Mutex
|
||||
buffer []domain.CorrelatedLog
|
||||
flushChan chan struct{}
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
logger *observability.Logger
|
||||
}
|
||||
|
||||
// SetLogger sets the logger used by the sink.
|
||||
func (s *ClickHouseSink) SetLogger(logger *observability.Logger) {
|
||||
s.logger = logger.WithFields(map[string]any{"sink": "clickhouse"})
|
||||
}
|
||||
|
||||
// NewClickHouseSink creates a new ClickHouse sink.
|
||||
func NewClickHouseSink(config Config) (*ClickHouseSink, error) {
|
||||
if strings.TrimSpace(config.DSN) == "" {
|
||||
return nil, fmt.Errorf("clickhouse DSN is required")
|
||||
}
|
||||
if strings.TrimSpace(config.Table) == "" {
|
||||
return nil, fmt.Errorf("clickhouse table is required")
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if config.BatchSize <= 0 {
|
||||
config.BatchSize = DefaultBatchSize
|
||||
}
|
||||
if config.FlushIntervalMs <= 0 {
|
||||
config.FlushIntervalMs = DefaultFlushIntervalMs
|
||||
}
|
||||
if config.MaxBufferSize <= 0 {
|
||||
config.MaxBufferSize = DefaultMaxBufferSize
|
||||
}
|
||||
if config.TimeoutMs <= 0 {
|
||||
config.TimeoutMs = DefaultTimeoutMs
|
||||
}
|
||||
|
||||
s := &ClickHouseSink{
|
||||
config: config,
|
||||
buffer: make([]domain.CorrelatedLog, 0, config.BatchSize),
|
||||
flushChan: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
logger: observability.NewLogger("clickhouse"),
|
||||
}
|
||||
|
||||
// Parse DSN and create options
|
||||
options, err := clickhouse.ParseDSN(config.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ClickHouse DSN: %w", err)
|
||||
}
|
||||
|
||||
// Connect to ClickHouse using native API
|
||||
conn, err := clickhouse.Open(options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to ClickHouse: %w", err)
|
||||
}
|
||||
|
||||
// Ping with timeout to verify connection
|
||||
pingCtx, pingCancel := context.WithTimeout(context.Background(), time.Duration(DefaultPingTimeoutMs)*time.Millisecond)
|
||||
defer pingCancel()
|
||||
|
||||
if err := conn.Ping(pingCtx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("failed to ping ClickHouse: %w", err)
|
||||
}
|
||||
|
||||
s.conn = conn
|
||||
s.log().Infof("connected to ClickHouse: table=%s batch_size=%d flush_interval_ms=%d",
|
||||
config.Table, config.BatchSize, config.FlushIntervalMs)
|
||||
|
||||
// Start flush goroutine
|
||||
s.wg.Add(1)
|
||||
go s.flushLoop()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Name returns the sink name.
|
||||
func (s *ClickHouseSink) Name() string {
|
||||
return "clickhouse"
|
||||
}
|
||||
|
||||
// log returns the logger, initializing a default one if not set (e.g. in tests).
|
||||
func (s *ClickHouseSink) log() *observability.Logger {
|
||||
if s.logger == nil {
|
||||
s.logger = observability.NewLogger("clickhouse")
|
||||
}
|
||||
return s.logger
|
||||
}
|
||||
|
||||
// Reopen is a no-op for ClickHouse (connection is managed internally).
|
||||
func (s *ClickHouseSink) Reopen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write adds a log to the buffer.
|
||||
func (s *ClickHouseSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
deadline := time.Now().Add(time.Duration(s.config.TimeoutMs) * time.Millisecond)
|
||||
|
||||
for {
|
||||
s.mu.Lock()
|
||||
if len(s.buffer) < s.config.MaxBufferSize {
|
||||
s.buffer = append(s.buffer, log)
|
||||
if len(s.buffer) >= s.config.BatchSize {
|
||||
select {
|
||||
case s.flushChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
drop := s.config.DropOnOverflow
|
||||
s.mu.Unlock()
|
||||
|
||||
if drop {
|
||||
s.log().Warnf("buffer full, dropping log: table=%s buffer_size=%d", s.config.Table, s.config.MaxBufferSize)
|
||||
return nil
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return fmt.Errorf("buffer full, timeout exceeded")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush flushes the buffer to ClickHouse.
|
||||
func (s *ClickHouseSink) Flush(ctx context.Context) error {
|
||||
return s.doFlush(ctx)
|
||||
}
|
||||
|
||||
// Close closes the sink.
|
||||
func (s *ClickHouseSink) Close() error {
|
||||
var closeErr error
|
||||
|
||||
s.closeOnce.Do(func() {
|
||||
if s.done != nil {
|
||||
close(s.done)
|
||||
}
|
||||
s.wg.Wait()
|
||||
|
||||
flushCtx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := s.doFlush(flushCtx); err != nil {
|
||||
closeErr = err
|
||||
}
|
||||
|
||||
if s.conn != nil {
|
||||
if err := s.conn.Close(); err != nil && closeErr == nil {
|
||||
closeErr = err
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func (s *ClickHouseSink) flushLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.config.FlushIntervalMs) * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond)
|
||||
if err := s.doFlush(ctx); err != nil {
|
||||
s.log().Error("final flush on close failed", err)
|
||||
}
|
||||
cancel()
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
needsFlush := len(s.buffer) > 0
|
||||
s.mu.Unlock()
|
||||
|
||||
if needsFlush {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond)
|
||||
if err := s.doFlush(ctx); err != nil {
|
||||
s.log().Error("periodic flush failed", err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
|
||||
case <-s.flushChan:
|
||||
s.mu.Lock()
|
||||
needsFlush := len(s.buffer) >= s.config.BatchSize
|
||||
s.mu.Unlock()
|
||||
|
||||
if needsFlush {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond)
|
||||
if err := s.doFlush(ctx); err != nil {
|
||||
s.log().Error("batch flush failed", err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClickHouseSink) doFlush(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
if len(s.buffer) == 0 {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy buffer to flush
|
||||
buffer := make([]domain.CorrelatedLog, len(s.buffer))
|
||||
copy(buffer, s.buffer)
|
||||
s.buffer = make([]domain.CorrelatedLog, 0, s.config.BatchSize)
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("clickhouse connection is not initialized")
|
||||
}
|
||||
|
||||
batchSize := len(buffer)
|
||||
|
||||
// Retry logic with exponential backoff
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
delay := RetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||||
s.log().Warnf("retrying batch insert: attempt=%d/%d delay=%s rows=%d err=%v",
|
||||
attempt+1, MaxRetries, delay, batchSize, lastErr)
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = s.executeBatch(ctx, buffer)
|
||||
if lastErr == nil {
|
||||
s.log().Debugf("batch sent: rows=%d table=%s", batchSize, s.config.Table)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isRetryableError(lastErr) {
|
||||
return fmt.Errorf("non-retryable error: %w", lastErr)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed after %d retries (batch size: %d): %w", MaxRetries, batchSize, lastErr)
|
||||
}
|
||||
|
||||
func (s *ClickHouseSink) executeBatch(ctx context.Context, buffer []domain.CorrelatedLog) error {
|
||||
if s.conn == nil {
|
||||
return fmt.Errorf("clickhouse connection is not initialized")
|
||||
}
|
||||
|
||||
// Table schema: http_logs_raw (raw_json String)
|
||||
// Single column insert - the entire log is serialized as JSON string
|
||||
query := fmt.Sprintf(`INSERT INTO %s (raw_json)`, s.config.Table)
|
||||
|
||||
// Prepare batch using native clickhouse-go/v2 API
|
||||
batch, err := s.conn.PrepareBatch(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare batch: %w", err)
|
||||
}
|
||||
|
||||
for i, log := range buffer {
|
||||
// Marshal the entire CorrelatedLog to JSON
|
||||
logJSON, marshalErr := json.Marshal(log)
|
||||
if marshalErr != nil {
|
||||
return fmt.Errorf("failed to marshal log %d to JSON: %w", i, marshalErr)
|
||||
}
|
||||
|
||||
// Append the JSON string as the raw_json column value
|
||||
appendErr := batch.Append(string(logJSON))
|
||||
if appendErr != nil {
|
||||
return fmt.Errorf("failed to append log %d to batch: %w", i, appendErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Send the batch - DO NOT FORGET this step
|
||||
sendErr := batch.Send()
|
||||
if sendErr != nil {
|
||||
return fmt.Errorf("failed to send batch (%d rows): %w", len(buffer), sendErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error is retryable.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
if netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
|
||||
// Explicit non-retryable SQL/schema errors
|
||||
if strings.Contains(errStr, "syntax error") ||
|
||||
strings.Contains(errStr, "unknown table") ||
|
||||
strings.Contains(errStr, "unknown column") ||
|
||||
(strings.Contains(errStr, "table") && strings.Contains(errStr, "not found")) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fallback network/transient errors
|
||||
retryableErrors := []string{
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network is unreachable",
|
||||
"broken pipe",
|
||||
"no route to host",
|
||||
}
|
||||
for _, re := range retryableErrors {
|
||||
if strings.Contains(errStr, re) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@ -0,0 +1,538 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
"github.com/antitbone/ja4/correlator/internal/observability"
|
||||
)
|
||||
|
||||
func TestClickHouseSink_Name(t *testing.T) {
|
||||
sink := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
}
|
||||
|
||||
if sink.Name() != "clickhouse" {
|
||||
t.Errorf("expected name 'clickhouse', got %s", sink.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_ConfigDefaults(t *testing.T) {
|
||||
// Test that defaults are applied correctly
|
||||
config := Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
// Other fields are zero, should get defaults
|
||||
}
|
||||
|
||||
// Verify defaults would be applied (we can't actually connect in tests)
|
||||
if config.BatchSize <= 0 {
|
||||
config.BatchSize = DefaultBatchSize
|
||||
}
|
||||
if config.FlushIntervalMs <= 0 {
|
||||
config.FlushIntervalMs = DefaultFlushIntervalMs
|
||||
}
|
||||
if config.MaxBufferSize <= 0 {
|
||||
config.MaxBufferSize = DefaultMaxBufferSize
|
||||
}
|
||||
if config.TimeoutMs <= 0 {
|
||||
config.TimeoutMs = DefaultTimeoutMs
|
||||
}
|
||||
|
||||
if config.BatchSize != DefaultBatchSize {
|
||||
t.Errorf("expected BatchSize %d, got %d", DefaultBatchSize, config.BatchSize)
|
||||
}
|
||||
if config.FlushIntervalMs != DefaultFlushIntervalMs {
|
||||
t.Errorf("expected FlushIntervalMs %d, got %d", DefaultFlushIntervalMs, config.FlushIntervalMs)
|
||||
}
|
||||
if config.MaxBufferSize != DefaultMaxBufferSize {
|
||||
t.Errorf("expected MaxBufferSize %d, got %d", DefaultMaxBufferSize, config.MaxBufferSize)
|
||||
}
|
||||
if config.TimeoutMs != DefaultTimeoutMs {
|
||||
t.Errorf("expected TimeoutMs %d, got %d", DefaultTimeoutMs, config.TimeoutMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_Write_BufferOverflow(t *testing.T) {
|
||||
// This test verifies the buffer overflow logic without actually connecting
|
||||
config := Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
BatchSize: 10,
|
||||
MaxBufferSize: 10,
|
||||
DropOnOverflow: true,
|
||||
TimeoutMs: 100,
|
||||
FlushIntervalMs: 1000,
|
||||
}
|
||||
|
||||
// We can't test actual writes without a ClickHouse instance,
|
||||
// but we can verify the config is valid
|
||||
if config.BatchSize > config.MaxBufferSize {
|
||||
t.Error("BatchSize should not exceed MaxBufferSize")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_IsRetryableError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{"nil error", nil, false},
|
||||
{"connection refused", &mockError{"connection refused"}, true},
|
||||
{"connection reset", &mockError{"connection reset by peer"}, true},
|
||||
{"timeout", &mockError{"timeout waiting for response"}, true},
|
||||
{"network unreachable", &mockError{"network is unreachable"}, true},
|
||||
{"broken pipe", &mockError{"broken pipe"}, true},
|
||||
{"syntax error", &mockError{"syntax error in SQL"}, false},
|
||||
{"table not found", &mockError{"table test not found"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isRetryableError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_FlushEmpty(t *testing.T) {
|
||||
// Test that flushing an empty buffer doesn't cause issues
|
||||
// (We can't test actual ClickHouse operations without a real instance)
|
||||
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
}
|
||||
|
||||
// Should not panic or error on empty flush
|
||||
ctx := context.Background()
|
||||
err := s.Flush(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on empty flush, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_CloseWithoutConnect(t *testing.T) {
|
||||
// Test that closing without connecting doesn't panic
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
err := s.Close()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on close without connect, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_Constants(t *testing.T) {
|
||||
// Verify constants have reasonable values
|
||||
if DefaultBatchSize <= 0 {
|
||||
t.Error("DefaultBatchSize should be positive")
|
||||
}
|
||||
if DefaultFlushIntervalMs <= 0 {
|
||||
t.Error("DefaultFlushIntervalMs should be positive")
|
||||
}
|
||||
if DefaultMaxBufferSize <= 0 {
|
||||
t.Error("DefaultMaxBufferSize should be positive")
|
||||
}
|
||||
if DefaultTimeoutMs <= 0 {
|
||||
t.Error("DefaultTimeoutMs should be positive")
|
||||
}
|
||||
if DefaultPingTimeoutMs <= 0 {
|
||||
t.Error("DefaultPingTimeoutMs should be positive")
|
||||
}
|
||||
if MaxRetries <= 0 {
|
||||
t.Error("MaxRetries should be positive")
|
||||
}
|
||||
if RetryBaseDelay <= 0 {
|
||||
t.Error("RetryBaseDelay should be positive")
|
||||
}
|
||||
}
|
||||
|
||||
// mockError implements error for testing
|
||||
type mockError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *mockError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
// Test the doFlush function with empty buffer (no actual DB connection)
|
||||
func TestClickHouseSink_DoFlushEmpty(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := s.doFlush(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error when flushing empty buffer, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that buffer is properly managed (without actual DB operations)
|
||||
func TestClickHouseSink_BufferManagement(t *testing.T) {
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Correlated: true,
|
||||
}
|
||||
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 100, // Allow more than 1 element
|
||||
DropOnOverflow: false,
|
||||
TimeoutMs: 1000,
|
||||
},
|
||||
buffer: []domain.CorrelatedLog{log},
|
||||
}
|
||||
|
||||
// Verify buffer has data
|
||||
if len(s.buffer) != 1 {
|
||||
t.Fatalf("expected buffer length 1, got %d", len(s.buffer))
|
||||
}
|
||||
|
||||
// Test that Write properly adds to buffer
|
||||
ctx := context.Background()
|
||||
err := s.Write(ctx, log)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error on Write: %v", err)
|
||||
}
|
||||
|
||||
if len(s.buffer) != 2 {
|
||||
t.Errorf("expected buffer length 2 after Write, got %d", len(s.buffer))
|
||||
}
|
||||
}
|
||||
|
||||
// Test Write with context cancellation
|
||||
func TestClickHouseSink_Write_ContextCancel(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 1,
|
||||
DropOnOverflow: false,
|
||||
TimeoutMs: 10,
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0, 1),
|
||||
}
|
||||
|
||||
// Fill the buffer
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
s.buffer = append(s.buffer, log)
|
||||
|
||||
// Try to write with cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
err := s.Write(ctx, log)
|
||||
if err == nil {
|
||||
t.Error("expected error when writing with cancelled context")
|
||||
}
|
||||
}
|
||||
|
||||
// Test DropOnOverflow behavior
|
||||
func TestClickHouseSink_Write_DropOnOverflow(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 1,
|
||||
DropOnOverflow: true,
|
||||
TimeoutMs: 10,
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0, 1),
|
||||
}
|
||||
|
||||
// Fill the buffer
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
s.buffer = append(s.buffer, log)
|
||||
|
||||
// Try to write when buffer is full - should drop silently
|
||||
ctx := context.Background()
|
||||
err := s.Write(ctx, log)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error when DropOnOverflow is true, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_ContextDeadlineExceeded tests context.DeadlineExceeded is retryable.
|
||||
func TestIsRetryableError_ContextDeadlineExceeded(t *testing.T) {
|
||||
if !isRetryableError(context.DeadlineExceeded) {
|
||||
t.Error("context.DeadlineExceeded should be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_ContextCanceled tests context.Canceled is NOT retryable.
|
||||
func TestIsRetryableError_ContextCanceled(t *testing.T) {
|
||||
if isRetryableError(context.Canceled) {
|
||||
t.Error("context.Canceled should not be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_NetTimeout tests net.Error with Timeout() = true is retryable.
|
||||
func TestIsRetryableError_NetTimeout(t *testing.T) {
|
||||
err := &mockNetError{timeout: true, temporary: false}
|
||||
if !isRetryableError(err) {
|
||||
t.Error("net.Error with Timeout()=true should be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_NetNoTimeout tests net.Error with Timeout() = false is NOT retryable.
|
||||
func TestIsRetryableError_NetNoTimeout(t *testing.T) {
|
||||
err := &mockNetError{timeout: false, temporary: false}
|
||||
if isRetryableError(err) {
|
||||
t.Error("net.Error with Timeout()=false should not be retryable (unless msg matches)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_UnknownTable tests "unknown table" is NOT retryable.
|
||||
func TestIsRetryableError_UnknownTable(t *testing.T) {
|
||||
if isRetryableError(&mockError{"unknown table users"}) {
|
||||
t.Error("unknown table error should not be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_UnknownColumn tests "unknown column" is NOT retryable.
|
||||
func TestIsRetryableError_UnknownColumn(t *testing.T) {
|
||||
if isRetryableError(&mockError{"unknown column foo"}) {
|
||||
t.Error("unknown column error should not be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_RandomError tests a random error is NOT retryable.
|
||||
func TestIsRetryableError_RandomError(t *testing.T) {
|
||||
if isRetryableError(&mockError{"some random unrecognized error"}) {
|
||||
t.Error("random error should not be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_NoRouteToHost tests "no route to host" is retryable.
|
||||
func TestIsRetryableError_NoRouteToHost(t *testing.T) {
|
||||
if !isRetryableError(&mockError{"no route to host"}) {
|
||||
t.Error("'no route to host' should be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRetryableError_TemporaryFailure tests "temporary failure" is retryable.
|
||||
func TestIsRetryableError_TemporaryFailure(t *testing.T) {
|
||||
if !isRetryableError(&mockError{"temporary failure in name resolution"}) {
|
||||
t.Error("'temporary failure' should be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
// mockNetError implements net.Error for testing.
|
||||
type mockNetError struct {
|
||||
timeout bool
|
||||
temporary bool
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *mockNetError) Error() string { return e.msg }
|
||||
func (e *mockNetError) Timeout() bool { return e.timeout }
|
||||
func (e *mockNetError) Temporary() bool { return e.temporary }
|
||||
|
||||
// TestNewClickHouseSink_EmptyDSN tests that empty DSN returns error.
|
||||
func TestNewClickHouseSink_EmptyDSN(t *testing.T) {
|
||||
_, err := NewClickHouseSink(Config{
|
||||
DSN: "",
|
||||
Table: "test_table",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty DSN")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClickHouseSink_WhitespaceDSN tests that whitespace DSN returns error.
|
||||
func TestNewClickHouseSink_WhitespaceDSN(t *testing.T) {
|
||||
_, err := NewClickHouseSink(Config{
|
||||
DSN: " ",
|
||||
Table: "test_table",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for whitespace-only DSN")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClickHouseSink_EmptyTable tests that empty Table returns error.
|
||||
func TestNewClickHouseSink_EmptyTable(t *testing.T) {
|
||||
_, err := NewClickHouseSink(Config{
|
||||
DSN: "clickhouse://localhost:9000/test",
|
||||
Table: "",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty Table")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClickHouseSink_WhitespaceTable tests that whitespace Table returns error.
|
||||
func TestNewClickHouseSink_WhitespaceTable(t *testing.T) {
|
||||
_, err := NewClickHouseSink(Config{
|
||||
DSN: "clickhouse://localhost:9000/test",
|
||||
Table: " ",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for whitespace-only Table")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClickHouseSink_InvalidDSN tests that an invalid DSN (no real connection) returns error.
|
||||
func TestNewClickHouseSink_InvalidDSN(t *testing.T) {
|
||||
_, err := NewClickHouseSink(Config{
|
||||
DSN: "not-a-valid-dsn",
|
||||
Table: "test_table",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid DSN")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClickHouseSink_SetLogger tests that SetLogger sets a logger.
|
||||
func TestClickHouseSink_SetLogger(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{Table: "test_table"},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
}
|
||||
|
||||
testLogger := observability.NewLogger("test")
|
||||
s.SetLogger(testLogger)
|
||||
|
||||
if s.logger == nil {
|
||||
t.Error("expected logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClickHouseSink_LogNilLogger tests that log() returns a logger even when s.logger is nil.
|
||||
func TestClickHouseSink_LogNilLogger(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{Table: "test_table"},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
}
|
||||
s.logger = nil
|
||||
|
||||
// log() should auto-initialize
|
||||
logger := s.log()
|
||||
if logger == nil {
|
||||
t.Error("expected non-nil logger from log()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClickHouseSink_Reopen tests that Reopen is a no-op and returns nil.
|
||||
func TestClickHouseSink_Reopen(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{Table: "test_table"},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
}
|
||||
if err := s.Reopen(); err != nil {
|
||||
t.Errorf("Reopen() should return nil, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClickHouseSink_DoFlushNilConn tests doFlush returns error when conn is nil and buffer non-empty.
|
||||
func TestClickHouseSink_DoFlushNilConn(t *testing.T) {
|
||||
log := domain.CorrelatedLog{SrcIP: "1.2.3.4", SrcPort: 1234}
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
Table: "test_table",
|
||||
BatchSize: DefaultBatchSize,
|
||||
},
|
||||
buffer: []domain.CorrelatedLog{log},
|
||||
conn: nil,
|
||||
}
|
||||
|
||||
err := s.doFlush(context.Background())
|
||||
if err == nil {
|
||||
t.Error("expected error from doFlush when conn is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClickHouseSink_CloseTwice tests that calling Close() twice does not panic or error.
|
||||
func TestClickHouseSink_CloseTwice(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
Table: "test_table",
|
||||
TimeoutMs: DefaultTimeoutMs,
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("first Close() should not error, got: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("second Close() should not error (closeOnce), got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClickHouseSink_WriteTimeout tests that Write returns error when buffer is full and timeout exceeded.
|
||||
func TestClickHouseSink_Write_Timeout(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 1,
|
||||
DropOnOverflow: false,
|
||||
TimeoutMs: 1, // 1ms timeout
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0, 1),
|
||||
}
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "1.2.3.4", SrcPort: 1234}
|
||||
// Fill the buffer
|
||||
s.buffer = append(s.buffer, log)
|
||||
|
||||
ctx := context.Background()
|
||||
err := s.Write(ctx, log)
|
||||
if err == nil {
|
||||
t.Error("expected error when buffer full and timeout exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark Write operation (without actual DB)
|
||||
func BenchmarkClickHouseSink_Write(b *testing.B) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 10000,
|
||||
DropOnOverflow: true,
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0, 10000),
|
||||
}
|
||||
|
||||
log := domain.CorrelatedLog{
|
||||
Timestamp: time.Now(),
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Correlated: true,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Write(ctx, log)
|
||||
}
|
||||
}
|
||||
191
services/correlator/internal/adapters/outbound/file/sink.go
Normal file
191
services/correlator/internal/adapters/outbound/file/sink.go
Normal file
@ -0,0 +1,191 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultFilePermissions for output files
|
||||
DefaultFilePermissions os.FileMode = 0644
|
||||
// DefaultDirPermissions for output directories
|
||||
DefaultDirPermissions os.FileMode = 0750
|
||||
)
|
||||
|
||||
// Config holds the file sink configuration.
|
||||
type Config struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
// FileSink writes correlated logs to a file as JSON lines.
|
||||
type FileSink struct {
|
||||
config Config
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
}
|
||||
|
||||
// NewFileSink creates a new file sink.
|
||||
func NewFileSink(config Config) (*FileSink, error) {
|
||||
// Validate path
|
||||
if err := validateFilePath(config.Path); err != nil {
|
||||
return nil, fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
|
||||
s := &FileSink{
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Open file on creation
|
||||
if err := s.openFile(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Name returns the sink name.
|
||||
func (s *FileSink) Name() string {
|
||||
return "file"
|
||||
}
|
||||
|
||||
// Reopen closes and reopens the file (for log rotation on SIGHUP).
|
||||
func (s *FileSink) Reopen() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.file != nil {
|
||||
if err := s.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.openFile()
|
||||
}
|
||||
|
||||
// Write writes a correlated log to the file.
|
||||
func (s *FileSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.file == nil {
|
||||
if err := s.openFile(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal log: %w", err)
|
||||
}
|
||||
|
||||
line := append(data, '\n')
|
||||
if _, err := s.file.Write(line); err != nil {
|
||||
return fmt.Errorf("failed to write log line: %w", err)
|
||||
}
|
||||
if err := s.file.Sync(); err != nil {
|
||||
return fmt.Errorf("failed to sync log line: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush flushes any buffered data.
|
||||
func (s *FileSink) Flush(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.file != nil {
|
||||
return s.file.Sync()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the sink.
|
||||
func (s *FileSink) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.file != nil {
|
||||
err := s.file.Close()
|
||||
s.file = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileSink) openFile() error {
|
||||
// Validate path again before opening
|
||||
if err := validateFilePath(s.config.Path); err != nil {
|
||||
return fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(s.config.Path)
|
||||
if err := os.MkdirAll(dir, DefaultDirPermissions); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(s.config.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, DefaultFilePermissions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
s.file = file
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFilePath validates that the file path is safe and allowed.
|
||||
func validateFilePath(path string) error {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return fmt.Errorf("path cannot be empty")
|
||||
}
|
||||
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Allow relative paths for testing/dev
|
||||
if !filepath.IsAbs(cleanPath) {
|
||||
return nil
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path: %w", err)
|
||||
}
|
||||
|
||||
allowedRoots := []string{
|
||||
"/var/log/logcorrelator",
|
||||
"/var/log",
|
||||
"/tmp",
|
||||
}
|
||||
|
||||
for _, root := range allowedRoots {
|
||||
absRoot, err := filepath.Abs(filepath.Clean(root))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(absRoot, absPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if rel == "." {
|
||||
return nil
|
||||
}
|
||||
if rel == ".." {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("path must be under allowed directories: %v", allowedRoots)
|
||||
}
|
||||
524
services/correlator/internal/adapters/outbound/file/sink_test.go
Normal file
524
services/correlator/internal/adapters/outbound/file/sink_test.go
Normal file
@ -0,0 +1,524 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
)
|
||||
|
||||
func TestFileSink_Write(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Correlated: true,
|
||||
}
|
||||
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
if err := sink.Flush(context.Background()); err != nil {
|
||||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and contains data
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
t.Error("expected non-empty file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_WriteImmediatePersist_NoFlushNeeded(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Correlated: true,
|
||||
}
|
||||
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
// Must be visible immediately without Flush()
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Error("expected data to be present immediately after Write without Flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_MultipleWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080 + i,
|
||||
}
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sink.Close()
|
||||
|
||||
// Verify file has 5 lines
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
|
||||
lines := 0
|
||||
for _, b := range data {
|
||||
if b == '\n' {
|
||||
lines++
|
||||
}
|
||||
}
|
||||
|
||||
if lines != 5 {
|
||||
t.Errorf("expected 5 lines, got %d", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_Name(t *testing.T) {
|
||||
sink, err := NewFileSink(Config{Path: "/tmp/test.log"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
|
||||
if sink.Name() != "file" {
|
||||
t.Errorf("expected name 'file', got %s", sink.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_ValidateFilePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty path", "", true},
|
||||
{"valid /var/log/logcorrelator", "/var/log/logcorrelator/test.log", false},
|
||||
{"valid /var/log", "/var/log/test.log", false},
|
||||
{"valid /tmp", "/tmp/test.log", false},
|
||||
{"reject lookalike /var/logevil", "/var/logevil/test.log", true},
|
||||
{"invalid directory", "/etc/logcorrelator/test.log", true},
|
||||
{"relative path", "test.log", false}, // Allowed for testing
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateFilePath(tt.path)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateFilePath(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_OpenFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "subdir", "test.log")
|
||||
|
||||
sink := &FileSink{
|
||||
config: Config{Path: testPath},
|
||||
}
|
||||
|
||||
err := sink.openFile()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
if sink.file == nil {
|
||||
t.Error("expected file to be opened")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_WriteBeforeOpen(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
// Write should open file automatically
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
err = sink.Write(context.Background(), log)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
// Verify file was created
|
||||
if _, err := os.Stat(testPath); os.IsNotExist(err) {
|
||||
t.Error("expected file to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_FlushBeforeOpen(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
// Flush before any write should not error
|
||||
err = sink.Flush(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on flush before open, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_InvalidPath(t *testing.T) {
|
||||
// Test with invalid path (outside allowed directories)
|
||||
_, err := NewFileSink(Config{Path: "/etc/../passwd"})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_Reopen(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
|
||||
// Write initial data
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
// Reopen should close and reopen the file
|
||||
err = sink.Reopen()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on Reopen, got %v", err)
|
||||
}
|
||||
|
||||
// Write after reopen
|
||||
log2 := domain.CorrelatedLog{SrcIP: "192.168.1.2", SrcPort: 8081}
|
||||
if err := sink.Write(context.Background(), log2); err != nil {
|
||||
t.Fatalf("failed to write after reopen: %v", err)
|
||||
}
|
||||
|
||||
sink.Close()
|
||||
|
||||
// Verify both writes are present
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
|
||||
lines := 0
|
||||
for _, b := range data {
|
||||
if b == '\n' {
|
||||
lines++
|
||||
}
|
||||
}
|
||||
|
||||
if lines != 2 {
|
||||
t.Errorf("expected 2 lines after reopen, got %d", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_Close(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
|
||||
// Close should succeed
|
||||
err = sink.Close()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on Close, got %v", err)
|
||||
}
|
||||
|
||||
// Write after close should fail or reopen
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
err = sink.Write(context.Background(), log)
|
||||
if err != nil {
|
||||
// Expected - file was closed
|
||||
t.Logf("write after close returned error (expected): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_EmptyPath(t *testing.T) {
|
||||
_, err := NewFileSink(Config{Path: ""})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_WhitespacePath(t *testing.T) {
|
||||
_, err := NewFileSink(Config{Path: " "})
|
||||
if err == nil {
|
||||
t.Error("expected error for whitespace-only path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_ValidateFilePath_AllowedRoots(t *testing.T) {
|
||||
// Test paths under allowed roots
|
||||
allowedPaths := []string{
|
||||
"/var/log/logcorrelator/correlated.log",
|
||||
"/var/log/test.log",
|
||||
"/tmp/test.log",
|
||||
"/tmp/subdir/test.log",
|
||||
"relative/path/test.log",
|
||||
"./test.log",
|
||||
}
|
||||
|
||||
for _, path := range allowedPaths {
|
||||
err := validateFilePath(path)
|
||||
if err != nil {
|
||||
t.Errorf("validateFilePath(%q) unexpected error: %v", path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_ValidateFilePath_RejectedPaths(t *testing.T) {
|
||||
// Test paths that should be rejected
|
||||
rejectedPaths := []string{
|
||||
"",
|
||||
" ",
|
||||
"/etc/passwd",
|
||||
"/etc/logcorrelator/test.log",
|
||||
"/root/test.log",
|
||||
"/home/user/test.log",
|
||||
"/var/logevil/test.log",
|
||||
}
|
||||
|
||||
for _, path := range rejectedPaths {
|
||||
err := validateFilePath(path)
|
||||
if err == nil {
|
||||
t.Errorf("validateFilePath(%q) should have been rejected", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_ConcurrentWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(n int) {
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080 + n}
|
||||
sink.Write(context.Background(), log)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all writes completed
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
|
||||
lines := 0
|
||||
for _, b := range data {
|
||||
if b == '\n' {
|
||||
lines++
|
||||
}
|
||||
}
|
||||
|
||||
if lines != 10 {
|
||||
t.Errorf("expected 10 lines from concurrent writes, got %d", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_Flush(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
// Flush should succeed
|
||||
err = sink.Flush(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on Flush, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_MarshalError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
// Create a log with unmarshalable data (channel)
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Fields: map[string]any{"chan": make(chan int)},
|
||||
}
|
||||
|
||||
err = sink.Write(context.Background(), log)
|
||||
if err == nil {
|
||||
t.Error("expected error when marshaling unmarshalable data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileSink_CloseTwice tests that closing an already-closed sink does not error.
|
||||
func TestFileSink_CloseTwice(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
|
||||
if err := sink.Close(); err != nil {
|
||||
t.Errorf("first Close() should not error, got: %v", err)
|
||||
}
|
||||
|
||||
// After close, file is nil, so second close should return nil
|
||||
if err := sink.Close(); err != nil {
|
||||
t.Errorf("second Close() on already-closed sink should not error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileSink_WriteAfterClose tests that Write after Close re-opens the file.
|
||||
func TestFileSink_WriteAfterCloseReopens(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
|
||||
if err := sink.Close(); err != nil {
|
||||
t.Fatalf("Close() failed: %v", err)
|
||||
}
|
||||
|
||||
// Write after close: FileSink.Write reopens the file when file == nil
|
||||
log := domain.CorrelatedLog{SrcIP: "1.2.3.4", SrcPort: 80}
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Errorf("Write after close should succeed (auto-reopen), got: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was written
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Error("expected data to be present after write on re-opened file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileSink_ReopenAfterWrite tests Reopen then write produces correct output.
|
||||
func TestFileSink_ReopenThenWrite(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
// Write before reopen
|
||||
log1 := domain.CorrelatedLog{SrcIP: "1.1.1.1", SrcPort: 80}
|
||||
if err := sink.Write(context.Background(), log1); err != nil {
|
||||
t.Fatalf("first Write failed: %v", err)
|
||||
}
|
||||
|
||||
// Simulate log rotation
|
||||
if err := sink.Reopen(); err != nil {
|
||||
t.Fatalf("Reopen failed: %v", err)
|
||||
}
|
||||
|
||||
// Write after reopen
|
||||
log2 := domain.CorrelatedLog{SrcIP: "2.2.2.2", SrcPort: 443}
|
||||
if err := sink.Write(context.Background(), log2); err != nil {
|
||||
t.Fatalf("second Write failed: %v", err)
|
||||
}
|
||||
|
||||
sink.Close()
|
||||
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
|
||||
lines := 0
|
||||
for _, b := range data {
|
||||
if b == '\n' {
|
||||
lines++
|
||||
}
|
||||
}
|
||||
if lines != 2 {
|
||||
t.Errorf("expected 2 lines after reopen+write, got %d", lines)
|
||||
}
|
||||
}
|
||||
137
services/correlator/internal/adapters/outbound/multi/sink.go
Normal file
137
services/correlator/internal/adapters/outbound/multi/sink.go
Normal file
@ -0,0 +1,137 @@
|
||||
package multi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
"github.com/antitbone/ja4/correlator/internal/ports"
|
||||
)
|
||||
|
||||
// MultiSink fans out correlated logs to multiple sinks.
|
||||
type MultiSink struct {
|
||||
mu sync.RWMutex
|
||||
sinks []ports.CorrelatedLogSink
|
||||
}
|
||||
|
||||
// NewMultiSink creates a new multi-sink.
|
||||
func NewMultiSink(sinks ...ports.CorrelatedLogSink) *MultiSink {
|
||||
return &MultiSink{
|
||||
sinks: sinks,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the sink name.
|
||||
func (s *MultiSink) Name() string {
|
||||
return "multi"
|
||||
}
|
||||
|
||||
// AddSink adds a sink to the fan-out.
|
||||
func (s *MultiSink) AddSink(sink ports.CorrelatedLogSink) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sinks = append(s.sinks, sink)
|
||||
}
|
||||
|
||||
// Write writes a correlated log to all sinks concurrently.
|
||||
// Returns the first error encountered (but all sinks are attempted).
|
||||
func (s *MultiSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
s.mu.RLock()
|
||||
sinks := make([]ports.CorrelatedLogSink, len(s.sinks))
|
||||
copy(sinks, s.sinks)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if len(sinks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var firstErr error
|
||||
var firstErrMu sync.Mutex
|
||||
errChan := make(chan error, len(sinks))
|
||||
|
||||
for _, sink := range sinks {
|
||||
wg.Add(1)
|
||||
go func(sk ports.CorrelatedLogSink) {
|
||||
defer wg.Done()
|
||||
if err := sk.Write(ctx, log); err != nil {
|
||||
// Non-blocking send to errChan
|
||||
select {
|
||||
case errChan <- err:
|
||||
default:
|
||||
// Channel full, error will be handled via firstErr
|
||||
}
|
||||
}
|
||||
}(sink)
|
||||
}
|
||||
|
||||
// Wait for all writes to complete in a separate goroutine
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Collect errors with timeout
|
||||
select {
|
||||
case <-done:
|
||||
close(errChan)
|
||||
// Collect first error
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
firstErrMu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
firstErrMu.Unlock()
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
firstErrMu.Lock()
|
||||
defer firstErrMu.Unlock()
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// Flush flushes all sinks.
|
||||
func (s *MultiSink) Flush(ctx context.Context) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, sink := range s.sinks {
|
||||
if err := sink.Flush(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes all sinks.
|
||||
func (s *MultiSink) Close() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var firstErr error
|
||||
for _, sink := range s.sinks {
|
||||
if err := sink.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// Reopen reopens all sinks (for log rotation on SIGHUP).
|
||||
func (s *MultiSink) Reopen() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var firstErr error
|
||||
for _, sink := range s.sinks {
|
||||
if err := sink.Reopen(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
@ -0,0 +1,233 @@
|
||||
package multi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
)
|
||||
|
||||
type mockSink struct {
|
||||
name string
|
||||
mu sync.Mutex
|
||||
writeFunc func(domain.CorrelatedLog) error
|
||||
flushFunc func() error
|
||||
closeFunc func() error
|
||||
reopenFunc func() error
|
||||
}
|
||||
|
||||
func (m *mockSink) Name() string { return m.name }
|
||||
func (m *mockSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.writeFunc(log)
|
||||
}
|
||||
func (m *mockSink) Flush(ctx context.Context) error { return m.flushFunc() }
|
||||
func (m *mockSink) Close() error { return m.closeFunc() }
|
||||
func (m *mockSink) Reopen() error {
|
||||
if m.reopenFunc != nil {
|
||||
return m.reopenFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMultiSink_Write(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
writeCount := 0
|
||||
|
||||
sink1 := &mockSink{
|
||||
name: "sink1",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
mu.Lock()
|
||||
writeCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
sink2 := &mockSink{
|
||||
name: "sink2",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
mu.Lock()
|
||||
writeCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink1, sink2)
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(context.Background(), log)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if writeCount != 2 {
|
||||
t.Errorf("expected 2 writes, got %d", writeCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Write_OneFails(t *testing.T) {
|
||||
sink1 := &mockSink{
|
||||
name: "sink1",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
return nil
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
sink2 := &mockSink{
|
||||
name: "sink2",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
return context.Canceled
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink1, sink2)
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(context.Background(), log)
|
||||
if err == nil {
|
||||
t.Error("expected error when one sink fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_AddSink(t *testing.T) {
|
||||
ms := NewMultiSink()
|
||||
|
||||
sink := &mockSink{
|
||||
name: "dynamic",
|
||||
writeFunc: func(log domain.CorrelatedLog) error { return nil },
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms.AddSink(sink)
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(context.Background(), log)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Name(t *testing.T) {
|
||||
ms := NewMultiSink()
|
||||
if ms.Name() != "multi" {
|
||||
t.Errorf("expected name 'multi', got %s", ms.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Flush(t *testing.T) {
|
||||
flushed := false
|
||||
sink := &mockSink{
|
||||
name: "test",
|
||||
writeFunc: func(log domain.CorrelatedLog) error { return nil },
|
||||
flushFunc: func() error {
|
||||
flushed = true
|
||||
return nil
|
||||
},
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink)
|
||||
err := ms.Flush(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !flushed {
|
||||
t.Error("expected sink to be flushed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Flush_Error(t *testing.T) {
|
||||
sink := &mockSink{
|
||||
name: "test",
|
||||
writeFunc: func(log domain.CorrelatedLog) error { return nil },
|
||||
flushFunc: func() error { return context.Canceled },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink)
|
||||
err := ms.Flush(context.Background())
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Close(t *testing.T) {
|
||||
closed := false
|
||||
sink := &mockSink{
|
||||
name: "test",
|
||||
writeFunc: func(log domain.CorrelatedLog) error { return nil },
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error {
|
||||
closed = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink)
|
||||
err := ms.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !closed {
|
||||
t.Error("expected sink to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Close_Error(t *testing.T) {
|
||||
sink := &mockSink{
|
||||
name: "test",
|
||||
writeFunc: func(log domain.CorrelatedLog) error { return nil },
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return context.Canceled },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink)
|
||||
err := ms.Close()
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Write_EmptySinks(t *testing.T) {
|
||||
ms := NewMultiSink()
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(context.Background(), log)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with empty sinks: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Write_ContextCancelled(t *testing.T) {
|
||||
sink := &mockSink{
|
||||
name: "test",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
<-context.Background().Done()
|
||||
return nil
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(ctx, log)
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled error, got %v", err)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
package stdout
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
)
|
||||
|
||||
// Config holds the stdout sink configuration.
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// StdoutSink is a no-op data sink. Operational logs are written to stderr
|
||||
// by the observability.Logger; correlated data must never appear on stdout.
|
||||
type StdoutSink struct{}
|
||||
|
||||
// NewStdoutSink creates a new stdout sink.
|
||||
func NewStdoutSink(config Config) *StdoutSink {
|
||||
return &StdoutSink{}
|
||||
}
|
||||
|
||||
// Name returns the sink name.
|
||||
func (s *StdoutSink) Name() string {
|
||||
return "stdout"
|
||||
}
|
||||
|
||||
// Reopen is a no-op for stdout.
|
||||
func (s *StdoutSink) Reopen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write is a no-op: correlated data must never be written to stdout.
|
||||
func (s *StdoutSink) Write(_ context.Context, _ domain.CorrelatedLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush is a no-op for stdout.
|
||||
func (s *StdoutSink) Flush(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op for stdout.
|
||||
func (s *StdoutSink) Close() error {
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,81 @@
|
||||
package stdout
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
)
|
||||
|
||||
func makeLog(correlated bool) domain.CorrelatedLog {
|
||||
return domain.CorrelatedLog{
|
||||
Timestamp: time.Unix(1700000000, 0),
|
||||
SrcIP: "1.2.3.4",
|
||||
SrcPort: 12345,
|
||||
Correlated: correlated,
|
||||
}
|
||||
}
|
||||
|
||||
// captureStdout replaces os.Stdout temporarily and returns what was written.
|
||||
func captureStdout(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("os.Pipe: %v", err)
|
||||
}
|
||||
old := os.Stdout
|
||||
os.Stdout = w
|
||||
|
||||
fn()
|
||||
|
||||
w.Close()
|
||||
os.Stdout = old
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.ReadFrom(r)
|
||||
r.Close()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestStdoutSink_Name(t *testing.T) {
|
||||
s := NewStdoutSink(Config{Enabled: true})
|
||||
if s.Name() != "stdout" {
|
||||
t.Errorf("expected name 'stdout', got %q", s.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestStdoutSink_WriteDoesNotProduceOutput verifies that no JSON data
|
||||
// (correlated or not) is ever written to stdout.
|
||||
func TestStdoutSink_WriteDoesNotProduceOutput(t *testing.T) {
|
||||
s := NewStdoutSink(Config{Enabled: true})
|
||||
|
||||
got := captureStdout(t, func() {
|
||||
if err := s.Write(context.Background(), makeLog(true)); err != nil {
|
||||
t.Fatalf("Write(correlated) returned error: %v", err)
|
||||
}
|
||||
if err := s.Write(context.Background(), makeLog(false)); err != nil {
|
||||
t.Fatalf("Write(orphan) returned error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
if got != "" {
|
||||
t.Errorf("stdout must be empty but got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStdoutSink_NoopMethods(t *testing.T) {
|
||||
s := NewStdoutSink(Config{Enabled: true})
|
||||
|
||||
if err := s.Flush(context.Background()); err != nil {
|
||||
t.Errorf("Flush returned error: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("Close returned error: %v", err)
|
||||
}
|
||||
if err := s.Reopen(); err != nil {
|
||||
t.Errorf("Reopen returned error: %v", err)
|
||||
}
|
||||
}
|
||||
160
services/correlator/internal/app/orchestrator.go
Normal file
160
services/correlator/internal/app/orchestrator.go
Normal file
@ -0,0 +1,160 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
"github.com/antitbone/ja4/correlator/internal/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultEventChannelBufferSize is the default size for event channels
|
||||
DefaultEventChannelBufferSize = 1000
|
||||
// OrphanTickInterval is how often the orchestrator drains pending orphans.
|
||||
// Set to half the default emit delay (500ms/2) so orphans are emitted promptly
|
||||
// even when no new events arrive.
|
||||
OrphanTickInterval = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
// OrchestratorConfig holds the orchestrator configuration.
|
||||
type OrchestratorConfig struct {
|
||||
Sources []ports.EventSource
|
||||
Sink ports.CorrelatedLogSink
|
||||
}
|
||||
|
||||
// Orchestrator connects sources to the correlation service and sinks.
|
||||
type Orchestrator struct {
|
||||
config OrchestratorConfig
|
||||
correlationSvc ports.CorrelationProcessor
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
running atomic.Bool
|
||||
}
|
||||
|
||||
// NewOrchestrator creates a new orchestrator.
|
||||
func NewOrchestrator(config OrchestratorConfig, correlationSvc ports.CorrelationProcessor) *Orchestrator {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Orchestrator{
|
||||
config: config,
|
||||
correlationSvc: correlationSvc,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the orchestration.
|
||||
func (o *Orchestrator) Start() error {
|
||||
if !o.running.CompareAndSwap(false, true) {
|
||||
return nil // Already running
|
||||
}
|
||||
|
||||
// Start each source
|
||||
for _, source := range o.config.Sources {
|
||||
eventChan := make(chan *domain.NormalizedEvent, DefaultEventChannelBufferSize)
|
||||
|
||||
o.wg.Add(1)
|
||||
go func(src ports.EventSource, evChan chan *domain.NormalizedEvent) {
|
||||
defer o.wg.Done()
|
||||
|
||||
// Start the source in a separate goroutine
|
||||
sourceErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := src.Start(o.ctx, evChan); err != nil {
|
||||
sourceErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// Process events in the current goroutine
|
||||
o.processEvents(evChan)
|
||||
|
||||
// Check for source start errors
|
||||
if err := <-sourceErr; err != nil {
|
||||
// Source failed to start, log error and exit
|
||||
return
|
||||
}
|
||||
}(source, eventChan)
|
||||
}
|
||||
|
||||
// Start a periodic ticker to drain pending orphan A events independently of the
|
||||
// event flow. Without this, orphans are only emitted when a new event arrives,
|
||||
// causing them to accumulate silently when the source goes quiet.
|
||||
o.wg.Add(1)
|
||||
go func() {
|
||||
defer o.wg.Done()
|
||||
ticker := time.NewTicker(OrphanTickInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-o.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
logs := o.correlationSvc.EmitPendingOrphans()
|
||||
for _, log := range logs {
|
||||
o.config.Sink.Write(o.ctx, log) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Orchestrator) processEvents(eventChan <-chan *domain.NormalizedEvent) {
|
||||
for {
|
||||
select {
|
||||
case <-o.ctx.Done():
|
||||
// Drain remaining events before exiting
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-eventChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
logs := o.correlationSvc.ProcessEvent(event)
|
||||
for _, log := range logs {
|
||||
o.config.Sink.Write(o.ctx, log)
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case event, ok := <-eventChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Process through correlation service
|
||||
logs := o.correlationSvc.ProcessEvent(event)
|
||||
|
||||
// Write correlated logs to sink
|
||||
for _, log := range logs {
|
||||
if err := o.config.Sink.Write(o.ctx, log); err != nil {
|
||||
// Log error but continue processing
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully stops the orchestrator.
|
||||
// It stops all sources and closes sinks immediately without waiting for queue drainage.
|
||||
// systemd TimeoutStopSec handles forced termination if needed.
|
||||
func (o *Orchestrator) Stop() error {
|
||||
if !o.running.CompareAndSwap(true, false) {
|
||||
return nil // Not running
|
||||
}
|
||||
|
||||
// Cancel context to stop accepting new events immediately
|
||||
o.cancel()
|
||||
|
||||
// Close sink (flush skipped - in-flight events are dropped)
|
||||
if err := o.config.Sink.Close(); err != nil {
|
||||
// Log error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
300
services/correlator/internal/app/orchestrator_test.go
Normal file
300
services/correlator/internal/app/orchestrator_test.go
Normal file
@ -0,0 +1,300 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
"github.com/antitbone/ja4/correlator/internal/ports"
|
||||
)
|
||||
|
||||
type mockEventSource struct {
|
||||
name string
|
||||
mu sync.RWMutex
|
||||
eventChan chan<- *domain.NormalizedEvent
|
||||
started bool
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func (m *mockEventSource) Name() string { return m.name }
|
||||
func (m *mockEventSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error {
|
||||
m.mu.Lock()
|
||||
m.started = true
|
||||
m.eventChan = eventChan
|
||||
m.mu.Unlock()
|
||||
<-ctx.Done()
|
||||
m.mu.Lock()
|
||||
m.stopped = true
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
func (m *mockEventSource) Stop() error { return nil }
|
||||
|
||||
func (m *mockEventSource) getEventChan() chan<- *domain.NormalizedEvent {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.eventChan
|
||||
}
|
||||
|
||||
func (m *mockEventSource) isStarted() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.started
|
||||
}
|
||||
|
||||
type mockSink struct {
|
||||
mu sync.Mutex
|
||||
written []domain.CorrelatedLog
|
||||
}
|
||||
|
||||
func (m *mockSink) Name() string { return "mock" }
|
||||
func (m *mockSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.written = append(m.written, log)
|
||||
return nil
|
||||
}
|
||||
func (m *mockSink) Flush(ctx context.Context) error { return nil }
|
||||
func (m *mockSink) Close() error { return nil }
|
||||
func (m *mockSink) Reopen() error { return nil }
|
||||
|
||||
func (m *mockSink) getWritten() []domain.CorrelatedLog {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make([]domain.CorrelatedLog, len(m.written))
|
||||
copy(result, m.written)
|
||||
return result
|
||||
}
|
||||
|
||||
func TestOrchestrator_StartStop(t *testing.T) {
|
||||
source := &mockEventSource{name: "test"}
|
||||
sink := &mockSink{}
|
||||
|
||||
corrConfig := domain.CorrelationConfig{
|
||||
TimeWindow: time.Second,
|
||||
ApacheAlwaysEmit: true,
|
||||
NetworkEmit: false,
|
||||
}
|
||||
correlationSvc := domain.NewCorrelationService(corrConfig, &domain.RealTimeProvider{})
|
||||
|
||||
orchestrator := NewOrchestrator(OrchestratorConfig{
|
||||
Sources: []ports.EventSource{source},
|
||||
Sink: sink,
|
||||
}, correlationSvc)
|
||||
|
||||
if err := orchestrator.Start(); err != nil {
|
||||
t.Fatalf("failed to start: %v", err)
|
||||
}
|
||||
|
||||
// Let it run briefly
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if err := orchestrator.Stop(); err != nil {
|
||||
t.Fatalf("failed to stop: %v", err)
|
||||
}
|
||||
|
||||
if !source.isStarted() {
|
||||
t.Error("expected source to be started")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchestrator_ProcessEvent(t *testing.T) {
|
||||
source := &mockEventSource{name: "test"}
|
||||
sink := &mockSink{}
|
||||
|
||||
corrConfig := domain.CorrelationConfig{
|
||||
TimeWindow: time.Second,
|
||||
ApacheAlwaysEmit: true,
|
||||
NetworkEmit: false,
|
||||
}
|
||||
correlationSvc := domain.NewCorrelationService(corrConfig, &domain.RealTimeProvider{})
|
||||
|
||||
orchestrator := NewOrchestrator(OrchestratorConfig{
|
||||
Sources: []ports.EventSource{source},
|
||||
Sink: sink,
|
||||
}, correlationSvc)
|
||||
|
||||
if err := orchestrator.Start(); err != nil {
|
||||
t.Fatalf("failed to start: %v", err)
|
||||
}
|
||||
|
||||
// Wait for source to start and get the channel
|
||||
var eventChan chan<- *domain.NormalizedEvent
|
||||
for i := 0; i < 50; i++ {
|
||||
eventChan = source.getEventChan()
|
||||
if eventChan != nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if eventChan == nil {
|
||||
t.Fatal("source did not start properly")
|
||||
}
|
||||
|
||||
// Send an event through the source
|
||||
event := &domain.NormalizedEvent{
|
||||
Source: domain.SourceA,
|
||||
Timestamp: time.Now(),
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Raw: map[string]any{"method": "GET"},
|
||||
}
|
||||
|
||||
// Send event
|
||||
eventChan <- event
|
||||
|
||||
// Give it time to process
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if err := orchestrator.Stop(); err != nil {
|
||||
t.Fatalf("failed to stop: %v", err)
|
||||
}
|
||||
|
||||
// Should have written at least one log (the orphan A)
|
||||
written := sink.getWritten()
|
||||
if len(written) == 0 {
|
||||
t.Error("expected at least one log to be written")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOrchestrator_StartTwice tests that calling Start() twice is a no-op (already running).
|
||||
func TestOrchestrator_StartTwice(t *testing.T) {
|
||||
source := &mockEventSource{name: "test"}
|
||||
sink := &mockSink{}
|
||||
|
||||
corrConfig := domain.CorrelationConfig{
|
||||
TimeWindow: time.Second,
|
||||
ApacheAlwaysEmit: true,
|
||||
}
|
||||
correlationSvc := domain.NewCorrelationService(corrConfig, &domain.RealTimeProvider{})
|
||||
|
||||
o := NewOrchestrator(OrchestratorConfig{
|
||||
Sources: []ports.EventSource{source},
|
||||
Sink: sink,
|
||||
}, correlationSvc)
|
||||
|
||||
if err := o.Start(); err != nil {
|
||||
t.Fatalf("first Start() failed: %v", err)
|
||||
}
|
||||
if err := o.Start(); err != nil {
|
||||
t.Errorf("second Start() should be no-op, got: %v", err)
|
||||
}
|
||||
|
||||
o.Stop()
|
||||
}
|
||||
|
||||
// TestOrchestrator_StopTwice tests that calling Stop() twice is a no-op.
|
||||
func TestOrchestrator_StopTwice(t *testing.T) {
|
||||
source := &mockEventSource{name: "test"}
|
||||
sink := &mockSink{}
|
||||
|
||||
corrConfig := domain.CorrelationConfig{
|
||||
TimeWindow: time.Second,
|
||||
ApacheAlwaysEmit: true,
|
||||
}
|
||||
correlationSvc := domain.NewCorrelationService(corrConfig, &domain.RealTimeProvider{})
|
||||
|
||||
o := NewOrchestrator(OrchestratorConfig{
|
||||
Sources: []ports.EventSource{source},
|
||||
Sink: sink,
|
||||
}, correlationSvc)
|
||||
|
||||
o.Start()
|
||||
|
||||
if err := o.Stop(); err != nil {
|
||||
t.Errorf("first Stop() failed: %v", err)
|
||||
}
|
||||
if err := o.Stop(); err != nil {
|
||||
t.Errorf("second Stop() should be no-op, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOrchestrator_NoSources tests that Orchestrator works with no sources.
|
||||
func TestOrchestrator_NoSources(t *testing.T) {
|
||||
sink := &mockSink{}
|
||||
|
||||
corrConfig := domain.CorrelationConfig{TimeWindow: time.Second}
|
||||
correlationSvc := domain.NewCorrelationService(corrConfig, &domain.RealTimeProvider{})
|
||||
|
||||
o := NewOrchestrator(OrchestratorConfig{
|
||||
Sources: []ports.EventSource{},
|
||||
Sink: sink,
|
||||
}, correlationSvc)
|
||||
|
||||
if err := o.Start(); err != nil {
|
||||
t.Fatalf("Start() with no sources failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if err := o.Stop(); err != nil {
|
||||
t.Errorf("Stop() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOrchestrator_OrphanEmission tests that orphan A events are emitted via tick.
|
||||
func TestOrchestrator_OrphanEmission(t *testing.T) {
|
||||
source := &mockEventSource{name: "test"}
|
||||
sink := &mockSink{}
|
||||
|
||||
corrConfig := domain.CorrelationConfig{
|
||||
TimeWindow: 50 * time.Millisecond,
|
||||
ApacheAlwaysEmit: true,
|
||||
ApacheEmitDelayMs: 10, // Very short delay so orphans emit quickly
|
||||
}
|
||||
correlationSvc := domain.NewCorrelationService(corrConfig, &domain.RealTimeProvider{})
|
||||
|
||||
o := NewOrchestrator(OrchestratorConfig{
|
||||
Sources: []ports.EventSource{source},
|
||||
Sink: sink,
|
||||
}, correlationSvc)
|
||||
|
||||
if err := o.Start(); err != nil {
|
||||
t.Fatalf("Start() failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for source to be ready
|
||||
var eventChan chan<- *domain.NormalizedEvent
|
||||
for i := 0; i < 50; i++ {
|
||||
eventChan = source.getEventChan()
|
||||
if eventChan != nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
if eventChan == nil {
|
||||
t.Fatal("source did not start")
|
||||
}
|
||||
|
||||
// Send a source A event (Apache/HTTP)
|
||||
eventChan <- &domain.NormalizedEvent{
|
||||
Source: domain.SourceA,
|
||||
Timestamp: time.Now(),
|
||||
SrcIP: "10.0.0.1",
|
||||
SrcPort: 12345,
|
||||
Raw: map[string]any{"method": "GET"},
|
||||
}
|
||||
|
||||
// Allow time for orphan ticker to fire (OrphanTickInterval = 250ms, but emit delay is 10ms)
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
o.Stop()
|
||||
|
||||
written := sink.getWritten()
|
||||
if len(written) == 0 {
|
||||
t.Error("expected at least one orphan log to be emitted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOrchestrator_Constants tests that constants have reasonable values.
|
||||
func TestOrchestrator_Constants(t *testing.T) {
|
||||
if DefaultEventChannelBufferSize <= 0 {
|
||||
t.Error("DefaultEventChannelBufferSize should be positive")
|
||||
}
|
||||
if OrphanTickInterval <= 0 {
|
||||
t.Error("OrphanTickInterval should be positive")
|
||||
}
|
||||
}
|
||||
406
services/correlator/internal/config/config.go
Normal file
406
services/correlator/internal/config/config.go
Normal file
@ -0,0 +1,406 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config holds the complete application configuration.
|
||||
type Config struct {
|
||||
Log LogConfig `yaml:"log"`
|
||||
Inputs InputsConfig `yaml:"inputs"`
|
||||
Outputs OutputsConfig `yaml:"outputs"`
|
||||
Correlation CorrelationConfig `yaml:"correlation"`
|
||||
Metrics MetricsConfig `yaml:"metrics"`
|
||||
}
|
||||
|
||||
// MetricsConfig holds metrics server configuration.
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Addr string `yaml:"addr"` // e.g., ":8080", "localhost:8080"
|
||||
}
|
||||
|
||||
// LogConfig holds logging configuration.
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level"` // DEBUG, INFO, WARN, ERROR
|
||||
}
|
||||
|
||||
// GetLogLevel returns the log level, defaulting to INFO if not set.
|
||||
func (c *LogConfig) GetLevel() string {
|
||||
if c.Level == "" {
|
||||
return "INFO"
|
||||
}
|
||||
return strings.ToUpper(c.Level)
|
||||
}
|
||||
|
||||
// ServiceConfig holds service-level configuration.
|
||||
type ServiceConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Language string `yaml:"language"`
|
||||
}
|
||||
|
||||
// InputsConfig holds input sources configuration.
|
||||
type InputsConfig struct {
|
||||
UnixSockets []UnixSocketConfig `yaml:"unix_sockets"`
|
||||
}
|
||||
|
||||
// UnixSocketConfig holds a Unix socket source configuration.
|
||||
type UnixSocketConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Path string `yaml:"path"`
|
||||
Format string `yaml:"format"`
|
||||
SourceType string `yaml:"source_type"` // "A" for Apache/HTTP, "B" for Network
|
||||
SocketPermissions string `yaml:"socket_permissions"` // octal string, e.g., "0660", "0666"
|
||||
}
|
||||
|
||||
// OutputsConfig holds output sinks configuration.
|
||||
type OutputsConfig struct {
|
||||
File FileOutputConfig `yaml:"file"`
|
||||
ClickHouse ClickHouseOutputConfig `yaml:"clickhouse"`
|
||||
Stdout StdoutOutputConfig `yaml:"stdout"`
|
||||
}
|
||||
|
||||
// FileOutputConfig holds file sink configuration.
|
||||
type FileOutputConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
// ClickHouseOutputConfig holds ClickHouse sink configuration.
|
||||
type ClickHouseOutputConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
DSN string `yaml:"dsn"`
|
||||
Table string `yaml:"table"`
|
||||
BatchSize int `yaml:"batch_size"`
|
||||
FlushIntervalMs int `yaml:"flush_interval_ms"`
|
||||
MaxBufferSize int `yaml:"max_buffer_size"`
|
||||
DropOnOverflow bool `yaml:"drop_on_overflow"`
|
||||
AsyncInsert bool `yaml:"async_insert"`
|
||||
TimeoutMs int `yaml:"timeout_ms"`
|
||||
}
|
||||
|
||||
// StdoutOutputConfig holds stdout sink configuration.
|
||||
type StdoutOutputConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Level string `yaml:"level"` // DEBUG, INFO, WARN, ERROR - filters output verbosity
|
||||
}
|
||||
|
||||
// CorrelationConfig holds correlation configuration.
|
||||
type CorrelationConfig struct {
|
||||
TimeWindow TimeWindowConfig `yaml:"time_window"`
|
||||
OrphanPolicy OrphanPolicyConfig `yaml:"orphan_policy"`
|
||||
Matching MatchingConfig `yaml:"matching"`
|
||||
Buffers BuffersConfig `yaml:"buffers"`
|
||||
TTL TTLConfig `yaml:"ttl"`
|
||||
ExcludeSourceIPs []string `yaml:"exclude_source_ips"` // List of source IPs or CIDR ranges to exclude
|
||||
IncludeDestPorts []int `yaml:"include_dest_ports"` // If non-empty, only correlate events matching these destination ports
|
||||
// Deprecated: Use TimeWindow.Value instead
|
||||
TimeWindowS int `yaml:"time_window_s"`
|
||||
// Deprecated: Use OrphanPolicy.ApacheAlwaysEmit instead
|
||||
EmitOrphans bool `yaml:"emit_orphans"`
|
||||
}
|
||||
|
||||
// TimeWindowConfig holds time window configuration.
|
||||
type TimeWindowConfig struct {
|
||||
Value int `yaml:"value"`
|
||||
Unit string `yaml:"unit"` // s, ms, etc.
|
||||
}
|
||||
|
||||
// GetDuration returns the time window as a duration.
|
||||
func (c *TimeWindowConfig) GetDuration() time.Duration {
|
||||
value := c.Value
|
||||
if value <= 0 {
|
||||
value = 1
|
||||
}
|
||||
switch c.Unit {
|
||||
case "ms", "millisecond", "milliseconds":
|
||||
return time.Duration(value) * time.Millisecond
|
||||
case "s", "sec", "second", "seconds":
|
||||
fallthrough
|
||||
default:
|
||||
return time.Duration(value) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// OrphanPolicyConfig holds orphan event policy configuration.
|
||||
type OrphanPolicyConfig struct {
|
||||
ApacheAlwaysEmit bool `yaml:"apache_always_emit"`
|
||||
ApacheEmitDelayMs int `yaml:"apache_emit_delay_ms"` // Delay in ms before emitting orphan A
|
||||
NetworkEmit bool `yaml:"network_emit"`
|
||||
}
|
||||
|
||||
// MatchingConfig holds matching mode configuration.
|
||||
type MatchingConfig struct {
|
||||
Mode string `yaml:"mode"` // one_to_one or one_to_many
|
||||
}
|
||||
|
||||
// BuffersConfig holds buffer size configuration.
|
||||
type BuffersConfig struct {
|
||||
MaxHTTPItems int `yaml:"max_http_items"`
|
||||
MaxNetworkItems int `yaml:"max_network_items"`
|
||||
}
|
||||
|
||||
// TTLConfig holds TTL configuration.
|
||||
type TTLConfig struct {
|
||||
NetworkTTLS int `yaml:"network_ttl_s"`
|
||||
}
|
||||
|
||||
// Load loads configuration from a YAML file.
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
cfg := defaultConfig()
|
||||
|
||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// defaultConfig returns a Config with default values.
|
||||
func defaultConfig() *Config {
|
||||
return &Config{
|
||||
Log: LogConfig{
|
||||
Level: "INFO",
|
||||
},
|
||||
Inputs: InputsConfig{
|
||||
UnixSockets: make([]UnixSocketConfig, 0),
|
||||
},
|
||||
Outputs: OutputsConfig{
|
||||
File: FileOutputConfig{
|
||||
Enabled: true,
|
||||
Path: "/var/log/logcorrelator/correlated.log",
|
||||
},
|
||||
ClickHouse: ClickHouseOutputConfig{
|
||||
Enabled: false,
|
||||
BatchSize: 500,
|
||||
FlushIntervalMs: 200,
|
||||
MaxBufferSize: 5000,
|
||||
DropOnOverflow: true,
|
||||
AsyncInsert: true,
|
||||
TimeoutMs: 1000,
|
||||
},
|
||||
Stdout: StdoutOutputConfig{Enabled: false},
|
||||
},
|
||||
Correlation: CorrelationConfig{
|
||||
TimeWindowS: 1,
|
||||
EmitOrphans: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the configuration.
|
||||
func (c *Config) Validate() error {
|
||||
if len(c.Inputs.UnixSockets) < 2 {
|
||||
return fmt.Errorf("at least two unix socket inputs are required")
|
||||
}
|
||||
|
||||
seenNames := make(map[string]struct{}, len(c.Inputs.UnixSockets))
|
||||
seenPaths := make(map[string]struct{}, len(c.Inputs.UnixSockets))
|
||||
|
||||
for i, input := range c.Inputs.UnixSockets {
|
||||
if strings.TrimSpace(input.Name) == "" {
|
||||
return fmt.Errorf("inputs.unix_sockets[%d].name is required", i)
|
||||
}
|
||||
if strings.TrimSpace(input.Path) == "" {
|
||||
return fmt.Errorf("inputs.unix_sockets[%d].path is required", i)
|
||||
}
|
||||
|
||||
if _, exists := seenNames[input.Name]; exists {
|
||||
return fmt.Errorf("duplicate unix socket input name: %s", input.Name)
|
||||
}
|
||||
seenNames[input.Name] = struct{}{}
|
||||
|
||||
if _, exists := seenPaths[input.Path]; exists {
|
||||
return fmt.Errorf("duplicate unix socket input path: %s", input.Path)
|
||||
}
|
||||
seenPaths[input.Path] = struct{}{}
|
||||
}
|
||||
|
||||
// At least one output must be enabled
|
||||
hasOutput := false
|
||||
if c.Outputs.File.Enabled && c.Outputs.File.Path != "" {
|
||||
hasOutput = true
|
||||
}
|
||||
if c.Outputs.ClickHouse.Enabled {
|
||||
hasOutput = true
|
||||
}
|
||||
if c.Outputs.Stdout.Enabled {
|
||||
hasOutput = true
|
||||
}
|
||||
|
||||
if !hasOutput {
|
||||
return fmt.Errorf("at least one output must be enabled (file, clickhouse, or stdout)")
|
||||
}
|
||||
|
||||
if c.Outputs.ClickHouse.Enabled {
|
||||
if strings.TrimSpace(c.Outputs.ClickHouse.DSN) == "" {
|
||||
return fmt.Errorf("clickhouse DSN is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(c.Outputs.ClickHouse.Table) == "" {
|
||||
return fmt.Errorf("clickhouse table is required when enabled")
|
||||
}
|
||||
if c.Outputs.ClickHouse.BatchSize <= 0 {
|
||||
return fmt.Errorf("clickhouse batch_size must be > 0")
|
||||
}
|
||||
if c.Outputs.ClickHouse.MaxBufferSize <= 0 {
|
||||
return fmt.Errorf("clickhouse max_buffer_size must be > 0")
|
||||
}
|
||||
if c.Outputs.ClickHouse.TimeoutMs <= 0 {
|
||||
return fmt.Errorf("clickhouse timeout_ms must be > 0")
|
||||
}
|
||||
}
|
||||
|
||||
if c.Correlation.TimeWindowS <= 0 {
|
||||
return fmt.Errorf("correlation.time_window_s must be > 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTimeWindow returns the time window as a duration.
|
||||
// Deprecated: Use TimeWindow.GetDuration() instead.
|
||||
func (c *CorrelationConfig) GetTimeWindow() time.Duration {
|
||||
// New config takes precedence
|
||||
if c.TimeWindow.Value > 0 {
|
||||
return c.TimeWindow.GetDuration()
|
||||
}
|
||||
// Fallback to deprecated field
|
||||
value := c.TimeWindowS
|
||||
if value <= 0 {
|
||||
value = 1
|
||||
}
|
||||
return time.Duration(value) * time.Second
|
||||
}
|
||||
|
||||
// GetApacheAlwaysEmit returns whether to always emit Apache events.
|
||||
func (c *CorrelationConfig) GetApacheAlwaysEmit() bool {
|
||||
if c.OrphanPolicy.ApacheAlwaysEmit {
|
||||
return true
|
||||
}
|
||||
// Fallback to deprecated field
|
||||
return c.EmitOrphans
|
||||
}
|
||||
|
||||
// GetApacheEmitDelayMs returns the delay in milliseconds before emitting orphan A events.
|
||||
func (c *CorrelationConfig) GetApacheEmitDelayMs() int {
|
||||
if c.OrphanPolicy.ApacheEmitDelayMs > 0 {
|
||||
return c.OrphanPolicy.ApacheEmitDelayMs
|
||||
}
|
||||
return domain.DefaultApacheEmitDelayMs // Default: 500ms
|
||||
}
|
||||
|
||||
// GetMatchingMode returns the matching mode.
|
||||
func (c *CorrelationConfig) GetMatchingMode() string {
|
||||
if c.Matching.Mode != "" {
|
||||
return c.Matching.Mode
|
||||
}
|
||||
return "one_to_many" // Default to Keep-Alive
|
||||
}
|
||||
|
||||
// GetMaxHTTPBufferSize returns the max HTTP buffer size.
|
||||
func (c *CorrelationConfig) GetMaxHTTPBufferSize() int {
|
||||
if c.Buffers.MaxHTTPItems > 0 {
|
||||
return c.Buffers.MaxHTTPItems
|
||||
}
|
||||
return domain.DefaultMaxHTTPBufferSize
|
||||
}
|
||||
|
||||
// GetMaxNetworkBufferSize returns the max network buffer size.
|
||||
func (c *CorrelationConfig) GetMaxNetworkBufferSize() int {
|
||||
if c.Buffers.MaxNetworkItems > 0 {
|
||||
return c.Buffers.MaxNetworkItems
|
||||
}
|
||||
return domain.DefaultMaxNetworkBufferSize
|
||||
}
|
||||
|
||||
// GetNetworkTTLS returns the network TTL in seconds.
|
||||
func (c *CorrelationConfig) GetNetworkTTLS() int {
|
||||
if c.TTL.NetworkTTLS > 0 {
|
||||
return c.TTL.NetworkTTLS
|
||||
}
|
||||
return domain.DefaultNetworkTTLS
|
||||
}
|
||||
|
||||
// GetSocketPermissions returns the socket permissions as os.FileMode.
|
||||
// Default is 0666 (world read/write).
|
||||
func (c *UnixSocketConfig) GetSocketPermissions() os.FileMode {
|
||||
trimmed := strings.TrimSpace(c.SocketPermissions)
|
||||
if trimmed == "" {
|
||||
return 0666
|
||||
}
|
||||
|
||||
// Parse octal string (e.g., "0660", "660", "0666")
|
||||
perms, err := strconv.ParseUint(trimmed, 8, 32)
|
||||
if err != nil {
|
||||
return 0666
|
||||
}
|
||||
|
||||
return os.FileMode(perms)
|
||||
}
|
||||
|
||||
// GetIncludeDestPorts returns the list of destination ports allowed for correlation.
|
||||
// An empty list means all ports are allowed.
|
||||
func (c *CorrelationConfig) GetIncludeDestPorts() []int {
|
||||
return c.IncludeDestPorts
|
||||
}
|
||||
|
||||
// GetExcludeSourceIPs returns the list of excluded source IPs or CIDR ranges.
|
||||
func (c *CorrelationConfig) GetExcludeSourceIPs() []string {
|
||||
return c.ExcludeSourceIPs
|
||||
}
|
||||
|
||||
// IsSourceIPExcluded checks if a source IP should be excluded.
|
||||
// Supports both exact IP matches and CIDR ranges.
|
||||
func (c *CorrelationConfig) IsSourceIPExcluded(ip string) bool {
|
||||
if len(c.ExcludeSourceIPs) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the IP once
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false // Invalid IP
|
||||
}
|
||||
|
||||
for _, exclude := range c.ExcludeSourceIPs {
|
||||
// Try CIDR first
|
||||
if strings.Contains(exclude, "/") {
|
||||
_, cidr, err := net.ParseCIDR(exclude)
|
||||
if err != nil {
|
||||
continue // Invalid CIDR, skip
|
||||
}
|
||||
if cidr.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// Exact IP match
|
||||
if exclude == ip {
|
||||
return true
|
||||
}
|
||||
// Also try parsing as IP (handles different formats like 192.168.1.1 vs 192.168.001.001)
|
||||
if excludeIP := net.ParseIP(exclude); excludeIP != nil {
|
||||
if excludeIP.Equal(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
1253
services/correlator/internal/config/config_test.go
Normal file
1253
services/correlator/internal/config/config_test.go
Normal file
File diff suppressed because it is too large
Load Diff
151
services/correlator/internal/domain/correlated_log.go
Normal file
151
services/correlator/internal/domain/correlated_log.go
Normal file
@ -0,0 +1,151 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CorrelatedLog represents the output correlated log entry.
|
||||
// All fields are flattened into a single-level structure.
|
||||
type CorrelatedLog struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
SrcIP string `json:"src_ip"`
|
||||
SrcPort int `json:"src_port"`
|
||||
DstIP string `json:"dst_ip,omitempty"`
|
||||
DstPort int `json:"dst_port,omitempty"`
|
||||
Correlated bool `json:"correlated"`
|
||||
OrphanSide string `json:"orphan_side,omitempty"`
|
||||
Fields map[string]any `json:"-"` // Additional fields, merged at marshal time
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling to flatten the structure.
|
||||
func (c CorrelatedLog) MarshalJSON() ([]byte, error) {
|
||||
// Create a flat map with all fields
|
||||
flat := make(map[string]any)
|
||||
|
||||
// Add core fields
|
||||
flat["timestamp"] = c.Timestamp
|
||||
flat["src_ip"] = c.SrcIP
|
||||
flat["src_port"] = c.SrcPort
|
||||
if c.DstIP != "" {
|
||||
flat["dst_ip"] = c.DstIP
|
||||
}
|
||||
if c.DstPort != 0 {
|
||||
flat["dst_port"] = c.DstPort
|
||||
}
|
||||
flat["correlated"] = c.Correlated
|
||||
if c.OrphanSide != "" {
|
||||
flat["orphan_side"] = c.OrphanSide
|
||||
}
|
||||
|
||||
// Merge additional fields while preserving reserved keys
|
||||
reservedKeys := map[string]struct{}{
|
||||
"timestamp": {},
|
||||
"src_ip": {},
|
||||
"src_port": {},
|
||||
"dst_ip": {},
|
||||
"dst_port": {},
|
||||
"correlated": {},
|
||||
"orphan_side": {},
|
||||
}
|
||||
for k, v := range c.Fields {
|
||||
if _, reserved := reservedKeys[k]; reserved {
|
||||
continue
|
||||
}
|
||||
flat[k] = v
|
||||
}
|
||||
|
||||
return json.Marshal(flat)
|
||||
}
|
||||
|
||||
// NewCorrelatedLogFromEvent creates a correlated log from a single event (orphan).
|
||||
func NewCorrelatedLogFromEvent(event *NormalizedEvent, orphanSide string) CorrelatedLog {
|
||||
fields := extractFields(event)
|
||||
if event.KeepAliveSeq > 0 {
|
||||
fields["keepalives"] = event.KeepAliveSeq
|
||||
}
|
||||
return CorrelatedLog{
|
||||
Timestamp: event.Timestamp,
|
||||
SrcIP: event.SrcIP,
|
||||
SrcPort: event.SrcPort,
|
||||
DstIP: event.DstIP,
|
||||
DstPort: event.DstPort,
|
||||
Correlated: false,
|
||||
OrphanSide: orphanSide,
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCorrelatedLog creates a correlated log from two matched events.
|
||||
func NewCorrelatedLog(apacheEvent, networkEvent *NormalizedEvent) CorrelatedLog {
|
||||
ts := apacheEvent.Timestamp
|
||||
if networkEvent.Timestamp.After(ts) {
|
||||
ts = networkEvent.Timestamp
|
||||
}
|
||||
|
||||
fields := mergeFields(apacheEvent, networkEvent)
|
||||
if apacheEvent.KeepAliveSeq > 0 {
|
||||
fields["keepalives"] = apacheEvent.KeepAliveSeq
|
||||
}
|
||||
|
||||
return CorrelatedLog{
|
||||
Timestamp: ts,
|
||||
SrcIP: apacheEvent.SrcIP,
|
||||
SrcPort: apacheEvent.SrcPort,
|
||||
DstIP: coalesceString(apacheEvent.DstIP, networkEvent.DstIP),
|
||||
DstPort: coalesceInt(apacheEvent.DstPort, networkEvent.DstPort),
|
||||
Correlated: true,
|
||||
OrphanSide: "",
|
||||
Fields: fields,
|
||||
}
|
||||
}
|
||||
|
||||
func extractFields(e *NormalizedEvent) map[string]any {
|
||||
result := make(map[string]any)
|
||||
for k, v := range e.Raw {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mergeFields(a, b *NormalizedEvent) map[string]any {
|
||||
result := make(map[string]any)
|
||||
|
||||
// Start with A fields
|
||||
for k, v := range a.Raw {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Merge B fields with collision handling
|
||||
for k, v := range b.Raw {
|
||||
if existing, exists := result[k]; exists {
|
||||
if reflect.DeepEqual(existing, v) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collision with different values: keep both with prefixes
|
||||
delete(result, k)
|
||||
result["a_"+k] = existing
|
||||
result["b_"+k] = v
|
||||
continue
|
||||
}
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func coalesceString(a, b string) string {
|
||||
if a != "" {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func coalesceInt(a, b int) int {
|
||||
if a != 0 {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
365
services/correlator/internal/domain/correlated_log_test.go
Normal file
365
services/correlator/internal/domain/correlated_log_test.go
Normal file
@ -0,0 +1,365 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizedEvent_CorrelationKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
event *NormalizedEvent
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic key",
|
||||
event: &NormalizedEvent{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
},
|
||||
expected: "192.168.1.1:8080",
|
||||
},
|
||||
{
|
||||
name: "different port",
|
||||
event: &NormalizedEvent{
|
||||
SrcIP: "10.0.0.1",
|
||||
SrcPort: 443,
|
||||
},
|
||||
expected: "10.0.0.1:443",
|
||||
},
|
||||
{
|
||||
name: "port zero",
|
||||
event: &NormalizedEvent{
|
||||
SrcIP: "127.0.0.1",
|
||||
SrcPort: 0,
|
||||
},
|
||||
expected: "127.0.0.1:0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key := tt.event.CorrelationKey()
|
||||
if key != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCorrelatedLogFromEvent(t *testing.T) {
|
||||
event := &NormalizedEvent{
|
||||
Source: SourceA,
|
||||
Timestamp: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
DstIP: "10.0.0.1",
|
||||
DstPort: 80,
|
||||
Raw: map[string]any{
|
||||
"method": "GET",
|
||||
"path": "/api/test",
|
||||
},
|
||||
}
|
||||
|
||||
log := NewCorrelatedLogFromEvent(event, "A")
|
||||
|
||||
if log.Correlated {
|
||||
t.Error("expected correlated to be false")
|
||||
}
|
||||
if log.OrphanSide != "A" {
|
||||
t.Errorf("expected orphan_side A, got %s", log.OrphanSide)
|
||||
}
|
||||
if log.SrcIP != "192.168.1.1" {
|
||||
t.Errorf("expected src_ip 192.168.1.1, got %s", log.SrcIP)
|
||||
}
|
||||
if log.Fields == nil {
|
||||
t.Error("expected fields to be non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCorrelatedLog(t *testing.T) {
|
||||
apacheEvent := &NormalizedEvent{
|
||||
Source: SourceA,
|
||||
Timestamp: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
DstIP: "10.0.0.1",
|
||||
DstPort: 80,
|
||||
Raw: map[string]any{"method": "GET"},
|
||||
}
|
||||
|
||||
networkEvent := &NormalizedEvent{
|
||||
Source: SourceB,
|
||||
Timestamp: time.Date(2024, 1, 1, 12, 0, 0, 500000000, time.UTC),
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
DstIP: "10.0.0.1",
|
||||
DstPort: 80,
|
||||
Raw: map[string]any{"ja3": "abc123"},
|
||||
}
|
||||
|
||||
log := NewCorrelatedLog(apacheEvent, networkEvent)
|
||||
|
||||
if !log.Correlated {
|
||||
t.Error("expected correlated to be true")
|
||||
}
|
||||
if log.OrphanSide != "" {
|
||||
t.Errorf("expected orphan_side to be empty, got %s", log.OrphanSide)
|
||||
}
|
||||
if log.Fields == nil {
|
||||
t.Error("expected fields to be non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCorrelatedLog_TimestampSelectionAEarlier verifies that when A is earlier the later (B) timestamp is used.
|
||||
func TestNewCorrelatedLog_TimestampSelectionAEarlier(t *testing.T) {
|
||||
tsA := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
tsB := time.Date(2024, 1, 1, 12, 0, 1, 0, time.UTC) // B is later
|
||||
|
||||
a := &NormalizedEvent{Source: SourceA, Timestamp: tsA, SrcIP: "1.1.1.1", SrcPort: 100, Raw: map[string]any{}}
|
||||
b := &NormalizedEvent{Source: SourceB, Timestamp: tsB, SrcIP: "1.1.1.1", SrcPort: 100, Raw: map[string]any{}}
|
||||
|
||||
log := NewCorrelatedLog(a, b)
|
||||
if !log.Timestamp.Equal(tsB) {
|
||||
t.Errorf("expected timestamp to be B's (later), got %v", log.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCorrelatedLog_TimestampSelectionBEarlier verifies that when B is earlier, A's timestamp is used.
|
||||
func TestNewCorrelatedLog_TimestampSelectionBEarlier(t *testing.T) {
|
||||
tsA := time.Date(2024, 1, 1, 12, 0, 1, 0, time.UTC) // A is later
|
||||
tsB := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
a := &NormalizedEvent{Source: SourceA, Timestamp: tsA, SrcIP: "1.1.1.1", SrcPort: 100, Raw: map[string]any{}}
|
||||
b := &NormalizedEvent{Source: SourceB, Timestamp: tsB, SrcIP: "1.1.1.1", SrcPort: 100, Raw: map[string]any{}}
|
||||
|
||||
log := NewCorrelatedLog(a, b)
|
||||
// The later timestamp wins. Since B is not After A, ts stays as A's timestamp.
|
||||
if !log.Timestamp.Equal(tsA) {
|
||||
t.Errorf("expected timestamp to be A's (later), got %v", log.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCorrelatedLog_TimestampEqual verifies equal timestamps yield A's timestamp.
|
||||
func TestNewCorrelatedLog_TimestampEqual(t *testing.T) {
|
||||
ts := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
a := &NormalizedEvent{Source: SourceA, Timestamp: ts, SrcIP: "1.1.1.1", SrcPort: 100, Raw: map[string]any{}}
|
||||
b := &NormalizedEvent{Source: SourceB, Timestamp: ts, SrcIP: "1.1.1.1", SrcPort: 100, Raw: map[string]any{}}
|
||||
|
||||
log := NewCorrelatedLog(a, b)
|
||||
if !log.Timestamp.Equal(ts) {
|
||||
t.Errorf("expected timestamp to be equal to both events' timestamp, got %v", log.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCorrelatedLogFromEvent_WithKeepAlive verifies keepalives field is added when KeepAliveSeq > 0.
|
||||
func TestNewCorrelatedLogFromEvent_WithKeepAlive(t *testing.T) {
|
||||
event := &NormalizedEvent{
|
||||
Source: SourceA,
|
||||
Timestamp: time.Now(),
|
||||
SrcIP: "1.1.1.1",
|
||||
SrcPort: 9999,
|
||||
KeepAliveSeq: 3,
|
||||
Raw: map[string]any{"method": "GET"},
|
||||
}
|
||||
|
||||
log := NewCorrelatedLogFromEvent(event, "A")
|
||||
if log.Fields["keepalives"] != 3 {
|
||||
t.Errorf("expected keepalives=3, got %v", log.Fields["keepalives"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCorrelatedLogFromEvent_NoKeepAlive verifies keepalives field is absent when KeepAliveSeq == 0.
|
||||
func TestNewCorrelatedLogFromEvent_NoKeepAlive(t *testing.T) {
|
||||
event := &NormalizedEvent{
|
||||
Source: SourceA,
|
||||
Timestamp: time.Now(),
|
||||
SrcIP: "1.1.1.1",
|
||||
SrcPort: 9999,
|
||||
KeepAliveSeq: 0,
|
||||
Raw: map[string]any{"method": "GET"},
|
||||
}
|
||||
|
||||
log := NewCorrelatedLogFromEvent(event, "A")
|
||||
if _, ok := log.Fields["keepalives"]; ok {
|
||||
t.Error("keepalives field should not be present when KeepAliveSeq == 0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMergeFields_NoCollision verifies fields from A and B are merged without conflict.
|
||||
func TestMergeFields_NoCollision(t *testing.T) {
|
||||
a := &NormalizedEvent{Raw: map[string]any{"method": "GET", "path": "/foo"}}
|
||||
b := &NormalizedEvent{Raw: map[string]any{"ja4": "abc123", "proto": "TLS"}}
|
||||
|
||||
fields := mergeFields(a, b)
|
||||
if fields["method"] != "GET" {
|
||||
t.Errorf("expected method=GET, got %v", fields["method"])
|
||||
}
|
||||
if fields["ja4"] != "abc123" {
|
||||
t.Errorf("expected ja4=abc123, got %v", fields["ja4"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestMergeFields_SameValueNoPrefix verifies same-value fields are not prefixed.
|
||||
func TestMergeFields_SameValueNoPrefix(t *testing.T) {
|
||||
a := &NormalizedEvent{Raw: map[string]any{"proto": "TCP"}}
|
||||
b := &NormalizedEvent{Raw: map[string]any{"proto": "TCP"}}
|
||||
|
||||
fields := mergeFields(a, b)
|
||||
if fields["proto"] != "TCP" {
|
||||
t.Errorf("expected proto=TCP (no prefix), got %v", fields["proto"])
|
||||
}
|
||||
if _, ok := fields["a_proto"]; ok {
|
||||
t.Error("a_proto should not exist for same-value collision")
|
||||
}
|
||||
if _, ok := fields["b_proto"]; ok {
|
||||
t.Error("b_proto should not exist for same-value collision")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMergeFields_DifferentValuePrefix verifies different-value fields get a_/b_ prefix.
|
||||
func TestMergeFields_DifferentValuePrefix(t *testing.T) {
|
||||
a := &NormalizedEvent{Raw: map[string]any{"port": 80}}
|
||||
b := &NormalizedEvent{Raw: map[string]any{"port": 443}}
|
||||
|
||||
fields := mergeFields(a, b)
|
||||
if fields["a_port"] != 80 {
|
||||
t.Errorf("expected a_port=80, got %v", fields["a_port"])
|
||||
}
|
||||
if fields["b_port"] != 443 {
|
||||
t.Errorf("expected b_port=443, got %v", fields["b_port"])
|
||||
}
|
||||
if _, ok := fields["port"]; ok {
|
||||
t.Error("original 'port' key should be removed on collision")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoalesceString_EmptyA tests that when a is empty, b is returned.
|
||||
func TestCoalesceString_EmptyA(t *testing.T) {
|
||||
result := coalesceString("", "fallback")
|
||||
if result != "fallback" {
|
||||
t.Errorf("expected 'fallback', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoalesceString_NonEmptyA tests that when a is non-empty, a is returned.
|
||||
func TestCoalesceString_NonEmptyA(t *testing.T) {
|
||||
result := coalesceString("primary", "fallback")
|
||||
if result != "primary" {
|
||||
t.Errorf("expected 'primary', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoalesceInt_ZeroA tests that when a is zero, b is returned.
|
||||
func TestCoalesceInt_ZeroA(t *testing.T) {
|
||||
result := coalesceInt(0, 443)
|
||||
if result != 443 {
|
||||
t.Errorf("expected 443, got %d", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoalesceInt_NonZeroA tests that when a is non-zero, a is returned.
|
||||
func TestCoalesceInt_NonZeroA(t *testing.T) {
|
||||
result := coalesceInt(80, 443)
|
||||
if result != 80 {
|
||||
t.Errorf("expected 80, got %d", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMarshalJSON_ReservedKeyProtection verifies reserved keys in Fields are not overwritten.
|
||||
func TestMarshalJSON_ReservedKeyProtection(t *testing.T) {
|
||||
log := CorrelatedLog{
|
||||
Timestamp: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC),
|
||||
SrcIP: "1.2.3.4",
|
||||
SrcPort: 1234,
|
||||
Correlated: true,
|
||||
Fields: map[string]any{
|
||||
"src_ip": "EVIL_OVERRIDE", // should be ignored
|
||||
"correlated": false, // should be ignored
|
||||
"extra": "value",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(log)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalJSON failed: %v", err)
|
||||
}
|
||||
|
||||
var flat map[string]any
|
||||
if err := json.Unmarshal(data, &flat); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if flat["src_ip"] != "1.2.3.4" {
|
||||
t.Errorf("reserved key src_ip should not be overwritten, got %v", flat["src_ip"])
|
||||
}
|
||||
if flat["correlated"] != true {
|
||||
t.Errorf("reserved key correlated should not be overwritten, got %v", flat["correlated"])
|
||||
}
|
||||
if flat["extra"] != "value" {
|
||||
t.Errorf("non-reserved key extra should be present, got %v", flat["extra"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestMarshalJSON_OptionalFieldsOmittedWhenZero verifies DstIP/DstPort are omitted when zero.
|
||||
func TestMarshalJSON_OptionalFieldsOmittedWhenZero(t *testing.T) {
|
||||
log := CorrelatedLog{
|
||||
Timestamp: time.Now(),
|
||||
SrcIP: "1.2.3.4",
|
||||
SrcPort: 1234,
|
||||
Correlated: false,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(log)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalJSON failed: %v", err)
|
||||
}
|
||||
|
||||
var flat map[string]any
|
||||
if err := json.Unmarshal(data, &flat); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := flat["dst_ip"]; ok {
|
||||
t.Error("dst_ip should be omitted when empty")
|
||||
}
|
||||
if _, ok := flat["dst_port"]; ok {
|
||||
t.Error("dst_port should be omitted when zero")
|
||||
}
|
||||
if _, ok := flat["orphan_side"]; ok {
|
||||
t.Error("orphan_side should be omitted when empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractFields_Basic verifies extractFields copies Raw fields.
|
||||
func TestExtractFields_Basic(t *testing.T) {
|
||||
e := &NormalizedEvent{
|
||||
Raw: map[string]any{"key1": "val1", "key2": 42},
|
||||
}
|
||||
fields := extractFields(e)
|
||||
if fields["key1"] != "val1" {
|
||||
t.Errorf("expected key1=val1, got %v", fields["key1"])
|
||||
}
|
||||
if fields["key2"] != 42 {
|
||||
t.Errorf("expected key2=42, got %v", fields["key2"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewCorrelatedLog_KeepAliveSeq verifies keepalives is set from apache event.
|
||||
func TestNewCorrelatedLog_KeepAliveSeq(t *testing.T) {
|
||||
a := &NormalizedEvent{
|
||||
Source: SourceA, Timestamp: time.Now(), SrcIP: "1.1.1.1", SrcPort: 100,
|
||||
KeepAliveSeq: 5,
|
||||
Raw: map[string]any{},
|
||||
}
|
||||
b := &NormalizedEvent{
|
||||
Source: SourceB, Timestamp: time.Now(), SrcIP: "1.1.1.1", SrcPort: 100,
|
||||
Raw: map[string]any{},
|
||||
}
|
||||
|
||||
log := NewCorrelatedLog(a, b)
|
||||
if log.Fields["keepalives"] != 5 {
|
||||
t.Errorf("expected keepalives=5, got %v", log.Fields["keepalives"])
|
||||
}
|
||||
}
|
||||
1017
services/correlator/internal/domain/correlation_service.go
Normal file
1017
services/correlator/internal/domain/correlation_service.go
Normal file
File diff suppressed because it is too large
Load Diff
1865
services/correlator/internal/domain/correlation_service_test.go
Normal file
1865
services/correlator/internal/domain/correlation_service_test.go
Normal file
File diff suppressed because it is too large
Load Diff
33
services/correlator/internal/domain/event.go
Normal file
33
services/correlator/internal/domain/event.go
Normal file
@ -0,0 +1,33 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EventSource identifies the source of an event.
|
||||
type EventSource string
|
||||
|
||||
const (
|
||||
SourceA EventSource = "A" // Apache/HTTP source
|
||||
SourceB EventSource = "B" // Network source
|
||||
)
|
||||
|
||||
// NormalizedEvent represents a unified internal event from either source.
|
||||
type NormalizedEvent struct {
|
||||
Source EventSource
|
||||
Timestamp time.Time
|
||||
SrcIP string
|
||||
SrcPort int
|
||||
DstIP string
|
||||
DstPort int
|
||||
Headers map[string]string
|
||||
Extra map[string]any
|
||||
Raw map[string]any // Original raw data
|
||||
KeepAliveSeq int // Request sequence number within the Keep-Alive connection (1-based)
|
||||
}
|
||||
|
||||
// CorrelationKey returns the key used for correlation (src_ip + src_port).
|
||||
func (e *NormalizedEvent) CorrelationKey() string {
|
||||
return e.SrcIP + ":" + strconv.Itoa(e.SrcPort)
|
||||
}
|
||||
25
services/correlator/internal/observability/logger.go
Normal file
25
services/correlator/internal/observability/logger.go
Normal file
@ -0,0 +1,25 @@
|
||||
// Package observability provides structured logging for the correlator service.
|
||||
// Implementation is delegated to shared/go/ja4common/logger to avoid duplication.
|
||||
package observability
|
||||
|
||||
import jalogger "github.com/antitbone/ja4/ja4common/logger"
|
||||
|
||||
// Type aliases — all existing correlator code compiles unchanged.
|
||||
type Logger = jalogger.Logger
|
||||
type LogLevel = jalogger.LogLevel
|
||||
|
||||
const (
|
||||
DEBUG LogLevel = jalogger.DEBUG
|
||||
INFO LogLevel = jalogger.INFO
|
||||
WARN LogLevel = jalogger.WARN
|
||||
ERROR LogLevel = jalogger.ERROR
|
||||
)
|
||||
|
||||
// NewLogger creates a new Logger with INFO level.
|
||||
func NewLogger(prefix string) *Logger { return jalogger.New(prefix) }
|
||||
|
||||
// NewLoggerWithLevel creates a new Logger with the specified minimum level.
|
||||
func NewLoggerWithLevel(prefix, level string) *Logger { return jalogger.NewWithLevel(prefix, level) }
|
||||
|
||||
// ParseLogLevel converts a string to LogLevel.
|
||||
func ParseLogLevel(level string) LogLevel { return jalogger.ParseLogLevel(level) }
|
||||
296
services/correlator/internal/observability/logger_test.go
Normal file
296
services/correlator/internal/observability/logger_test.go
Normal file
@ -0,0 +1,296 @@
|
||||
// Package observability tests — behavioral tests for the Logger type alias.
|
||||
// Since Logger = jalogger.Logger, we test the observable API only.
|
||||
package observability_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/observability"
|
||||
)
|
||||
|
||||
func TestNewLogger_NonNil(t *testing.T) {
|
||||
logger := observability.NewLogger("test")
|
||||
if logger == nil {
|
||||
t.Fatal("expected non-nil logger")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_DefaultLevel_IsInfo(t *testing.T) {
|
||||
logger := observability.NewLogger("test")
|
||||
if !logger.ShouldLog(observability.INFO) {
|
||||
t.Error("INFO should be enabled by default")
|
||||
}
|
||||
if logger.ShouldLog(observability.DEBUG) {
|
||||
t.Error("DEBUG should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_Info_NoPanic(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "INFO")
|
||||
if !logger.ShouldLog(observability.INFO) {
|
||||
t.Error("INFO should be enabled")
|
||||
}
|
||||
logger.Info("test message")
|
||||
}
|
||||
|
||||
func TestLogger_Error_NoPanic(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "ERROR")
|
||||
if !logger.ShouldLog(observability.ERROR) {
|
||||
t.Error("ERROR should be enabled")
|
||||
}
|
||||
logger.Error("error message", nil)
|
||||
}
|
||||
|
||||
func TestLogger_Debug_NoPanic(t *testing.T) {
|
||||
logger := observability.NewLogger("test")
|
||||
logger.SetLevel("DEBUG")
|
||||
if !logger.ShouldLog(observability.DEBUG) {
|
||||
t.Error("DEBUG should be enabled after SetLevel(DEBUG)")
|
||||
}
|
||||
logger.Debug("test message")
|
||||
}
|
||||
|
||||
func TestLogger_SetLevel(t *testing.T) {
|
||||
logger := observability.NewLogger("test")
|
||||
|
||||
logger.SetLevel("DEBUG")
|
||||
if !logger.ShouldLog(observability.DEBUG) {
|
||||
t.Error("DEBUG should be enabled after SetLevel(DEBUG)")
|
||||
}
|
||||
|
||||
logger.SetLevel("INFO")
|
||||
if logger.ShouldLog(observability.DEBUG) {
|
||||
t.Error("DEBUG should be disabled after SetLevel(INFO)")
|
||||
}
|
||||
|
||||
logger.SetLevel("WARN")
|
||||
if logger.ShouldLog(observability.INFO) {
|
||||
t.Error("INFO should be disabled after SetLevel(WARN)")
|
||||
}
|
||||
if !logger.ShouldLog(observability.WARN) {
|
||||
t.Error("WARN should be enabled after SetLevel(WARN)")
|
||||
}
|
||||
|
||||
logger.SetLevel("ERROR")
|
||||
if logger.ShouldLog(observability.WARN) {
|
||||
t.Error("WARN should be disabled after SetLevel(ERROR)")
|
||||
}
|
||||
if !logger.ShouldLog(observability.ERROR) {
|
||||
t.Error("ERROR should be enabled after SetLevel(ERROR)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLogLevel(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected observability.LogLevel
|
||||
}{
|
||||
{"DEBUG", observability.DEBUG},
|
||||
{"debug", observability.DEBUG},
|
||||
{"INFO", observability.INFO},
|
||||
{"info", observability.INFO},
|
||||
{"WARN", observability.WARN},
|
||||
{"warn", observability.WARN},
|
||||
{"WARNING", observability.WARN},
|
||||
{"ERROR", observability.ERROR},
|
||||
{"error", observability.ERROR},
|
||||
{"", observability.INFO},
|
||||
{"invalid", observability.INFO},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := observability.ParseLogLevel(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ParseLogLevel(%q) = %v, want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_WithFields_NoPanic(t *testing.T) {
|
||||
logger := observability.NewLogger("test")
|
||||
child := logger.WithFields(map[string]any{"key1": "value1", "key2": 42})
|
||||
if child == logger {
|
||||
t.Error("expected different logger instance")
|
||||
}
|
||||
child.Info("message with fields")
|
||||
}
|
||||
|
||||
func TestLogLevel_String(t *testing.T) {
|
||||
cases := []struct {
|
||||
level observability.LogLevel
|
||||
expected string
|
||||
}{
|
||||
{observability.DEBUG, "DEBUG"},
|
||||
{observability.INFO, "INFO"},
|
||||
{observability.WARN, "WARN"},
|
||||
{observability.ERROR, "ERROR"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
if got := tt.level.String(); got != tt.expected {
|
||||
t.Errorf("LogLevel(%d).String() = %q, want %q", tt.level, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_Warn_NoPanic(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "WARN")
|
||||
if !logger.ShouldLog(observability.WARN) {
|
||||
t.Error("WARN should be enabled")
|
||||
}
|
||||
logger.Warn("warning message")
|
||||
}
|
||||
|
||||
func TestLogger_Formatted_NoPanic(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "DEBUG")
|
||||
logger.Warnf("formatted %s %d", "message", 42)
|
||||
logger.Infof("formatted %s %d", "message", 42)
|
||||
logger.Debugf("formatted %s %d", "message", 42)
|
||||
}
|
||||
|
||||
func TestLogger_Error_WithError(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "ERROR")
|
||||
logger.Error("error occurred", &testErr{"test error"})
|
||||
}
|
||||
|
||||
func TestLogger_ShouldLog_Concurrent(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "DEBUG")
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
_ = logger.ShouldLog(observability.DEBUG)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_Log_Concurrent(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "DEBUG")
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(n int) {
|
||||
logger.Debugf("message %d", n)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_WithFields_Concurrent(t *testing.T) {
|
||||
logger := observability.NewLogger("test")
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(n int) {
|
||||
_ = logger.WithFields(map[string]any{"key": n})
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_SetLevel_Concurrent(t *testing.T) {
|
||||
logger := observability.NewLogger("test")
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
logger.SetLevel("DEBUG")
|
||||
logger.SetLevel("INFO")
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
type testErr struct{ msg string }
|
||||
|
||||
func (e *testErr) Error() string { return e.msg }
|
||||
|
||||
func TestNewLoggerWithLevel_AllLevels(t *testing.T) {
|
||||
levels := []string{"DEBUG", "INFO", "WARN", "WARNING", "ERROR", "invalid", ""}
|
||||
for _, level := range levels {
|
||||
t.Run(level, func(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", level)
|
||||
if logger == nil {
|
||||
t.Errorf("NewLoggerWithLevel(%q) returned nil", level)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevel_Constants(t *testing.T) {
|
||||
if observability.DEBUG >= observability.INFO {
|
||||
t.Error("DEBUG should be less than INFO")
|
||||
}
|
||||
if observability.INFO >= observability.WARN {
|
||||
t.Error("INFO should be less than WARN")
|
||||
}
|
||||
if observability.WARN >= observability.ERROR {
|
||||
t.Error("WARN should be less than ERROR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_ShouldLog_AllLevels(t *testing.T) {
|
||||
cases := []struct {
|
||||
minLevel string
|
||||
level observability.LogLevel
|
||||
want bool
|
||||
}{
|
||||
{"DEBUG", observability.DEBUG, true},
|
||||
{"DEBUG", observability.INFO, true},
|
||||
{"DEBUG", observability.WARN, true},
|
||||
{"DEBUG", observability.ERROR, true},
|
||||
{"INFO", observability.DEBUG, false},
|
||||
{"INFO", observability.INFO, true},
|
||||
{"INFO", observability.WARN, true},
|
||||
{"WARN", observability.INFO, false},
|
||||
{"WARN", observability.WARN, true},
|
||||
{"WARN", observability.ERROR, true},
|
||||
{"ERROR", observability.WARN, false},
|
||||
{"ERROR", observability.ERROR, true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.minLevel+"_"+tc.level.String(), func(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", tc.minLevel)
|
||||
got := logger.ShouldLog(tc.level)
|
||||
if got != tc.want {
|
||||
t.Errorf("ShouldLog(%v) with min=%s: expected %v, got %v",
|
||||
tc.level, tc.minLevel, tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLogLevel_WarningAlias(t *testing.T) {
|
||||
got := observability.ParseLogLevel("WARNING")
|
||||
if got != observability.WARN {
|
||||
t.Errorf("ParseLogLevel(WARNING) = %v, want WARN", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_Errorf_NoPanic(t *testing.T) {
|
||||
logger := observability.NewLoggerWithLevel("test", "DEBUG")
|
||||
// Errorf is not defined in the interface, but Warnf/Infof/Debugf are tested
|
||||
// Just ensure Error with a formatted message doesn't panic
|
||||
logger.Error("formatted error", &testErr{"err detail"})
|
||||
}
|
||||
|
||||
func TestNewLogger_PrefixIsUsed(t *testing.T) {
|
||||
logger := observability.NewLogger("my-prefix")
|
||||
if logger == nil {
|
||||
t.Fatal("expected non-nil logger")
|
||||
}
|
||||
// The logger should be usable
|
||||
logger.Infof("hello from %s", "my-prefix")
|
||||
}
|
||||
176
services/correlator/internal/observability/metrics.go
Normal file
176
services/correlator/internal/observability/metrics.go
Normal file
@ -0,0 +1,176 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// CorrelationMetrics tracks correlation statistics for debugging and monitoring.
|
||||
type CorrelationMetrics struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Events received
|
||||
eventsReceivedA atomic.Int64
|
||||
eventsReceivedB atomic.Int64
|
||||
|
||||
// Correlation results
|
||||
correlationsSuccess atomic.Int64
|
||||
correlationsFailed atomic.Int64
|
||||
|
||||
// Failure reasons
|
||||
failedNoMatchKey atomic.Int64 // No event with same key in buffer
|
||||
failedTimeWindow atomic.Int64 // Key found but outside time window
|
||||
failedBufferEviction atomic.Int64 // Event evicted due to buffer full
|
||||
failedTTLExpired atomic.Int64 // B event TTL expired before match
|
||||
failedIPExcluded atomic.Int64 // Event excluded by IP filter
|
||||
|
||||
// Buffer stats
|
||||
bufferASize atomic.Int64
|
||||
bufferBSize atomic.Int64
|
||||
|
||||
// Orphan stats
|
||||
orphansEmittedA atomic.Int64
|
||||
orphansEmittedB atomic.Int64
|
||||
orphansPendingA atomic.Int64
|
||||
pendingOrphanMatch atomic.Int64 // B matched with pending orphan A
|
||||
|
||||
// Keep-Alive stats
|
||||
keepAliveResets atomic.Int64 // Number of TTL resets (one-to-many mode)
|
||||
}
|
||||
|
||||
// NewCorrelationMetrics creates a new metrics tracker.
|
||||
func NewCorrelationMetrics() *CorrelationMetrics {
|
||||
return &CorrelationMetrics{}
|
||||
}
|
||||
|
||||
// RecordEventReceived records an event received from a source.
|
||||
func (m *CorrelationMetrics) RecordEventReceived(source string) {
|
||||
if source == "A" {
|
||||
m.eventsReceivedA.Add(1)
|
||||
} else if source == "B" {
|
||||
m.eventsReceivedB.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordCorrelationSuccess records a successful correlation.
|
||||
func (m *CorrelationMetrics) RecordCorrelationSuccess() {
|
||||
m.correlationsSuccess.Add(1)
|
||||
}
|
||||
|
||||
// RecordCorrelationFailed records a failed correlation attempt with the reason.
|
||||
func (m *CorrelationMetrics) RecordCorrelationFailed(reason string) {
|
||||
m.correlationsFailed.Add(1)
|
||||
switch reason {
|
||||
case "no_match_key":
|
||||
m.failedNoMatchKey.Add(1)
|
||||
case "time_window":
|
||||
m.failedTimeWindow.Add(1)
|
||||
case "buffer_eviction":
|
||||
m.failedBufferEviction.Add(1)
|
||||
case "ttl_expired":
|
||||
m.failedTTLExpired.Add(1)
|
||||
case "ip_excluded":
|
||||
m.failedIPExcluded.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordBufferEviction records an event evicted from buffer.
|
||||
func (m *CorrelationMetrics) RecordBufferEviction(source string) {
|
||||
// Can be used for additional tracking if needed
|
||||
}
|
||||
|
||||
// RecordOrphanEmitted records an orphan event emitted.
|
||||
func (m *CorrelationMetrics) RecordOrphanEmitted(source string) {
|
||||
if source == "A" {
|
||||
m.orphansEmittedA.Add(1)
|
||||
} else if source == "B" {
|
||||
m.orphansEmittedB.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordPendingOrphan records an A event added to pending orphans.
|
||||
func (m *CorrelationMetrics) RecordPendingOrphan() {
|
||||
m.orphansPendingA.Add(1)
|
||||
}
|
||||
|
||||
// RecordPendingOrphanMatch records a B event matching a pending orphan A.
|
||||
func (m *CorrelationMetrics) RecordPendingOrphanMatch() {
|
||||
m.pendingOrphanMatch.Add(1)
|
||||
}
|
||||
|
||||
// RecordKeepAliveReset records a TTL reset for Keep-Alive.
|
||||
func (m *CorrelationMetrics) RecordKeepAliveReset() {
|
||||
m.keepAliveResets.Add(1)
|
||||
}
|
||||
|
||||
// UpdateBufferSizes updates the current buffer sizes.
|
||||
func (m *CorrelationMetrics) UpdateBufferSizes(sizeA, sizeB int64) {
|
||||
m.bufferASize.Store(sizeA)
|
||||
m.bufferBSize.Store(sizeB)
|
||||
}
|
||||
|
||||
// Snapshot returns a point-in-time snapshot of all metrics.
|
||||
func (m *CorrelationMetrics) Snapshot() MetricsSnapshot {
|
||||
return MetricsSnapshot{
|
||||
EventsReceivedA: m.eventsReceivedA.Load(),
|
||||
EventsReceivedB: m.eventsReceivedB.Load(),
|
||||
CorrelationsSuccess: m.correlationsSuccess.Load(),
|
||||
CorrelationsFailed: m.correlationsFailed.Load(),
|
||||
FailedNoMatchKey: m.failedNoMatchKey.Load(),
|
||||
FailedTimeWindow: m.failedTimeWindow.Load(),
|
||||
FailedBufferEviction: m.failedBufferEviction.Load(),
|
||||
FailedTTLExpired: m.failedTTLExpired.Load(),
|
||||
FailedIPExcluded: m.failedIPExcluded.Load(),
|
||||
BufferASize: m.bufferASize.Load(),
|
||||
BufferBSize: m.bufferBSize.Load(),
|
||||
OrphansEmittedA: m.orphansEmittedA.Load(),
|
||||
OrphansEmittedB: m.orphansEmittedB.Load(),
|
||||
OrphansPendingA: m.orphansPendingA.Load(),
|
||||
PendingOrphanMatch: m.pendingOrphanMatch.Load(),
|
||||
KeepAliveResets: m.keepAliveResets.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsSnapshot is a point-in-time snapshot of metrics.
|
||||
type MetricsSnapshot struct {
|
||||
EventsReceivedA int64 `json:"events_received_a"`
|
||||
EventsReceivedB int64 `json:"events_received_b"`
|
||||
CorrelationsSuccess int64 `json:"correlations_success"`
|
||||
CorrelationsFailed int64 `json:"correlations_failed"`
|
||||
FailedNoMatchKey int64 `json:"failed_no_match_key"`
|
||||
FailedTimeWindow int64 `json:"failed_time_window"`
|
||||
FailedBufferEviction int64 `json:"failed_buffer_eviction"`
|
||||
FailedTTLExpired int64 `json:"failed_ttl_expired"`
|
||||
FailedIPExcluded int64 `json:"failed_ip_excluded"`
|
||||
BufferASize int64 `json:"buffer_a_size"`
|
||||
BufferBSize int64 `json:"buffer_b_size"`
|
||||
OrphansEmittedA int64 `json:"orphans_emitted_a"`
|
||||
OrphansEmittedB int64 `json:"orphans_emitted_b"`
|
||||
OrphansPendingA int64 `json:"orphans_pending_a"`
|
||||
PendingOrphanMatch int64 `json:"pending_orphan_match"`
|
||||
KeepAliveResets int64 `json:"keepalive_resets"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (m *CorrelationMetrics) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(m.Snapshot())
|
||||
}
|
||||
|
||||
// String returns a human-readable string of metrics.
|
||||
func (m *CorrelationMetrics) String() string {
|
||||
s := m.Snapshot()
|
||||
var b strings.Builder
|
||||
b.WriteString("Correlation Metrics:\n")
|
||||
fmt.Fprintf(&b, " Events Received: A=%d B=%d Total=%d\n", s.EventsReceivedA, s.EventsReceivedB, s.EventsReceivedA+s.EventsReceivedB)
|
||||
fmt.Fprintf(&b, " Correlations: Success=%d Failed=%d\n", s.CorrelationsSuccess, s.CorrelationsFailed)
|
||||
fmt.Fprintf(&b, " Failure Reasons: no_match_key=%d time_window=%d buffer_eviction=%d ttl_expired=%d ip_excluded=%d\n",
|
||||
s.FailedNoMatchKey, s.FailedTimeWindow, s.FailedBufferEviction, s.FailedTTLExpired, s.FailedIPExcluded)
|
||||
fmt.Fprintf(&b, " Buffer Sizes: A=%d B=%d\n", s.BufferASize, s.BufferBSize)
|
||||
fmt.Fprintf(&b, " Orphans: Emitted A=%d B=%d Pending A=%d\n", s.OrphansEmittedA, s.OrphansEmittedB, s.OrphansPendingA)
|
||||
fmt.Fprintf(&b, " Pending Orphan Match: %d\n", s.PendingOrphanMatch)
|
||||
fmt.Fprintf(&b, " Keep-Alive Resets: %d\n", s.KeepAliveResets)
|
||||
return b.String()
|
||||
}
|
||||
128
services/correlator/internal/observability/metrics_server.go
Normal file
128
services/correlator/internal/observability/metrics_server.go
Normal file
@ -0,0 +1,128 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MetricsServer exposes correlation metrics via HTTP.
|
||||
type MetricsServer struct {
|
||||
mu sync.Mutex
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
metricsFunc func() MetricsSnapshot
|
||||
running bool
|
||||
}
|
||||
|
||||
// NewMetricsServer creates a new metrics HTTP server.
|
||||
func NewMetricsServer(addr string, metricsFunc func() MetricsSnapshot) (*MetricsServer, error) {
|
||||
if metricsFunc == nil {
|
||||
return nil, fmt.Errorf("metricsFunc cannot be nil")
|
||||
}
|
||||
|
||||
ms := &MetricsServer{
|
||||
metricsFunc: metricsFunc,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/metrics", ms.handleMetrics)
|
||||
mux.HandleFunc("/health", ms.handleHealth)
|
||||
|
||||
ms.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// Start begins listening on the configured address.
|
||||
func (ms *MetricsServer) Start() error {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
if ms.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", ms.server.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start metrics server: %w", err)
|
||||
}
|
||||
|
||||
ms.listener = listener
|
||||
ms.running = true
|
||||
|
||||
go func() {
|
||||
if err := ms.server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
// Server error or closed
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the metrics server.
|
||||
func (ms *MetricsServer) Stop(ctx context.Context) error {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
if !ms.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
ms.running = false
|
||||
return ms.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// handleMetrics returns the correlation metrics as JSON.
|
||||
func (ms *MetricsServer) handleMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
metrics := ms.metricsFunc()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(metrics); err != nil {
|
||||
http.Error(w, "Failed to encode metrics", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleHealth returns a simple health check response.
|
||||
func (ms *MetricsServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"status":"healthy"}`)
|
||||
}
|
||||
|
||||
// IsRunning returns true if the server is running.
|
||||
func (ms *MetricsServer) IsRunning() bool {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
return ms.running
|
||||
}
|
||||
|
||||
// Addr returns the listening address.
|
||||
func (ms *MetricsServer) Addr() string {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
if ms.listener == nil {
|
||||
return ""
|
||||
}
|
||||
return ms.listener.Addr().String()
|
||||
}
|
||||
57
services/correlator/internal/ports/source.go
Normal file
57
services/correlator/internal/ports/source.go
Normal file
@ -0,0 +1,57 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/antitbone/ja4/correlator/internal/domain"
|
||||
)
|
||||
|
||||
// EventSource defines the interface for log sources.
|
||||
type EventSource interface {
|
||||
// Start begins reading events and sending them to the channel.
|
||||
// Returns an error if the source cannot be started.
|
||||
Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error
|
||||
|
||||
// Stop gracefully stops the source.
|
||||
Stop() error
|
||||
|
||||
// Name returns the source name.
|
||||
Name() string
|
||||
}
|
||||
|
||||
// CorrelatedLogSink defines the interface for correlated log destinations.
|
||||
type CorrelatedLogSink interface {
|
||||
// Write sends a correlated log to the sink.
|
||||
Write(ctx context.Context, log domain.CorrelatedLog) error
|
||||
|
||||
// Flush flushes any buffered logs.
|
||||
Flush(ctx context.Context) error
|
||||
|
||||
// Close closes the sink.
|
||||
Close() error
|
||||
|
||||
// Name returns the sink name.
|
||||
Name() string
|
||||
|
||||
// Reopen closes and reopens the sink (for log rotation on SIGHUP).
|
||||
// Optional: only FileSink implements this.
|
||||
Reopen() error
|
||||
}
|
||||
|
||||
// CorrelationProcessor defines the interface for the correlation service.
|
||||
// This allows for easier testing and alternative implementations.
|
||||
type CorrelationProcessor interface {
|
||||
// ProcessEvent processes an incoming event and returns correlated logs.
|
||||
ProcessEvent(event *domain.NormalizedEvent) []domain.CorrelatedLog
|
||||
|
||||
// Flush forces emission of remaining buffered events.
|
||||
Flush() []domain.CorrelatedLog
|
||||
|
||||
// EmitPendingOrphans emits orphan A events whose delay has expired.
|
||||
// Called periodically by the Orchestrator ticker so orphans are not blocked
|
||||
// waiting for the next incoming event.
|
||||
EmitPendingOrphans() []domain.CorrelatedLog
|
||||
|
||||
// GetBufferSizes returns the current buffer sizes for monitoring.
|
||||
GetBufferSizes() (int, int)
|
||||
}
|
||||
Reference in New Issue
Block a user