fix: renforcer corrélation A/B et sorties stdout/fichier
Some checks failed
Build and Test / test (push) Has been cancelled
Build and Test / build (push) Has been cancelled
Build and Test / docker (push) Has been cancelled

Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) <aider@aider.chat>
This commit is contained in:
Jacquin Antoine
2026-03-01 12:10:17 +01:00
parent d3436f6245
commit 27c7659397
13 changed files with 441 additions and 259 deletions

View File

@ -92,7 +92,7 @@ func main() {
logger.Info(fmt.Sprintf("Configured ClickHouse sink: table=%s", cfg.Outputs.ClickHouse.Table)) logger.Info(fmt.Sprintf("Configured ClickHouse sink: table=%s", cfg.Outputs.ClickHouse.Table))
} }
if cfg.Outputs.Stdout { if cfg.Outputs.Stdout.Enabled {
stdoutSink := stdout.NewStdoutSink(stdout.Config{ stdoutSink := stdout.NewStdoutSink(stdout.Config{
Enabled: true, Enabled: true,
}) })

View File

@ -8,10 +8,12 @@ log:
inputs: inputs:
unix_sockets: unix_sockets:
- name: http - name: http
source_type: A
path: /var/run/logcorrelator/http.socket path: /var/run/logcorrelator/http.socket
format: json format: json
socket_permissions: "0660" # owner + group read/write socket_permissions: "0660" # owner + group read/write
- name: network - name: network
source_type: B
path: /var/run/logcorrelator/network.socket path: /var/run/logcorrelator/network.socket
format: json format: json
socket_permissions: "0660" socket_permissions: "0660"
@ -24,9 +26,9 @@ outputs:
dsn: clickhouse://user:pass@localhost:9000/db dsn: clickhouse://user:pass@localhost:9000/db
table: correlated_logs_http_network table: correlated_logs_http_network
stdout: false stdout:
enabled: false
correlation: correlation:
time_window_s: 1 time_window_s: 1
emit_orphans: true # http toujours émis, network jamais seul emit_orphans: true # http toujours émis, network jamais seul

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math"
"net" "net"
"os" "os"
"strconv" "strconv"
@ -191,6 +192,21 @@ func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventC
} }
} }
func resolveSource(sourceType string, headers map[string]string) domain.EventSource {
switch strings.ToLower(strings.TrimSpace(sourceType)) {
case "a", "apache", "http":
return domain.SourceA
case "b", "network", "net":
return domain.SourceB
default:
// fallback compat
if len(headers) > 0 {
return domain.SourceA
}
return domain.SourceB
}
}
func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, error) { func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, error) {
var raw map[string]any var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil { if err := json.Unmarshal(data, &raw); err != nil {
@ -198,12 +214,29 @@ func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, er
} }
event := &domain.NormalizedEvent{ event := &domain.NormalizedEvent{
Raw: raw, Raw: raw,
Extra: make(map[string]any), Extra: make(map[string]any),
Headers: make(map[string]string),
} }
// Extract headers (header_* fields) first
for k, v := range raw {
if strings.HasPrefix(k, "header_") {
if sv, ok := v.(string); ok {
event.Headers[k[7:]] = sv
}
}
}
// Resolve source first (strict timestamp logic depends on source)
event.Source = resolveSource(sourceType, event.Headers)
// Extract and validate src_ip // Extract and validate src_ip
if v, ok := getString(raw, "src_ip"); ok { if v, ok := getString(raw, "src_ip"); ok {
v = strings.TrimSpace(v)
if v == "" {
return nil, fmt.Errorf("src_ip cannot be empty")
}
event.SrcIP = v event.SrcIP = v
} else { } else {
return nil, fmt.Errorf("missing required field: src_ip") return nil, fmt.Errorf("missing required field: src_ip")
@ -221,7 +254,7 @@ func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, er
// Extract dst_ip (optional) // Extract dst_ip (optional)
if v, ok := getString(raw, "dst_ip"); ok { if v, ok := getString(raw, "dst_ip"); ok {
event.DstIP = v event.DstIP = strings.TrimSpace(v)
} }
// Extract dst_port (optional) // Extract dst_port (optional)
@ -232,50 +265,23 @@ func parseJSONEvent(data []byte, sourceType string) (*domain.NormalizedEvent, er
event.DstPort = v event.DstPort = v
} }
// Extract timestamp - try different fields // Extract timestamp based on source contract
if ts, ok := getInt64(raw, "timestamp"); ok { switch event.Source {
case domain.SourceA:
ts, ok := getInt64(raw, "timestamp")
if !ok {
return nil, fmt.Errorf("missing required numeric field: timestamp for source A")
}
// Assume nanoseconds // Assume nanoseconds
event.Timestamp = time.Unix(0, ts) event.Timestamp = time.Unix(0, ts)
} else if tsStr, ok := getString(raw, "time"); ok { case domain.SourceB:
if t, err := time.Parse(time.RFC3339, tsStr); err == nil { // For network source, always use local reception time
event.Timestamp = t
}
} else if tsStr, ok := getString(raw, "timestamp"); ok {
if t, err := time.Parse(time.RFC3339, tsStr); err == nil {
event.Timestamp = t
}
}
if event.Timestamp.IsZero() {
event.Timestamp = time.Now() event.Timestamp = time.Now()
}
// Extract headers (header_* fields)
event.Headers = make(map[string]string)
for k, v := range raw {
if len(k) > 7 && k[:7] == "header_" {
if sv, ok := v.(string); ok {
event.Headers[k[7:]] = sv
}
}
}
// Determine source based on explicit config or fallback to heuristic
switch sourceType {
case "A", "a", "apache", "http":
event.Source = domain.SourceA
case "B", "b", "network", "net":
event.Source = domain.SourceB
default: default:
// Fallback to heuristic detection for backward compatibility return nil, fmt.Errorf("unsupported source type: %s", event.Source)
if len(event.Headers) > 0 {
event.Source = domain.SourceA
} else {
event.Source = domain.SourceB
}
} }
// Extra fields (single pass) // Extra fields
knownFields := map[string]bool{ knownFields := map[string]bool{
"src_ip": true, "src_port": true, "dst_ip": true, "dst_port": true, "src_ip": true, "src_port": true, "dst_ip": true, "dst_port": true,
"timestamp": true, "time": true, "timestamp": true, "time": true,
@ -306,6 +312,9 @@ func getInt(m map[string]any, key string) (int, bool) {
if v, ok := m[key]; ok { if v, ok := m[key]; ok {
switch val := v.(type) { switch val := v.(type) {
case float64: case float64:
if math.Trunc(val) != val {
return 0, false
}
return int(val), true return int(val), true
case int: case int:
return val, true return val, true
@ -324,6 +333,9 @@ func getInt64(m map[string]any, key string) (int64, bool) {
if v, ok := m[key]; ok { if v, ok := m[key]; ok {
switch val := v.(type) { switch val := v.(type) {
case float64: case float64:
if math.Trunc(val) != val {
return 0, false
}
return int64(val), true return int64(val), true
case int: case int:
return int64(val), true return int64(val), true

View File

@ -41,6 +41,10 @@ func TestParseJSONEvent_Apache(t *testing.T) {
if event.Source != domain.SourceA { if event.Source != domain.SourceA {
t.Errorf("expected source A, got %s", event.Source) t.Errorf("expected source A, got %s", event.Source)
} }
expectedTs := time.Unix(0, 1704110400000000000)
if !event.Timestamp.Equal(expectedTs) {
t.Errorf("expected timestamp %v, got %v", expectedTs, event.Timestamp)
}
} }
func TestParseJSONEvent_Network(t *testing.T) { func TestParseJSONEvent_Network(t *testing.T) {
@ -49,12 +53,15 @@ func TestParseJSONEvent_Network(t *testing.T) {
"src_port": 8080, "src_port": 8080,
"dst_ip": "10.0.0.1", "dst_ip": "10.0.0.1",
"dst_port": 443, "dst_port": 443,
"timestamp": 1704110400000000000,
"ja3": "abc123def456", "ja3": "abc123def456",
"ja4": "xyz789", "ja4": "xyz789",
"tcp_meta_flags": "SYN" "tcp_meta_flags": "SYN"
}`) }`)
before := time.Now()
event, err := parseJSONEvent(data, "B") event, err := parseJSONEvent(data, "B")
after := time.Now()
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@ -68,6 +75,9 @@ func TestParseJSONEvent_Network(t *testing.T) {
if event.Source != domain.SourceB { if event.Source != domain.SourceB {
t.Errorf("expected source B, got %s", event.Source) t.Errorf("expected source B, got %s", event.Source)
} }
if event.Timestamp.Before(before.Add(-2*time.Second)) || event.Timestamp.After(after.Add(2*time.Second)) {
t.Errorf("expected network timestamp near now, got %v", event.Timestamp)
}
} }
func TestParseJSONEvent_InvalidJSON(t *testing.T) { func TestParseJSONEvent_InvalidJSON(t *testing.T) {
@ -88,21 +98,35 @@ func TestParseJSONEvent_MissingFields(t *testing.T) {
} }
} }
func TestParseJSONEvent_StringTimestamp(t *testing.T) { func TestParseJSONEvent_SourceARequiresNumericTimestamp(t *testing.T) {
data := []byte(`{ data := []byte(`{
"src_ip": "192.168.1.1", "src_ip": "192.168.1.1",
"src_port": 8080, "src_port": 8080,
"time": "2024-01-01T12:00:00Z" "time": "2024-01-01T12:00:00Z"
}`) }`)
event, err := parseJSONEvent(data, "") _, err := parseJSONEvent(data, "A")
if err == nil {
t.Fatal("expected error for source A without numeric timestamp")
}
}
func TestParseJSONEvent_SourceBIgnoresPayloadTimestamp(t *testing.T) {
data := []byte(`{
"src_ip": "192.168.1.1",
"src_port": 8080,
"timestamp": 1
}`)
before := time.Now()
event, err := parseJSONEvent(data, "B")
after := time.Now()
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
expected := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) if event.Timestamp.Before(before.Add(-2*time.Second)) || event.Timestamp.After(after.Add(2*time.Second)) {
if !event.Timestamp.Equal(expected) { t.Errorf("expected source B timestamp near now, got %v", event.Timestamp)
t.Errorf("expected timestamp %v, got %v", expected, event.Timestamp)
} }
} }
@ -114,40 +138,40 @@ func TestParseJSONEvent_ExplicitSourceType(t *testing.T) {
expected domain.EventSource expected domain.EventSource
}{ }{
{ {
name: "explicit A", name: "explicit A",
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000}`,
sourceType: "A", sourceType: "A",
expected: domain.SourceA, expected: domain.SourceA,
}, },
{ {
name: "explicit B", name: "explicit B",
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080}`,
sourceType: "B", sourceType: "B",
expected: domain.SourceB, expected: domain.SourceB,
}, },
{ {
name: "explicit apache", name: "explicit apache",
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000}`,
sourceType: "apache", sourceType: "apache",
expected: domain.SourceA, expected: domain.SourceA,
}, },
{ {
name: "explicit network", name: "explicit network",
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080}`,
sourceType: "network", sourceType: "network",
expected: domain.SourceB, expected: domain.SourceB,
}, },
{ {
name: "auto-detect A with headers", name: "auto-detect A with headers",
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "header_host": "example.com"}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080, "timestamp": 1704110400000000000, "header_host": "example.com"}`,
sourceType: "", sourceType: "",
expected: domain.SourceA, expected: domain.SourceA,
}, },
{ {
name: "auto-detect B without headers", name: "auto-detect B without headers",
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "abc"}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "abc"}`,
sourceType: "", sourceType: "",
expected: domain.SourceB, expected: domain.SourceB,
}, },
} }
@ -241,7 +265,7 @@ func TestGetInt(t *testing.T) {
expected int expected int
ok bool ok bool
}{ }{
{"float", 42, true}, {"float", 0, false},
{"int", 42, true}, {"int", 42, true},
{"int64", 42, true}, {"int64", 42, true},
{"string", 42, true}, {"string", 42, true},
@ -278,7 +302,7 @@ func TestGetInt64(t *testing.T) {
expected int64 expected int64
ok bool ok bool
}{ }{
{"float", 42, true}, {"float", 0, false},
{"int", 42, true}, {"int", 42, true},
{"int64", 42, true}, {"int64", 42, true},
{"string", 42, true}, {"string", 42, true},
@ -302,45 +326,52 @@ func TestGetInt64(t *testing.T) {
func TestParseJSONEvent_PortValidation(t *testing.T) { func TestParseJSONEvent_PortValidation(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
data string data string
wantErr bool sourceType string
wantErr bool
}{ }{
{ {
name: "valid src_port", name: "valid src_port",
data: `{"src_ip": "192.168.1.1", "src_port": 8080}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080}`,
wantErr: false, sourceType: "B",
wantErr: false,
}, },
{ {
name: "src_port zero", name: "src_port zero",
data: `{"src_ip": "192.168.1.1", "src_port": 0}`, data: `{"src_ip": "192.168.1.1", "src_port": 0}`,
wantErr: true, sourceType: "B",
wantErr: true,
}, },
{ {
name: "src_port negative", name: "src_port negative",
data: `{"src_ip": "192.168.1.1", "src_port": -1}`, data: `{"src_ip": "192.168.1.1", "src_port": -1}`,
wantErr: true, sourceType: "B",
wantErr: true,
}, },
{ {
name: "src_port too high", name: "src_port too high",
data: `{"src_ip": "192.168.1.1", "src_port": 70000}`, data: `{"src_ip": "192.168.1.1", "src_port": 70000}`,
wantErr: true, sourceType: "B",
wantErr: true,
}, },
{ {
name: "valid dst_port zero", name: "valid dst_port zero",
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 0}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 0}`,
wantErr: false, sourceType: "B",
wantErr: false,
}, },
{ {
name: "dst_port too high", name: "dst_port too high",
data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 70000}`, data: `{"src_ip": "192.168.1.1", "src_port": 8080, "dst_port": 70000}`,
wantErr: true, sourceType: "B",
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := parseJSONEvent([]byte(tt.data), "") _, err := parseJSONEvent([]byte(tt.data), tt.sourceType)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parseJSONEvent() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parseJSONEvent() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -350,12 +381,12 @@ func TestParseJSONEvent_PortValidation(t *testing.T) {
func TestParseJSONEvent_TimestampFallback(t *testing.T) { func TestParseJSONEvent_TimestampFallback(t *testing.T) {
data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080}`) data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080}`)
event, err := parseJSONEvent(data, "") event, err := parseJSONEvent(data, "B")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
// Should fallback to current time if no timestamp provided // For source B, timestamp is reception time
if event.Timestamp.IsZero() { if event.Timestamp.IsZero() {
t.Error("expected non-zero timestamp") t.Error("expected non-zero timestamp")
} }

View File

@ -4,7 +4,9 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -115,35 +117,37 @@ func (s *ClickHouseSink) Name() string {
// Write adds a log to the buffer. // Write adds a log to the buffer.
func (s *ClickHouseSink) Write(ctx context.Context, log domain.CorrelatedLog) error { func (s *ClickHouseSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
s.mu.Lock() deadline := time.Now().Add(time.Duration(s.config.TimeoutMs) * time.Millisecond)
defer s.mu.Unlock()
// Check buffer overflow for {
if len(s.buffer) >= s.config.MaxBufferSize { s.mu.Lock()
if s.config.DropOnOverflow { if len(s.buffer) < s.config.MaxBufferSize {
// Drop the log s.buffer = append(s.buffer, log)
if len(s.buffer) >= s.config.BatchSize {
select {
case s.flushChan <- struct{}{}:
default:
}
}
s.mu.Unlock()
return nil return nil
} }
// Block until space is available (with timeout) drop := s.config.DropOnOverflow
s.mu.Unlock()
if drop {
return nil
}
if time.Now().After(deadline) {
return fmt.Errorf("buffer full, timeout exceeded")
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-time.After(time.Duration(s.config.TimeoutMs) * time.Millisecond): case <-time.After(10 * time.Millisecond):
return fmt.Errorf("buffer full, timeout exceeded")
} }
} }
s.buffer = append(s.buffer, log)
// Trigger flush if batch is full
if len(s.buffer) >= s.config.BatchSize {
select {
case s.flushChan <- struct{}{}:
default:
}
}
return nil
} }
// Flush flushes the buffer to ClickHouse. // Flush flushes the buffer to ClickHouse.
@ -311,7 +315,33 @@ func isRetryableError(err error) bool {
if err == nil { if err == nil {
return false return false
} }
if errors.Is(err, context.DeadlineExceeded) {
return true
}
if errors.Is(err, context.Canceled) {
return false
}
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return true
}
}
errStr := strings.ToLower(err.Error()) errStr := strings.ToLower(err.Error())
// Explicit non-retryable SQL/schema errors
if strings.Contains(errStr, "syntax error") ||
strings.Contains(errStr, "unknown table") ||
strings.Contains(errStr, "unknown column") ||
(strings.Contains(errStr, "table") && strings.Contains(errStr, "not found")) {
return false
}
// Fallback network/transient errors
retryableErrors := []string{ retryableErrors := []string{
"connection refused", "connection refused",
"connection reset", "connection reset",
@ -319,11 +349,13 @@ func isRetryableError(err error) bool {
"temporary failure", "temporary failure",
"network is unreachable", "network is unreachable",
"broken pipe", "broken pipe",
"no route to host",
} }
for _, re := range retryableErrors { for _, re := range retryableErrors {
if strings.Contains(errStr, re) { if strings.Contains(errStr, re) {
return true return true
} }
} }
return false return false
} }

View File

@ -1,7 +1,6 @@
package file package file
import ( import (
"bufio"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -30,7 +29,6 @@ type FileSink struct {
config Config config Config
mu sync.Mutex mu sync.Mutex
file *os.File file *os.File
writer *bufio.Writer
} }
// NewFileSink creates a new file sink. // NewFileSink creates a new file sink.
@ -66,11 +64,12 @@ func (s *FileSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
return fmt.Errorf("failed to marshal log: %w", err) return fmt.Errorf("failed to marshal log: %w", err)
} }
if _, err := s.writer.Write(data); err != nil { line := append(data, '\n')
return fmt.Errorf("failed to write log: %w", err) if _, err := s.file.Write(line); err != nil {
return fmt.Errorf("failed to write log line: %w", err)
} }
if _, err := s.writer.WriteString("\n"); err != nil { if err := s.file.Sync(); err != nil {
return fmt.Errorf("failed to write newline: %w", err) return fmt.Errorf("failed to sync log line: %w", err)
} }
return nil return nil
@ -81,8 +80,8 @@ func (s *FileSink) Flush(ctx context.Context) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.writer != nil { if s.file != nil {
return s.writer.Flush() return s.file.Sync()
} }
return nil return nil
} }
@ -92,12 +91,6 @@ func (s *FileSink) Close() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.writer != nil {
if err := s.writer.Flush(); err != nil {
return err
}
}
if s.file != nil { if s.file != nil {
return s.file.Close() return s.file.Close()
} }
@ -122,47 +115,54 @@ func (s *FileSink) openFile() error {
} }
s.file = file s.file = file
s.writer = bufio.NewWriter(file)
return nil return nil
} }
// validateFilePath validates that the file path is safe and allowed. // validateFilePath validates that the file path is safe and allowed.
func validateFilePath(path string) error { func validateFilePath(path string) error {
if path == "" { if strings.TrimSpace(path) == "" {
return fmt.Errorf("path cannot be empty") return fmt.Errorf("path cannot be empty")
} }
// Clean the path
cleanPath := filepath.Clean(path) cleanPath := filepath.Clean(path)
// Ensure path is absolute or relative to allowed directories // Allow relative paths for testing/dev
allowedPrefixes := []string{ if !filepath.IsAbs(cleanPath) {
return nil
}
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return fmt.Errorf("failed to resolve absolute path: %w", err)
}
allowedRoots := []string{
"/var/log/logcorrelator", "/var/log/logcorrelator",
"/var/log", "/var/log",
"/tmp", "/tmp",
} }
// Check if path is in allowed directories for _, root := range allowedRoots {
allowed := false absRoot, err := filepath.Abs(filepath.Clean(root))
for _, prefix := range allowedPrefixes { if err != nil {
if strings.HasPrefix(cleanPath, prefix) { continue
allowed = true
break
} }
}
if !allowed { rel, err := filepath.Rel(absRoot, absPath)
// Allow relative paths for testing if err != nil {
if !filepath.IsAbs(cleanPath) { continue
}
if rel == "." {
return nil
}
if rel == ".." {
continue
}
if !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return nil return nil
} }
return fmt.Errorf("path must be in allowed directories: %v", allowedPrefixes)
} }
// Check for path traversal return fmt.Errorf("path must be under allowed directories: %v", allowedRoots)
if strings.Contains(cleanPath, "..") {
return fmt.Errorf("path cannot contain '..'")
}
return nil
} }

View File

@ -44,6 +44,36 @@ func TestFileSink_Write(t *testing.T) {
} }
} }
func TestFileSink_WriteImmediatePersist_NoFlushNeeded(t *testing.T) {
tmpDir := t.TempDir()
testPath := filepath.Join(tmpDir, "test.log")
sink, err := NewFileSink(Config{Path: testPath})
if err != nil {
t.Fatalf("failed to create sink: %v", err)
}
defer sink.Close()
log := domain.CorrelatedLog{
SrcIP: "192.168.1.1",
SrcPort: 8080,
Correlated: true,
}
if err := sink.Write(context.Background(), log); err != nil {
t.Fatalf("failed to write: %v", err)
}
// Must be visible immediately without Flush()
data, err := os.ReadFile(testPath)
if err != nil {
t.Fatalf("failed to read file: %v", err)
}
if len(data) == 0 {
t.Error("expected data to be present immediately after Write without Flush")
}
}
func TestFileSink_MultipleWrites(t *testing.T) { func TestFileSink_MultipleWrites(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
testPath := filepath.Join(tmpDir, "test.log") testPath := filepath.Join(tmpDir, "test.log")
@ -105,7 +135,7 @@ func TestFileSink_ValidateFilePath(t *testing.T) {
{"valid /var/log/logcorrelator", "/var/log/logcorrelator/test.log", false}, {"valid /var/log/logcorrelator", "/var/log/logcorrelator/test.log", false},
{"valid /var/log", "/var/log/test.log", false}, {"valid /var/log", "/var/log/test.log", false},
{"valid /tmp", "/tmp/test.log", false}, {"valid /tmp", "/tmp/test.log", false},
{"path traversal", "/var/log/../etc/passwd", true}, {"reject lookalike /var/logevil", "/var/logevil/test.log", true},
{"invalid directory", "/etc/logcorrelator/test.log", true}, {"invalid directory", "/etc/logcorrelator/test.log", true},
{"relative path", "test.log", false}, // Allowed for testing {"relative path", "test.log", false}, // Allowed for testing
} }
@ -137,9 +167,6 @@ func TestFileSink_OpenFile(t *testing.T) {
if sink.file == nil { if sink.file == nil {
t.Error("expected file to be opened") t.Error("expected file to be opened")
} }
if sink.writer == nil {
t.Error("expected writer to be initialized")
}
} }
func TestFileSink_WriteBeforeOpen(t *testing.T) { func TestFileSink_WriteBeforeOpen(t *testing.T) {
@ -183,7 +210,7 @@ func TestFileSink_FlushBeforeOpen(t *testing.T) {
} }
func TestFileSink_InvalidPath(t *testing.T) { func TestFileSink_InvalidPath(t *testing.T) {
// Test with invalid path (path traversal) // Test with invalid path (outside allowed directories)
_, err := NewFileSink(Config{Path: "/etc/../passwd"}) _, err := NewFileSink(Config{Path: "/etc/../passwd"})
if err == nil { if err == nil {
t.Error("expected error for invalid path") t.Error("expected error for invalid path")

View File

@ -12,10 +12,10 @@ import (
// Config holds the complete application configuration. // Config holds the complete application configuration.
type Config struct { type Config struct {
Log LogConfig `yaml:"log"` Log LogConfig `yaml:"log"`
Inputs InputsConfig `yaml:"inputs"` Inputs InputsConfig `yaml:"inputs"`
Outputs OutputsConfig `yaml:"outputs"` Outputs OutputsConfig `yaml:"outputs"`
Correlation CorrelationConfig `yaml:"correlation"` Correlation CorrelationConfig `yaml:"correlation"`
} }
// LogConfig holds logging configuration. // LogConfig holds logging configuration.
@ -44,10 +44,10 @@ type InputsConfig struct {
// UnixSocketConfig holds a Unix socket source configuration. // UnixSocketConfig holds a Unix socket source configuration.
type UnixSocketConfig struct { type UnixSocketConfig struct {
Name string `yaml:"name"` Name string `yaml:"name"`
Path string `yaml:"path"` Path string `yaml:"path"`
Format string `yaml:"format"` Format string `yaml:"format"`
SourceType string `yaml:"source_type"` // "A" for Apache/HTTP, "B" for Network SourceType string `yaml:"source_type"` // "A" for Apache/HTTP, "B" for Network
SocketPermissions string `yaml:"socket_permissions"` // octal string, e.g., "0660", "0666" SocketPermissions string `yaml:"socket_permissions"` // octal string, e.g., "0660", "0666"
} }
@ -55,7 +55,7 @@ type UnixSocketConfig struct {
type OutputsConfig struct { type OutputsConfig struct {
File FileOutputConfig `yaml:"file"` File FileOutputConfig `yaml:"file"`
ClickHouse ClickHouseOutputConfig `yaml:"clickhouse"` ClickHouse ClickHouseOutputConfig `yaml:"clickhouse"`
Stdout bool `yaml:"stdout"` Stdout StdoutOutputConfig `yaml:"stdout"`
} }
// FileOutputConfig holds file sink configuration. // FileOutputConfig holds file sink configuration.
@ -77,15 +77,14 @@ type ClickHouseOutputConfig struct {
} }
// StdoutOutputConfig holds stdout sink configuration. // StdoutOutputConfig holds stdout sink configuration.
// Deprecated: stdout is now a boolean flag in OutputsConfig.
type StdoutOutputConfig struct { type StdoutOutputConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
} }
// CorrelationConfig holds correlation configuration. // CorrelationConfig holds correlation configuration.
type CorrelationConfig struct { type CorrelationConfig struct {
TimeWindowS int `yaml:"time_window_s"` TimeWindowS int `yaml:"time_window_s"`
EmitOrphans bool `yaml:"emit_orphans"` EmitOrphans bool `yaml:"emit_orphans"`
} }
// Load loads configuration from a YAML file. // Load loads configuration from a YAML file.
@ -130,7 +129,7 @@ func defaultConfig() *Config {
AsyncInsert: true, AsyncInsert: true,
TimeoutMs: 1000, TimeoutMs: 1000,
}, },
Stdout: false, Stdout: StdoutOutputConfig{Enabled: false},
}, },
Correlation: CorrelationConfig{ Correlation: CorrelationConfig{
TimeWindowS: 1, TimeWindowS: 1,
@ -175,7 +174,7 @@ func (c *Config) Validate() error {
if c.Outputs.ClickHouse.Enabled { if c.Outputs.ClickHouse.Enabled {
hasOutput = true hasOutput = true
} }
if c.Outputs.Stdout { if c.Outputs.Stdout.Enabled {
hasOutput = true hasOutput = true
} }
@ -220,12 +219,13 @@ func (c *CorrelationConfig) GetTimeWindow() time.Duration {
// GetSocketPermissions returns the socket permissions as os.FileMode. // GetSocketPermissions returns the socket permissions as os.FileMode.
// Default is 0660 (owner + group read/write). // Default is 0660 (owner + group read/write).
func (c *UnixSocketConfig) GetSocketPermissions() os.FileMode { func (c *UnixSocketConfig) GetSocketPermissions() os.FileMode {
if c.SocketPermissions == "" { trimmed := strings.TrimSpace(c.SocketPermissions)
if trimmed == "" {
return 0660 return 0660
} }
// Parse octal string (e.g., "0660", "660", "0666") // Parse octal string (e.g., "0660", "660", "0666")
perms, err := strconv.ParseUint(strings.TrimPrefix(c.SocketPermissions, "0"), 8, 32) perms, err := strconv.ParseUint(trimmed, 8, 32)
if err != nil { if err != nil {
return 0660 return 0660
} }

View File

@ -131,7 +131,7 @@ func TestValidate_AtLeastOneOutput(t *testing.T) {
Outputs: OutputsConfig{ Outputs: OutputsConfig{
File: FileOutputConfig{}, File: FileOutputConfig{},
ClickHouse: ClickHouseOutputConfig{Enabled: false}, ClickHouse: ClickHouseOutputConfig{Enabled: false},
Stdout: false, Stdout: StdoutOutputConfig{Enabled: false},
}, },
} }
@ -554,3 +554,37 @@ correlation:
t.Errorf("expected log level DEBUG, got %s", cfg.Log.GetLevel()) t.Errorf("expected log level DEBUG, got %s", cfg.Log.GetLevel())
} }
} }
func TestLoad_StdoutEnabledObject(t *testing.T) {
content := `
inputs:
unix_sockets:
- name: http
path: /var/run/logcorrelator/http.socket
- name: network
path: /var/run/logcorrelator/network.socket
outputs:
stdout:
enabled: true
correlation:
time_window_s: 1
emit_orphans: true
`
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yml")
if err := os.WriteFile(configPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
cfg, err := Load(configPath)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !cfg.Outputs.Stdout.Enabled {
t.Error("expected outputs.stdout.enabled to be true")
}
}

View File

@ -2,6 +2,7 @@ package domain
import ( import (
"encoding/json" "encoding/json"
"reflect"
"time" "time"
) )
@ -38,8 +39,20 @@ func (c CorrelatedLog) MarshalJSON() ([]byte, error) {
flat["orphan_side"] = c.OrphanSide flat["orphan_side"] = c.OrphanSide
} }
// Merge additional fields // Merge additional fields while preserving reserved keys
reservedKeys := map[string]struct{}{
"timestamp": {},
"src_ip": {},
"src_port": {},
"dst_ip": {},
"dst_port": {},
"correlated": {},
"orphan_side": {},
}
for k, v := range c.Fields { for k, v := range c.Fields {
if _, reserved := reservedKeys[k]; reserved {
continue
}
flat[k] = v flat[k] = v
} }
@ -89,13 +102,28 @@ func extractFields(e *NormalizedEvent) map[string]any {
func mergeFields(a, b *NormalizedEvent) map[string]any { func mergeFields(a, b *NormalizedEvent) map[string]any {
result := make(map[string]any) result := make(map[string]any)
// Merge fields from both events
// Start with A fields
for k, v := range a.Raw { for k, v := range a.Raw {
result[k] = v result[k] = v
} }
// Merge B fields with collision handling
for k, v := range b.Raw { for k, v := range b.Raw {
if existing, exists := result[k]; exists {
if reflect.DeepEqual(existing, v) {
continue
}
// Collision with different values: keep both with prefixes
delete(result, k)
result["a_"+k] = existing
result["b_"+k] = v
continue
}
result[k] = v result[k] = v
} }
return result return result
} }

View File

@ -101,9 +101,6 @@ func (s *CorrelationService) ProcessEvent(event *NormalizedEvent) []CorrelatedLo
if event.Source == SourceA && s.config.ApacheAlwaysEmit { if event.Source == SourceA && s.config.ApacheAlwaysEmit {
return []CorrelatedLog{NewCorrelatedLogFromEvent(event, "A")} return []CorrelatedLog{NewCorrelatedLogFromEvent(event, "A")}
} }
if event.Source == SourceB && s.config.NetworkEmit {
return []CorrelatedLog{NewCorrelatedLogFromEvent(event, "B")}
}
return nil return nil
} }
@ -187,14 +184,7 @@ func (s *CorrelationService) processSourceB(event *NormalizedEvent) ([]Correlate
return []CorrelatedLog{correlated}, false return []CorrelatedLog{correlated}, false
} }
// No match found - orphan B event (not emitted by default) // Never emit B alone. Keep in buffer for potential future match.
if s.config.NetworkEmit {
orphan := NewCorrelatedLogFromEvent(event, "B")
s.logger.Warnf("orphan B event (no A match): src_ip=%s src_port=%d", event.SrcIP, event.SrcPort)
return []CorrelatedLog{orphan}, false
}
// Keep in buffer for potential future match
return nil, true return nil, true
} }
@ -231,20 +221,17 @@ func (s *CorrelationService) cleanExpired() {
// cleanBuffer removes expired events from a buffer. // cleanBuffer removes expired events from a buffer.
func (s *CorrelationService) cleanBuffer(buffer *eventBuffer, pending map[string][]*list.Element, cutoff time.Time) { func (s *CorrelationService) cleanBuffer(buffer *eventBuffer, pending map[string][]*list.Element, cutoff time.Time) {
for elem := buffer.events.Front(); elem != nil; { for elem := buffer.events.Front(); elem != nil; {
next := elem.Next()
event := elem.Value.(*NormalizedEvent) event := elem.Value.(*NormalizedEvent)
if event.Timestamp.Before(cutoff) { if event.Timestamp.Before(cutoff) {
next := elem.Next()
key := event.CorrelationKey() key := event.CorrelationKey()
buffer.events.Remove(elem) buffer.events.Remove(elem)
pending[key] = removeElementFromSlice(pending[key], elem) pending[key] = removeElementFromSlice(pending[key], elem)
if len(pending[key]) == 0 { if len(pending[key]) == 0 {
delete(pending, key) delete(pending, key)
} }
elem = next
continue
} }
// Events are inserted in arrival order; once we hit a non-expired event, stop. elem = next
break
} }
} }
@ -307,14 +294,7 @@ func (s *CorrelationService) Flush() []CorrelatedLog {
} }
} }
// Emit remaining B events as orphans only if explicitly enabled // Never emit remaining B events alone.
if s.config.NetworkEmit {
for elem := s.bufferB.events.Front(); elem != nil; elem = elem.Next() {
event := elem.Value.(*NormalizedEvent)
orphan := NewCorrelatedLogFromEvent(event, "B")
results = append(results, orphan)
}
}
// Clear buffers // Clear buffers
s.bufferA.events.Init() s.bufferA.events.Init()

View File

@ -193,7 +193,6 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) {
now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
timeProvider := &mockTimeProvider{now: now} timeProvider := &mockTimeProvider{now: now}
// Flush only emits events if ApacheAlwaysEmit and NetworkEmit are true
config := CorrelationConfig{ config := CorrelationConfig{
TimeWindow: time.Second, TimeWindow: time.Second,
ApacheAlwaysEmit: true, ApacheAlwaysEmit: true,
@ -201,7 +200,6 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) {
} }
svc := NewCorrelationService(config, timeProvider) svc := NewCorrelationService(config, timeProvider)
// We need to bypass the normal ProcessEvent logic to get events into buffers
// Add events directly to buffers for testing Flush // Add events directly to buffers for testing Flush
keyA := "192.168.1.1:8080" keyA := "192.168.1.1:8080"
keyB := "192.168.1.2:9090" keyB := "192.168.1.2:9090"
@ -219,7 +217,6 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) {
SrcPort: 9090, SrcPort: 9090,
} }
// Manually add to buffers (simulating events that couldn't be matched)
elemA := svc.bufferA.events.PushBack(apacheEvent) elemA := svc.bufferA.events.PushBack(apacheEvent)
svc.pendingA[keyA] = append(svc.pendingA[keyA], elemA) svc.pendingA[keyA] = append(svc.pendingA[keyA], elemA)
@ -227,8 +224,11 @@ func TestCorrelationService_FlushWithEvents(t *testing.T) {
svc.pendingB[keyB] = append(svc.pendingB[keyB], elemB) svc.pendingB[keyB] = append(svc.pendingB[keyB], elemB)
flushed := svc.Flush() flushed := svc.Flush()
if len(flushed) != 2 { if len(flushed) != 1 {
t.Errorf("expected 2 flushed events, got %d", len(flushed)) t.Errorf("expected 1 flushed event (A only), got %d", len(flushed))
}
if len(flushed) == 1 && flushed[0].OrphanSide != "A" {
t.Errorf("expected orphan side A, got %s", flushed[0].OrphanSide)
} }
// Verify buffers are cleared // Verify buffers are cleared
@ -262,7 +262,7 @@ func TestCorrelationService_BufferOverflow(t *testing.T) {
svc.ProcessEvent(event) svc.ProcessEvent(event)
} }
// Buffer full, next event should be dropped (not emitted since ApacheAlwaysEmit=false but buffer full) // Buffer full, next event should be dropped (not emitted since ApacheAlwaysEmit=false)
overflowEvent := &NormalizedEvent{ overflowEvent := &NormalizedEvent{
Source: SourceA, Source: SourceA,
Timestamp: now, Timestamp: now,
@ -340,3 +340,33 @@ func TestCorrelationService_DifferentSourceTypes(t *testing.T) {
t.Error("expected correlated result") t.Error("expected correlated result")
} }
} }
func TestCorrelationService_NetworkEmitTrue_DoesNotEmitBAlone(t *testing.T) {
now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
timeProvider := &mockTimeProvider{now: now}
config := CorrelationConfig{
TimeWindow: time.Second,
ApacheAlwaysEmit: false,
NetworkEmit: true,
}
svc := NewCorrelationService(config, timeProvider)
networkEvent := &NormalizedEvent{
Source: SourceB,
Timestamp: now,
SrcIP: "10.10.10.10",
SrcPort: 5555,
}
results := svc.ProcessEvent(networkEvent)
if len(results) != 0 {
t.Errorf("expected 0 immediate results for orphan B, got %d", len(results))
}
flushed := svc.Flush()
if len(flushed) != 0 {
t.Errorf("expected 0 flushed orphan B events, got %d", len(flushed))
}
}

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"sort"
"strings" "strings"
"sync" "sync"
) )
@ -52,10 +53,10 @@ func (l LogLevel) String() string {
// Logger provides structured logging. // Logger provides structured logging.
type Logger struct { type Logger struct {
mu sync.Mutex mu sync.RWMutex
logger *log.Logger logger *log.Logger
prefix string prefix string
fields map[string]any fields map[string]any
minLevel LogLevel minLevel LogLevel
} }
@ -88,30 +89,32 @@ func (l *Logger) SetLevel(level string) {
// ShouldLog returns true if the given level should be logged. // ShouldLog returns true if the given level should be logged.
func (l *Logger) ShouldLog(level LogLevel) bool { func (l *Logger) ShouldLog(level LogLevel) bool {
l.mu.Lock() l.mu.RLock()
defer l.mu.Unlock() defer l.mu.RUnlock()
return level >= l.minLevel return level >= l.minLevel
} }
// WithFields returns a new logger with additional fields. // WithFields returns a new logger with additional fields.
func (l *Logger) WithFields(fields map[string]any) *Logger { func (l *Logger) WithFields(fields map[string]any) *Logger {
l.mu.Lock() l.mu.RLock()
minLevel := l.minLevel minLevel := l.minLevel
l.mu.Unlock() prefix := l.prefix
existing := make(map[string]any, len(l.fields))
for k, v := range l.fields {
existing[k] = v
}
l.mu.RUnlock()
newLogger := &Logger{ for k, v := range fields {
existing[k] = v
}
return &Logger{
logger: l.logger, logger: l.logger,
prefix: l.prefix, prefix: prefix,
fields: make(map[string]any), fields: existing,
minLevel: minLevel, minLevel: minLevel,
} }
for k, v := range l.fields {
newLogger.fields[k] = v
}
for k, v := range fields {
newLogger.fields[k] = v
}
return newLogger
} }
// Info logs an info message. // Info logs an info message.
@ -119,8 +122,6 @@ func (l *Logger) Info(msg string) {
if !l.ShouldLog(INFO) { if !l.ShouldLog(INFO) {
return return
} }
l.mu.Lock()
defer l.mu.Unlock()
l.log("INFO", msg) l.log("INFO", msg)
} }
@ -129,8 +130,6 @@ func (l *Logger) Warn(msg string) {
if !l.ShouldLog(WARN) { if !l.ShouldLog(WARN) {
return return
} }
l.mu.Lock()
defer l.mu.Unlock()
l.log("WARN", msg) l.log("WARN", msg)
} }
@ -139,8 +138,6 @@ func (l *Logger) Error(msg string, err error) {
if !l.ShouldLog(ERROR) { if !l.ShouldLog(ERROR) {
return return
} }
l.mu.Lock()
defer l.mu.Unlock()
if err != nil { if err != nil {
l.log("ERROR", msg+" "+err.Error()) l.log("ERROR", msg+" "+err.Error())
} else { } else {
@ -153,8 +150,6 @@ func (l *Logger) Debug(msg string) {
if !l.ShouldLog(DEBUG) { if !l.ShouldLog(DEBUG) {
return return
} }
l.mu.Lock()
defer l.mu.Unlock()
l.log("DEBUG", msg) l.log("DEBUG", msg)
} }
@ -163,8 +158,6 @@ func (l *Logger) Debugf(msg string, args ...any) {
if !l.ShouldLog(DEBUG) { if !l.ShouldLog(DEBUG) {
return return
} }
l.mu.Lock()
defer l.mu.Unlock()
l.log("DEBUG", fmt.Sprintf(msg, args...)) l.log("DEBUG", fmt.Sprintf(msg, args...))
} }
@ -173,8 +166,6 @@ func (l *Logger) Warnf(msg string, args ...any) {
if !l.ShouldLog(WARN) { if !l.ShouldLog(WARN) {
return return
} }
l.mu.Lock()
defer l.mu.Unlock()
l.log("WARN", fmt.Sprintf(msg, args...)) l.log("WARN", fmt.Sprintf(msg, args...))
} }
@ -183,27 +174,42 @@ func (l *Logger) Infof(msg string, args ...any) {
if !l.ShouldLog(INFO) { if !l.ShouldLog(INFO) {
return return
} }
l.mu.Lock()
defer l.mu.Unlock()
l.log("INFO", fmt.Sprintf(msg, args...)) l.log("INFO", fmt.Sprintf(msg, args...))
} }
func (l *Logger) log(level, msg string) { func (l *Logger) log(level, msg string) {
l.mu.RLock()
prefix := l.prefix prefix := l.prefix
if prefix != "" { fields := make(map[string]any, len(l.fields))
prefix = "[" + prefix + "] "
}
l.logger.SetPrefix(prefix + level + " ")
var args []any
for k, v := range l.fields { for k, v := range l.fields {
args = append(args, k, v) fields[k] = v
}
l.mu.RUnlock()
var b strings.Builder
if prefix != "" {
b.WriteString("[")
b.WriteString(prefix)
b.WriteString("] ")
}
b.WriteString(level)
b.WriteString(" ")
b.WriteString(msg)
if len(fields) > 0 {
keys := make([]string, 0, len(fields))
for k := range fields {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
b.WriteString(" ")
b.WriteString(k)
b.WriteString("=")
b.WriteString(fmt.Sprintf("%v", fields[k]))
}
} }
if len(args) > 0 { l.logger.Print(b.String())
l.logger.Printf(msg+" %+v", args...)
} else {
l.logger.Print(msg)
}
} }