From 27c76593977955764cdbd6375c04f1e33561c45b Mon Sep 17 00:00:00 2001 From: Jacquin Antoine Date: Sun, 1 Mar 2026 12:10:17 +0100 Subject: [PATCH] =?UTF-8?q?fix:=20renforcer=20corr=C3=A9lation=20A/B=20et?= =?UTF-8?q?=20sorties=20stdout/fichier?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) --- cmd/logcorrelator/main.go | 4 +- config.example.yml | 6 +- .../adapters/inbound/unixsocket/source.go | 96 +++++++------ .../inbound/unixsocket/source_test.go | 129 +++++++++++------- internal/adapters/outbound/clickhouse/sink.go | 74 +++++++--- internal/adapters/outbound/file/sink.go | 72 +++++----- internal/adapters/outbound/file/sink_test.go | 37 ++++- internal/config/config.go | 32 ++--- internal/config/config_test.go | 36 ++++- internal/domain/correlated_log.go | 38 +++++- internal/domain/correlation_service.go | 36 ++--- internal/domain/correlation_service_test.go | 42 +++++- internal/observability/logger.go | 98 ++++++------- 13 files changed, 441 insertions(+), 259 deletions(-) diff --git a/cmd/logcorrelator/main.go b/cmd/logcorrelator/main.go index ea77ab5..18fe2a9 100644 --- a/cmd/logcorrelator/main.go +++ b/cmd/logcorrelator/main.go @@ -92,7 +92,7 @@ func main() { logger.Info(fmt.Sprintf("Configured ClickHouse sink: table=%s", cfg.Outputs.ClickHouse.Table)) } - if cfg.Outputs.Stdout { + if cfg.Outputs.Stdout.Enabled { stdoutSink := stdout.NewStdoutSink(stdout.Config{ Enabled: true, }) @@ -110,7 +110,7 @@ func main() { NetworkEmit: false, MaxBufferSize: domain.DefaultMaxBufferSize, }, &domain.RealTimeProvider{}) - + // Set logger for correlation service correlationSvc.SetLogger(logger.WithFields(map[string]any{"component": "correlation"})) diff --git a/config.example.yml b/config.example.yml index 67bff45..39b6e08 100644 --- a/config.example.yml +++ b/config.example.yml @@ -8,10 +8,12 @@ log: inputs: unix_sockets: - name: http + source_type: A path: /var/run/logcorrelator/http.socket format: json socket_permissions: "0660" # owner + group read/write - name: network + source_type: B path: /var/run/logcorrelator/network.socket format: json socket_permissions: "0660" @@ -24,9 +26,9 @@ outputs: dsn: clickhouse://user:pass@localhost:9000/db table: correlated_logs_http_network - stdout: false + stdout: + enabled: false correlation: time_window_s: 1 emit_orphans: true # http toujours émis, network jamais seul - diff --git a/internal/adapters/inbound/unixsocket/source.go b/internal/adapters/inbound/unixsocket/source.go index 2733a2c..90aae41 100644 --- a/internal/adapters/inbound/unixsocket/source.go +++ b/internal/adapters/inbound/unixsocket/source.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "math" "net" "os" "strconv" @@ -180,7 +181,7 @@ func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventC } // Debug: log raw events - s.logger.Debugf("event received: source=%s src_ip=%s src_port=%d", + s.logger.Debugf("event received: source=%s src_ip=%s src_port=%d", event.Source, event.SrcIP, event.SrcPort) select { @@ -191,6 +192,21 @@ func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventC } } +func resolveSource(sourceType string, headers map[string]string) domain.EventSource { + switch strings.ToLower(strings.TrimSpace(sourceType)) { + case "a", "apache", "http": + return domain.SourceA + case "b", "network", "net": + return domain.SourceB + default: + // fallback compat + if len(headers) > 0 { + return domain.SourceA + } + return domain.SourceB + } +} + func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, error) { var raw map[string]any if err := json.Unmarshal(data, &raw); err != nil { @@ -198,12 +214,29 @@ func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, er } event := &domain.NormalizedEvent{ - Raw: raw, - Extra: make(map[string]any), + Raw: raw, + Extra: make(map[string]any), + Headers: make(map[string]string), } + // Extract headers (header_* fields) first + for k, v := range raw { + if strings.HasPrefix(k, "header_") { + if sv, ok := v.(string); ok { + event.Headers[k[7:]] = sv + } + } + } + + // Resolve source first (strict timestamp logic depends on source) + event.Source = resolveSource(sourceType, event.Headers) + // Extract and validate src_ip if v, ok := getString(raw, "src_ip"); ok { + v = strings.TrimSpace(v) + if v == "" { + return nil, fmt.Errorf("src_ip cannot be empty") + } event.SrcIP = v } else { return nil, fmt.Errorf("missing required field: src_ip") @@ -221,7 +254,7 @@ func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, er // Extract dst_ip (optional) if v, ok := getString(raw, "dst_ip"); ok { - event.DstIP = v + event.DstIP = strings.TrimSpace(v) } // Extract dst_port (optional) @@ -232,50 +265,23 @@ func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, er event.DstPort = v } - // Extract timestamp - try different fields - if ts, ok := getInt64(raw, "timestamp"); ok { + // Extract timestamp based on source contract + switch event.Source { + case domain.SourceA: + ts, ok := getInt64(raw, "timestamp") + if !ok { + return nil, fmt.Errorf("missing required numeric field: timestamp for source A") + } // Assume nanoseconds event.Timestamp = time.Unix(0, ts) - } else if tsStr, ok := getString(raw, "time"); ok { - if t, err := time.Parse(time.RFC3339, tsStr); err == nil { - event.Timestamp = t - } - } else if tsStr, ok := getString(raw, "timestamp"); ok { - if t, err := time.Parse(time.RFC3339, tsStr); err == nil { - event.Timestamp = t - } - } - - if event.Timestamp.IsZero() { + case domain.SourceB: + // For network source, always use local reception time event.Timestamp = time.Now() - } - - // Extract headers (header_* fields) - event.Headers = make(map[string]string) - for k, v := range raw { - if len(k) > 7 && k[:7] == "header_" { - if sv, ok := v.(string); ok { - event.Headers[k[7:]] = sv - } - } - } - - // Determine source based on explicit config or fallback to heuristic - switch sourceType { - case "A", "a", "apache", "http": - event.Source = domain.SourceA - case "B", "b", "network", "net": - event.Source = domain.SourceB default: - // Fallback to heuristic detection for backward compatibility - if len(event.Headers) > 0 { - event.Source = domain.SourceA - } else { - event.Source = domain.SourceB - } + return nil, fmt.Errorf("unsupported source type: %s", event.Source) } - // Extra fields (single pass) + // Extra fields knownFields := map[string]bool{ "src_ip": true, "src_port": true, "dst_ip": true, "dst_port": true, "timestamp": true, "time": true, @@ -306,6 +312,9 @@ func getInt(m map[string]any, key string) (int, bool) { if v, ok := m[key]; ok { switch val := v.(type) { case float64: + if math.Trunc(val) != val { + return 0, false + } return int(val), true case int: return val, true @@ -324,6 +333,9 @@ func getInt64(m map[string]any, key string) (int64, bool) { if v, ok := m[key]; ok { switch val := v.(type) { case float64: + if math.Trunc(val) != val { + return 0, false + } return int64(val), true case int: return int64(val), true diff --git a/internal/adapters/inbound/unixsocket/source_test.go b/internal/adapters/inbound/unixsocket/source_test.go index 4fd5235..d706daa 100644 --- a/internal/adapters/inbound/unixsocket/source_test.go +++ b/internal/adapters/inbound/unixsocket/source_test.go @@ -41,6 +41,10 @@ func TestParseJSONEvent_Apache(t *testing.T) { if event.Source != domain.SourceA { t.Errorf("expected source A, got %s", event.Source) } + expectedTs := time.Unix(0, 1704110400000000000) + if !event.Timestamp.Equal(expectedTs) { + t.Errorf("expected timestamp %v, got %v", expectedTs, event.Timestamp) + } } func TestParseJSONEvent_Network(t *testing.T) { @@ -49,12 +53,15 @@ func TestParseJSONEvent_Network(t *testing.T) { "src_port": 8080, "dst_ip": "10.0.0.1", "dst_port": 443, + "timestamp": 1704110400000000000, "ja3": "abc123def456", "ja4": "xyz789", "tcp_meta_flags": "SYN" }`) + before := time.Now() event, err := parseJSONEvent(data, "B") + after := time.Now() if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -68,6 +75,9 @@ func TestParseJSONEvent_Network(t *testing.T) { if event.Source != domain.SourceB { t.Errorf("expected source B, got %s", event.Source) } + if event.Timestamp.Before(before.Add(-2*time.Second)) || event.Timestamp.After(after.Add(2*time.Second)) { + t.Errorf("expected network timestamp near now, got %v", event.Timestamp) + } } func TestParseJSONEvent_InvalidJSON(t *testing.T) { @@ -88,21 +98,35 @@ func TestParseJSONEvent_MissingFields(t *testing.T) { } } -func TestParseJSONEvent_StringTimestamp(t *testing.T) { +func TestParseJSONEvent_SourceARequiresNumericTimestamp(t *testing.T) { data := []byte(`{ "src_ip": "192.168.1.1", "src_port": 8080, "time": "2024-01-01T12:00:00Z" }`) - event, err := parseJSONEvent(data, "") + _, err := parseJSONEvent(data, "A") + if err == nil { + t.Fatal("expected error for source A without numeric timestamp") + } +} + +func TestParseJSONEvent_SourceBIgnoresPayloadTimestamp(t *testing.T) { + data := []byte(`{ + "src_ip": "192.168.1.1", + "src_port": 8080, + "timestamp": 1 + }`) + + before := time.Now() + event, err := parseJSONEvent(data, "B") + after := time.Now() if err != nil { t.Fatalf("unexpected error: %v", err) } - expected := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) - if !event.Timestamp.Equal(expected) { - t.Errorf("expected timestamp %v, got %v", expected, event.Timestamp) + if event.Timestamp.Before(before.Add(-2*time.Second)) || event.Timestamp.After(after.Add(2*time.Second)) { + t.Errorf("expected source B timestamp near now, got %v", event.Timestamp) } } @@ -114,40 +138,40 @@ func TestParseJSONEvent_ExplicitSourceType(t *testing.T) { expected domain.EventSource }{ { - name: "explicit A", - data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, + name: "explicit A", + data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000}`, sourceType: "A", - expected: domain.SourceA, + expected: domain.SourceA, }, { - name: "explicit B", - data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, + name: "explicit B", + data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, sourceType: "B", - expected: domain.SourceB, + expected: domain.SourceB, }, { - name: "explicit apache", - data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, + name: "explicit apache", + data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000}`, sourceType: "apache", - expected: domain.SourceA, + expected: domain.SourceA, }, { - name: "explicit network", - data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, + name: "explicit network", + data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, sourceType: "network", - expected: domain.SourceB, + expected: domain.SourceB, }, { - name: "auto-detect A with headers", - data: `{"src_ip": "192.168.1.1", "src_port": 8080, "header_host": "example.com"}`, + name: "auto-detect A with headers", + data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000, "header_host": "example.com"}`, sourceType: "", - expected: domain.SourceA, + expected: domain.SourceA, }, { - name: "auto-detect B without headers", - data: `{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "abc"}`, + name: "auto-detect B without headers", + data: `{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "abc"}`, sourceType: "", - expected: domain.SourceB, + expected: domain.SourceB, }, } @@ -241,7 +265,7 @@ func TestGetInt(t *testing.T) { expected int ok bool }{ - {"float", 42, true}, + {"float", 0, false}, {"int", 42, true}, {"int64", 42, true}, {"string", 42, true}, @@ -278,7 +302,7 @@ func TestGetInt64(t *testing.T) { expected int64 ok bool }{ - {"float", 42, true}, + {"float", 0, false}, {"int", 42, true}, {"int64", 42, true}, {"string", 42, true}, @@ -302,45 +326,52 @@ func TestGetInt64(t *testing.T) { func TestParseJSONEvent_PortValidation(t *testing.T) { tests := []struct { - name string - data string - wantErr bool + name string + data string + sourceType string + wantErr bool }{ { - name: "valid src_port", - data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, - wantErr: false, + name: "valid src_port", + data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, + sourceType: "B", + wantErr: false, }, { - name: "src_port zero", - data: `{"src_ip": "192.168.1.1", "src_port": 0}`, - wantErr: true, + name: "src_port zero", + data: `{"src_ip": "192.168.1.1", "src_port": 0}`, + sourceType: "B", + wantErr: true, }, { - name: "src_port negative", - data: `{"src_ip": "192.168.1.1", "src_port": -1}`, - wantErr: true, + name: "src_port negative", + data: `{"src_ip": "192.168.1.1", "src_port": -1}`, + sourceType: "B", + wantErr: true, }, { - name: "src_port too high", - data: `{"src_ip": "192.168.1.1", "src_port": 70000}`, - wantErr: true, + name: "src_port too high", + data: `{"src_ip": "192.168.1.1", "src_port": 70000}`, + sourceType: "B", + wantErr: true, }, { - name: "valid dst_port zero", - data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 0}`, - wantErr: false, + name: "valid dst_port zero", + data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 0}`, + sourceType: "B", + wantErr: false, }, { - name: "dst_port too high", - data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 70000}`, - wantErr: true, + name: "dst_port too high", + data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 70000}`, + sourceType: "B", + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := parseJSONEvent([]byte(tt.data), "") + _, err := parseJSONEvent([]byte(tt.data), tt.sourceType) if (err != nil) != tt.wantErr { t.Errorf("parseJSONEvent() error = %v, wantErr %v", err, tt.wantErr) } @@ -350,12 +381,12 @@ func TestParseJSONEvent_PortValidation(t *testing.T) { func TestParseJSONEvent_TimestampFallback(t *testing.T) { data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080}`) - event, err := parseJSONEvent(data, "") + event, err := parseJSONEvent(data, "B") if err != nil { t.Fatalf("unexpected error: %v", err) } - // Should fallback to current time if no timestamp provided + // For source B, timestamp is reception time if event.Timestamp.IsZero() { t.Error("expected non-zero timestamp") } diff --git a/internal/adapters/outbound/clickhouse/sink.go b/internal/adapters/outbound/clickhouse/sink.go index 78de0fb..2b1a0ba 100644 --- a/internal/adapters/outbound/clickhouse/sink.go +++ b/internal/adapters/outbound/clickhouse/sink.go @@ -4,7 +4,9 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" + "net" "strings" "sync" "time" @@ -115,35 +117,37 @@ func (s *ClickHouseSink) Name() string { // Write adds a log to the buffer. func (s *ClickHouseSink) Write(ctx context.Context, log domain.CorrelatedLog) error { - s.mu.Lock() - defer s.mu.Unlock() + deadline := time.Now().Add(time.Duration(s.config.TimeoutMs) * time.Millisecond) - // Check buffer overflow - if len(s.buffer) >= s.config.MaxBufferSize { - if s.config.DropOnOverflow { - // Drop the log + for { + s.mu.Lock() + if len(s.buffer) < s.config.MaxBufferSize { + s.buffer = append(s.buffer, log) + if len(s.buffer) >= s.config.BatchSize { + select { + case s.flushChan <- struct{}{}: + default: + } + } + s.mu.Unlock() return nil } - // Block until space is available (with timeout) + drop := s.config.DropOnOverflow + s.mu.Unlock() + + if drop { + return nil + } + if time.Now().After(deadline) { + return fmt.Errorf("buffer full, timeout exceeded") + } + select { case <-ctx.Done(): return ctx.Err() - case <-time.After(time.Duration(s.config.TimeoutMs) * time.Millisecond): - return fmt.Errorf("buffer full, timeout exceeded") + case <-time.After(10 * time.Millisecond): } } - - s.buffer = append(s.buffer, log) - - // Trigger flush if batch is full - if len(s.buffer) >= s.config.BatchSize { - select { - case s.flushChan <- struct{}{}: - default: - } - } - - return nil } // Flush flushes the buffer to ClickHouse. @@ -311,7 +315,33 @@ func isRetryableError(err error) bool { if err == nil { return false } + + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + if errors.Is(err, context.Canceled) { + return false + } + + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return true + } + } + errStr := strings.ToLower(err.Error()) + + // Explicit non-retryable SQL/schema errors + if strings.Contains(errStr, "syntax error") || + strings.Contains(errStr, "unknown table") || + strings.Contains(errStr, "unknown column") || + (strings.Contains(errStr, "table") && strings.Contains(errStr, "not found")) { + return false + } + + // Fallback network/transient errors retryableErrors := []string{ "connection refused", "connection reset", @@ -319,11 +349,13 @@ func isRetryableError(err error) bool { "temporary failure", "network is unreachable", "broken pipe", + "no route to host", } for _, re := range retryableErrors { if strings.Contains(errStr, re) { return true } } + return false } diff --git a/internal/adapters/outbound/file/sink.go b/internal/adapters/outbound/file/sink.go index d20efdb..ccee1e1 100644 --- a/internal/adapters/outbound/file/sink.go +++ b/internal/adapters/outbound/file/sink.go @@ -1,7 +1,6 @@ package file import ( - "bufio" "context" "encoding/json" "fmt" @@ -30,7 +29,6 @@ type FileSink struct { config Config mu sync.Mutex file *os.File - writer *bufio.Writer } // NewFileSink creates a new file sink. @@ -66,11 +64,12 @@ func (s *FileSink) Write(ctx context.Context, log domain.CorrelatedLog) error { return fmt.Errorf("failed to marshal log: %w", err) } - if _, err := s.writer.Write(data); err != nil { - return fmt.Errorf("failed to write log: %w", err) + line := append(data, '\n') + if _, err := s.file.Write(line); err != nil { + return fmt.Errorf("failed to write log line: %w", err) } - if _, err := s.writer.WriteString("\n"); err != nil { - return fmt.Errorf("failed to write newline: %w", err) + if err := s.file.Sync(); err != nil { + return fmt.Errorf("failed to sync log line: %w", err) } return nil @@ -81,8 +80,8 @@ func (s *FileSink) Flush(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() - if s.writer != nil { - return s.writer.Flush() + if s.file != nil { + return s.file.Sync() } return nil } @@ -92,12 +91,6 @@ func (s *FileSink) Close() error { s.mu.Lock() defer s.mu.Unlock() - if s.writer != nil { - if err := s.writer.Flush(); err != nil { - return err - } - } - if s.file != nil { return s.file.Close() } @@ -122,47 +115,54 @@ func (s *FileSink) openFile() error { } s.file = file - s.writer = bufio.NewWriter(file) return nil } // validateFilePath validates that the file path is safe and allowed. func validateFilePath(path string) error { - if path == "" { + if strings.TrimSpace(path) == "" { return fmt.Errorf("path cannot be empty") } - // Clean the path cleanPath := filepath.Clean(path) - // Ensure path is absolute or relative to allowed directories - allowedPrefixes := []string{ + // Allow relative paths for testing/dev + if !filepath.IsAbs(cleanPath) { + return nil + } + + absPath, err := filepath.Abs(cleanPath) + if err != nil { + return fmt.Errorf("failed to resolve absolute path: %w", err) + } + + allowedRoots := []string{ "/var/log/logcorrelator", "/var/log", "/tmp", } - // Check if path is in allowed directories - allowed := false - for _, prefix := range allowedPrefixes { - if strings.HasPrefix(cleanPath, prefix) { - allowed = true - break + for _, root := range allowedRoots { + absRoot, err := filepath.Abs(filepath.Clean(root)) + if err != nil { + continue } - } - if !allowed { - // Allow relative paths for testing - if !filepath.IsAbs(cleanPath) { + rel, err := filepath.Rel(absRoot, absPath) + if err != nil { + continue + } + + if rel == "." { + return nil + } + if rel == ".." { + continue + } + if !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { return nil } - return fmt.Errorf("path must be in allowed directories: %v", allowedPrefixes) } - // Check for path traversal - if strings.Contains(cleanPath, "..") { - return fmt.Errorf("path cannot contain '..'") - } - - return nil + return fmt.Errorf("path must be under allowed directories: %v", allowedRoots) } diff --git a/internal/adapters/outbound/file/sink_test.go b/internal/adapters/outbound/file/sink_test.go index 6f30f75..469123c 100644 --- a/internal/adapters/outbound/file/sink_test.go +++ b/internal/adapters/outbound/file/sink_test.go @@ -44,6 +44,36 @@ func TestFileSink_Write(t *testing.T) { } } +func TestFileSink_WriteImmediatePersist_NoFlushNeeded(t *testing.T) { + tmpDir := t.TempDir() + testPath := filepath.Join(tmpDir, "test.log") + + sink, err := NewFileSink(Config{Path: testPath}) + if err != nil { + t.Fatalf("failed to create sink: %v", err) + } + defer sink.Close() + + log := domain.CorrelatedLog{ + SrcIP: "192.168.1.1", + SrcPort: 8080, + Correlated: true, + } + + if err := sink.Write(context.Background(), log); err != nil { + t.Fatalf("failed to write: %v", err) + } + + // Must be visible immediately without Flush() + data, err := os.ReadFile(testPath) + if err != nil { + t.Fatalf("failed to read file: %v", err) + } + if len(data) == 0 { + t.Error("expected data to be present immediately after Write without Flush") + } +} + func TestFileSink_MultipleWrites(t *testing.T) { tmpDir := t.TempDir() testPath := filepath.Join(tmpDir, "test.log") @@ -105,7 +135,7 @@ func TestFileSink_ValidateFilePath(t *testing.T) { {"valid /var/log/logcorrelator", "/var/log/logcorrelator/test.log", false}, {"valid /var/log", "/var/log/test.log", false}, {"valid /tmp", "/tmp/test.log", false}, - {"path traversal", "/var/log/../etc/passwd", true}, + {"reject lookalike /var/logevil", "/var/logevil/test.log", true}, {"invalid directory", "/etc/logcorrelator/test.log", true}, {"relative path", "test.log", false}, // Allowed for testing } @@ -137,9 +167,6 @@ func TestFileSink_OpenFile(t *testing.T) { if sink.file == nil { t.Error("expected file to be opened") } - if sink.writer == nil { - t.Error("expected writer to be initialized") - } } func TestFileSink_WriteBeforeOpen(t *testing.T) { @@ -183,7 +210,7 @@ func TestFileSink_FlushBeforeOpen(t *testing.T) { } func TestFileSink_InvalidPath(t *testing.T) { - // Test with invalid path (path traversal) + // Test with invalid path (outside allowed directories) _, err := NewFileSink(Config{Path: "/etc/../passwd"}) if err == nil { t.Error("expected error for invalid path") diff --git a/internal/config/config.go b/internal/config/config.go index fa64d41..8d98300 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,10 +12,10 @@ import ( // Config holds the complete application configuration. type Config struct { - Log LogConfig `yaml:"log"` - Inputs InputsConfig `yaml:"inputs"` - Outputs OutputsConfig `yaml:"outputs"` - Correlation CorrelationConfig `yaml:"correlation"` + Log LogConfig `yaml:"log"` + Inputs InputsConfig `yaml:"inputs"` + Outputs OutputsConfig `yaml:"outputs"` + Correlation CorrelationConfig `yaml:"correlation"` } // LogConfig holds logging configuration. @@ -44,10 +44,10 @@ type InputsConfig struct { // UnixSocketConfig holds a Unix socket source configuration. type UnixSocketConfig struct { - Name string `yaml:"name"` - Path string `yaml:"path"` - Format string `yaml:"format"` - SourceType string `yaml:"source_type"` // "A" for Apache/HTTP, "B" for Network + Name string `yaml:"name"` + Path string `yaml:"path"` + Format string `yaml:"format"` + SourceType string `yaml:"source_type"` // "A" for Apache/HTTP, "B" for Network SocketPermissions string `yaml:"socket_permissions"` // octal string, e.g., "0660", "0666" } @@ -55,7 +55,7 @@ type UnixSocketConfig struct { type OutputsConfig struct { File FileOutputConfig `yaml:"file"` ClickHouse ClickHouseOutputConfig `yaml:"clickhouse"` - Stdout bool `yaml:"stdout"` + Stdout StdoutOutputConfig `yaml:"stdout"` } // FileOutputConfig holds file sink configuration. @@ -77,15 +77,14 @@ type ClickHouseOutputConfig struct { } // StdoutOutputConfig holds stdout sink configuration. -// Deprecated: stdout is now a boolean flag in OutputsConfig. type StdoutOutputConfig struct { Enabled bool `yaml:"enabled"` } // CorrelationConfig holds correlation configuration. type CorrelationConfig struct { - TimeWindowS int `yaml:"time_window_s"` - EmitOrphans bool `yaml:"emit_orphans"` + TimeWindowS int `yaml:"time_window_s"` + EmitOrphans bool `yaml:"emit_orphans"` } // Load loads configuration from a YAML file. @@ -130,7 +129,7 @@ func defaultConfig() *Config { AsyncInsert: true, TimeoutMs: 1000, }, - Stdout: false, + Stdout: StdoutOutputConfig{Enabled: false}, }, Correlation: CorrelationConfig{ TimeWindowS: 1, @@ -175,7 +174,7 @@ func (c *Config) Validate() error { if c.Outputs.ClickHouse.Enabled { hasOutput = true } - if c.Outputs.Stdout { + if c.Outputs.Stdout.Enabled { hasOutput = true } @@ -220,12 +219,13 @@ func (c *CorrelationConfig) GetTimeWindow() time.Duration { // GetSocketPermissions returns the socket permissions as os.FileMode. // Default is 0660 (owner + group read/write). func (c *UnixSocketConfig) GetSocketPermissions() os.FileMode { - if c.SocketPermissions == "" { + trimmed := strings.TrimSpace(c.SocketPermissions) + if trimmed == "" { return 0660 } // Parse octal string (e.g., "0660", "660", "0666") - perms, err := strconv.ParseUint(strings.TrimPrefix(c.SocketPermissions, "0"), 8, 32) + perms, err := strconv.ParseUint(trimmed, 8, 32) if err != nil { return 0660 } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 091ec70..ffaab1e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -131,7 +131,7 @@ func TestValidate_AtLeastOneOutput(t *testing.T) { Outputs: OutputsConfig{ File: FileOutputConfig{}, ClickHouse: ClickHouseOutputConfig{Enabled: false}, - Stdout: false, + Stdout: StdoutOutputConfig{Enabled: false}, }, } @@ -554,3 +554,37 @@ correlation: t.Errorf("expected log level DEBUG, got %s", cfg.Log.GetLevel()) } } + +func TestLoad_StdoutEnabledObject(t *testing.T) { + content := ` +inputs: + unix_sockets: + - name: http + path: /var/run/logcorrelator/http.socket + - name: network + path: /var/run/logcorrelator/network.socket + +outputs: + stdout: + enabled: true + +correlation: + time_window_s: 1 + emit_orphans: true +` + + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yml") + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !cfg.Outputs.Stdout.Enabled { + t.Error("expected outputs.stdout.enabled to be true") + } +} diff --git a/internal/domain/correlated_log.go b/internal/domain/correlated_log.go index 704741e..0a16b56 100644 --- a/internal/domain/correlated_log.go +++ b/internal/domain/correlated_log.go @@ -2,6 +2,7 @@ package domain import ( "encoding/json" + "reflect" "time" ) @@ -22,7 +23,7 @@ type CorrelatedLog struct { func (c CorrelatedLog) MarshalJSON() ([]byte, error) { // Create a flat map with all fields flat := make(map[string]any) - + // Add core fields flat["timestamp"] = c.Timestamp flat["src_ip"] = c.SrcIP @@ -37,12 +38,24 @@ func (c CorrelatedLog) MarshalJSON() ([]byte, error) { if c.OrphanSide != "" { flat["orphan_side"] = c.OrphanSide } - - // Merge additional fields + + // Merge additional fields while preserving reserved keys + reservedKeys := map[string]struct{}{ + "timestamp": {}, + "src_ip": {}, + "src_port": {}, + "dst_ip": {}, + "dst_port": {}, + "correlated": {}, + "orphan_side": {}, + } for k, v := range c.Fields { + if _, reserved := reservedKeys[k]; reserved { + continue + } flat[k] = v } - + return json.Marshal(flat) } @@ -89,13 +102,28 @@ func extractFields(e *NormalizedEvent) map[string]any { func mergeFields(a, b *NormalizedEvent) map[string]any { result := make(map[string]any) - // Merge fields from both events + + // Start with A fields for k, v := range a.Raw { result[k] = v } + + // Merge B fields with collision handling for k, v := range b.Raw { + if existing, exists := result[k]; exists { + if reflect.DeepEqual(existing, v) { + continue + } + + // Collision with different values: keep both with prefixes + delete(result, k) + result["a_"+k] = existing + result["b_"+k] = v + continue + } result[k] = v } + return result } diff --git a/internal/domain/correlation_service.go b/internal/domain/correlation_service.go index ec80f4f..35ed3b4 100644 --- a/internal/domain/correlation_service.go +++ b/internal/domain/correlation_service.go @@ -96,14 +96,11 @@ func (s *CorrelationService) ProcessEvent(event *NormalizedEvent) []CorrelatedLo // Check buffer overflow before adding if s.isBufferFull(event.Source) { // Buffer full, drop event or emit as orphan - s.logger.Warnf("buffer full, dropping event: source=%s src_ip=%s src_port=%d", + s.logger.Warnf("buffer full, dropping event: source=%s src_ip=%s src_port=%d", event.Source, event.SrcIP, event.SrcPort) if event.Source == SourceA && s.config.ApacheAlwaysEmit { return []CorrelatedLog{NewCorrelatedLogFromEvent(event, "A")} } - if event.Source == SourceB && s.config.NetworkEmit { - return []CorrelatedLog{NewCorrelatedLogFromEvent(event, "B")} - } return nil } @@ -123,7 +120,7 @@ func (s *CorrelationService) ProcessEvent(event *NormalizedEvent) []CorrelatedLo if shouldBuffer { s.addEvent(event) - s.logger.Debugf("event buffered: source=%s src_ip=%s src_port=%d buffer_size=%d", + s.logger.Debugf("event buffered: source=%s src_ip=%s src_port=%d buffer_size=%d", event.Source, event.SrcIP, event.SrcPort, s.getBufferSize(event.Source)) } @@ -158,7 +155,7 @@ func (s *CorrelationService) processSourceA(event *NormalizedEvent) ([]Correlate return s.eventsMatch(event, other) }); bEvent != nil { correlated := NewCorrelatedLog(event, bEvent) - s.logger.Debugf("correlation found: A(src_ip=%s src_port=%d) + B(src_ip=%s src_port=%d)", + s.logger.Debugf("correlation found: A(src_ip=%s src_port=%d) + B(src_ip=%s src_port=%d)", event.SrcIP, event.SrcPort, bEvent.SrcIP, bEvent.SrcPort) return []CorrelatedLog{correlated}, false } @@ -182,19 +179,12 @@ func (s *CorrelationService) processSourceB(event *NormalizedEvent) ([]Correlate return s.eventsMatch(other, event) }); aEvent != nil { correlated := NewCorrelatedLog(aEvent, event) - s.logger.Debugf("correlation found: A(src_ip=%s src_port=%d) + B(src_ip=%s src_port=%d)", + s.logger.Debugf("correlation found: A(src_ip=%s src_port=%d) + B(src_ip=%s src_port=%d)", aEvent.SrcIP, aEvent.SrcPort, event.SrcIP, event.SrcPort) return []CorrelatedLog{correlated}, false } - // No match found - orphan B event (not emitted by default) - if s.config.NetworkEmit { - orphan := NewCorrelatedLogFromEvent(event, "B") - s.logger.Warnf("orphan B event (no A match): src_ip=%s src_port=%d", event.SrcIP, event.SrcPort) - return []CorrelatedLog{orphan}, false - } - - // Keep in buffer for potential future match + // Never emit B alone. Keep in buffer for potential future match. return nil, true } @@ -231,20 +221,17 @@ func (s *CorrelationService) cleanExpired() { // cleanBuffer removes expired events from a buffer. func (s *CorrelationService) cleanBuffer(buffer *eventBuffer, pending map[string][]*list.Element, cutoff time.Time) { for elem := buffer.events.Front(); elem != nil; { + next := elem.Next() event := elem.Value.(*NormalizedEvent) if event.Timestamp.Before(cutoff) { - next := elem.Next() key := event.CorrelationKey() buffer.events.Remove(elem) pending[key] = removeElementFromSlice(pending[key], elem) if len(pending[key]) == 0 { delete(pending, key) } - elem = next - continue } - // Events are inserted in arrival order; once we hit a non-expired event, stop. - break + elem = next } } @@ -307,14 +294,7 @@ func (s *CorrelationService) Flush() []CorrelatedLog { } } - // Emit remaining B events as orphans only if explicitly enabled - if s.config.NetworkEmit { - for elem := s.bufferB.events.Front(); elem != nil; elem = elem.Next() { - event := elem.Value.(*NormalizedEvent) - orphan := NewCorrelatedLogFromEvent(event, "B") - results = append(results, orphan) - } - } + // Never emit remaining B events alone. // Clear buffers s.bufferA.events.Init() diff --git a/internal/domain/correlation_service_test.go b/internal/domain/correlation_service_test.go index 68fcb94..1883b81 100644 --- a/internal/domain/correlation_service_test.go +++ b/internal/domain/correlation_service_test.go @@ -193,7 +193,6 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) { now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) timeProvider := &mockTimeProvider{now: now} - // Flush only emits events if ApacheAlwaysEmit and NetworkEmit are true config := CorrelationConfig{ TimeWindow: time.Second, ApacheAlwaysEmit: true, @@ -201,7 +200,6 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) { } svc := NewCorrelationService(config, timeProvider) - // We need to bypass the normal ProcessEvent logic to get events into buffers // Add events directly to buffers for testing Flush keyA := "192.168.1.1:8080" keyB := "192.168.1.2:9090" @@ -219,7 +217,6 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) { SrcPort: 9090, } - // Manually add to buffers (simulating events that couldn't be matched) elemA := svc.bufferA.events.PushBack(apacheEvent) svc.pendingA[keyA] = append(svc.pendingA[keyA], elemA) @@ -227,8 +224,11 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) { svc.pendingB[keyB] = append(svc.pendingB[keyB], elemB) flushed := svc.Flush() - if len(flushed) != 2 { - t.Errorf("expected 2 flushed events, got %d", len(flushed)) + if len(flushed) != 1 { + t.Errorf("expected 1 flushed event (A only), got %d", len(flushed)) + } + if len(flushed) == 1 && flushed[0].OrphanSide != "A" { + t.Errorf("expected orphan side A, got %s", flushed[0].OrphanSide) } // Verify buffers are cleared @@ -262,7 +262,7 @@ func TestCorrelationService_BufferOverflow(t *testing.T) { svc.ProcessEvent(event) } - // Buffer full, next event should be dropped (not emitted since ApacheAlwaysEmit=false but buffer full) + // Buffer full, next event should be dropped (not emitted since ApacheAlwaysEmit=false) overflowEvent := &NormalizedEvent{ Source: SourceA, Timestamp: now, @@ -340,3 +340,33 @@ func TestCorrelationService_DifferentSourceTypes(t *testing.T) { t.Error("expected correlated result") } } + +func TestCorrelationService_NetworkEmitTrue_DoesNotEmitBAlone(t *testing.T) { + now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + timeProvider := &mockTimeProvider{now: now} + + config := CorrelationConfig{ + TimeWindow: time.Second, + ApacheAlwaysEmit: false, + NetworkEmit: true, + } + + svc := NewCorrelationService(config, timeProvider) + + networkEvent := &NormalizedEvent{ + Source: SourceB, + Timestamp: now, + SrcIP: "10.10.10.10", + SrcPort: 5555, + } + + results := svc.ProcessEvent(networkEvent) + if len(results) != 0 { + t.Errorf("expected 0 immediate results for orphan B, got %d", len(results)) + } + + flushed := svc.Flush() + if len(flushed) != 0 { + t.Errorf("expected 0 flushed orphan B events, got %d", len(flushed)) + } +} diff --git a/internal/observability/logger.go b/internal/observability/logger.go index c8af44e..46bf7f8 100644 --- a/internal/observability/logger.go +++ b/internal/observability/logger.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "sort" "strings" "sync" ) @@ -52,10 +53,10 @@ func (l LogLevel) String() string { // Logger provides structured logging. type Logger struct { - mu sync.Mutex - logger *log.Logger - prefix string - fields map[string]any + mu sync.RWMutex + logger *log.Logger + prefix string + fields map[string]any minLevel LogLevel } @@ -88,30 +89,32 @@ func (l *Logger) SetLevel(level string) { // ShouldLog returns true if the given level should be logged. func (l *Logger) ShouldLog(level LogLevel) bool { - l.mu.Lock() - defer l.mu.Unlock() + l.mu.RLock() + defer l.mu.RUnlock() return level >= l.minLevel } // WithFields returns a new logger with additional fields. func (l *Logger) WithFields(fields map[string]any) *Logger { - l.mu.Lock() + l.mu.RLock() minLevel := l.minLevel - l.mu.Unlock() - - newLogger := &Logger{ + prefix := l.prefix + existing := make(map[string]any, len(l.fields)) + for k, v := range l.fields { + existing[k] = v + } + l.mu.RUnlock() + + for k, v := range fields { + existing[k] = v + } + + return &Logger{ logger: l.logger, - prefix: l.prefix, - fields: make(map[string]any), + prefix: prefix, + fields: existing, minLevel: minLevel, } - for k, v := range l.fields { - newLogger.fields[k] = v - } - for k, v := range fields { - newLogger.fields[k] = v - } - return newLogger } // Info logs an info message. @@ -119,8 +122,6 @@ func (l *Logger) Info(msg string) { if !l.ShouldLog(INFO) { return } - l.mu.Lock() - defer l.mu.Unlock() l.log("INFO", msg) } @@ -129,8 +130,6 @@ func (l *Logger) Warn(msg string) { if !l.ShouldLog(WARN) { return } - l.mu.Lock() - defer l.mu.Unlock() l.log("WARN", msg) } @@ -139,8 +138,6 @@ func (l *Logger) Error(msg string, err error) { if !l.ShouldLog(ERROR) { return } - l.mu.Lock() - defer l.mu.Unlock() if err != nil { l.log("ERROR", msg+" "+err.Error()) } else { @@ -153,8 +150,6 @@ func (l *Logger) Debug(msg string) { if !l.ShouldLog(DEBUG) { return } - l.mu.Lock() - defer l.mu.Unlock() l.log("DEBUG", msg) } @@ -163,8 +158,6 @@ func (l *Logger) Debugf(msg string, args ...any) { if !l.ShouldLog(DEBUG) { return } - l.mu.Lock() - defer l.mu.Unlock() l.log("DEBUG", fmt.Sprintf(msg, args...)) } @@ -173,8 +166,6 @@ func (l *Logger) Warnf(msg string, args ...any) { if !l.ShouldLog(WARN) { return } - l.mu.Lock() - defer l.mu.Unlock() l.log("WARN", fmt.Sprintf(msg, args...)) } @@ -183,27 +174,42 @@ func (l *Logger) Infof(msg string, args ...any) { if !l.ShouldLog(INFO) { return } - l.mu.Lock() - defer l.mu.Unlock() l.log("INFO", fmt.Sprintf(msg, args...)) } func (l *Logger) log(level, msg string) { + l.mu.RLock() prefix := l.prefix + fields := make(map[string]any, len(l.fields)) + for k, v := range l.fields { + fields[k] = v + } + l.mu.RUnlock() + + var b strings.Builder if prefix != "" { - prefix = "[" + prefix + "] " + b.WriteString("[") + b.WriteString(prefix) + b.WriteString("] ") + } + b.WriteString(level) + b.WriteString(" ") + b.WriteString(msg) + + if len(fields) > 0 { + keys := make([]string, 0, len(fields)) + for k := range fields { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + b.WriteString(" ") + b.WriteString(k) + b.WriteString("=") + b.WriteString(fmt.Sprintf("%v", fields[k])) + } } - l.logger.SetPrefix(prefix + level + " ") - - var args []any - for k, v := range l.fields { - args = append(args, k, v) - } - - if len(args) > 0 { - l.logger.Printf(msg+" %+v", args...) - } else { - l.logger.Print(msg) - } + l.logger.Print(b.String()) }