diff --git a/internal/adapters/inbound/unixsocket/source.go b/internal/adapters/inbound/unixsocket/source.go index 5b0907b..8958e26 100644 --- a/internal/adapters/inbound/unixsocket/source.go +++ b/internal/adapters/inbound/unixsocket/source.go @@ -40,6 +40,7 @@ type UnixSocketSource struct { done chan struct{} wg sync.WaitGroup semaphore chan struct{} // Limit concurrent connections + stopOnce sync.Once } // NewUnixSocketSource creates a new Unix socket source. @@ -58,6 +59,10 @@ func (s *UnixSocketSource) Name() string { // Start begins listening on the Unix socket. func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error { + if strings.TrimSpace(s.config.Path) == "" { + return fmt.Errorf("socket path cannot be empty") + } + // Remove existing socket file if present if info, err := os.Stat(s.config.Path); err == nil { if info.Mode()&os.ModeSocket != 0 { @@ -78,8 +83,8 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N // Set permissions - fail if we can't if err := os.Chmod(s.config.Path, DefaultSocketPermissions); err != nil { - listener.Close() - os.Remove(s.config.Path) + _ = listener.Close() + _ = os.Remove(s.config.Path) return fmt.Errorf("failed to set socket permissions: %w", err) } @@ -120,7 +125,7 @@ func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan // Connection accepted default: // Too many connections, reject - conn.Close() + _ = conn.Close() continue } @@ -136,7 +141,7 @@ func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventChan chan<- *domain.NormalizedEvent) { // Set read deadline to prevent hanging - conn.SetReadDeadline(time.Now().Add(5 * time.Minute)) + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Minute)) scanner := bufio.NewScanner(conn) // Increase buffer size limit to 1MB @@ -167,10 +172,6 @@ func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventC return } } - - if err := scanner.Err(); err != nil { - // Connection error, log but don't crash - } } func parseJSONEvent(data []byte) (*domain.NormalizedEvent, error) { @@ -314,21 +315,26 @@ func getInt64(m map[string]any, key string) (int64, bool) { // Stop gracefully stops the source. func (s *UnixSocketSource) Stop() error { - s.mu.Lock() - defer s.mu.Unlock() + var stopErr error - close(s.done) + s.stopOnce.Do(func() { + s.mu.Lock() + defer s.mu.Unlock() - if s.listener != nil { - s.listener.Close() - } + close(s.done) - s.wg.Wait() + if s.listener != nil { + _ = s.listener.Close() + } - // Clean up socket file - if err := os.Remove(s.config.Path); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to remove socket file: %w", err) - } + s.wg.Wait() - return nil + // Clean up socket file + if err := os.Remove(s.config.Path); err != nil && !os.IsNotExist(err) { + stopErr = fmt.Errorf("failed to remove socket file: %w", err) + return + } + }) + + return stopErr } diff --git a/internal/adapters/outbound/clickhouse/sink.go b/internal/adapters/outbound/clickhouse/sink.go index a9cfdfa..36d780e 100644 --- a/internal/adapters/outbound/clickhouse/sink.go +++ b/internal/adapters/outbound/clickhouse/sink.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "strings" "sync" "time" @@ -49,10 +50,18 @@ type ClickHouseSink struct { flushChan chan struct{} done chan struct{} wg sync.WaitGroup + closeOnce sync.Once } // NewClickHouseSink creates a new ClickHouse sink. func NewClickHouseSink(config Config) (*ClickHouseSink, error) { + if strings.TrimSpace(config.DSN) == "" { + return nil, fmt.Errorf("clickhouse DSN is required") + } + if strings.TrimSpace(config.Table) == "" { + return nil, fmt.Errorf("clickhouse table is required") + } + // Apply defaults if config.BatchSize <= 0 { config.BatchSize = DefaultBatchSize @@ -85,7 +94,7 @@ func NewClickHouseSink(config Config) (*ClickHouseSink, error) { defer pingCancel() if err := db.PingContext(pingCtx); err != nil { - db.Close() + _ = db.Close() return nil, fmt.Errorf("failed to ping ClickHouse: %w", err) } @@ -143,13 +152,28 @@ func (s *ClickHouseSink) Flush(ctx context.Context) error { // Close closes the sink. func (s *ClickHouseSink) Close() error { - close(s.done) - s.wg.Wait() + var closeErr error - if s.db != nil { - return s.db.Close() - } - return nil + s.closeOnce.Do(func() { + if s.done != nil { + close(s.done) + } + s.wg.Wait() + + flushCtx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond) + defer cancel() + if err := s.doFlush(flushCtx); err != nil { + closeErr = err + } + + if s.db != nil { + if err := s.db.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + }) + + return closeErr } func (s *ClickHouseSink) flushLoop() { @@ -161,25 +185,30 @@ func (s *ClickHouseSink) flushLoop() { for { select { case <-s.done: + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond) + _ = s.doFlush(ctx) + cancel() return + case <-ticker.C: s.mu.Lock() needsFlush := len(s.buffer) > 0 s.mu.Unlock() + if needsFlush { - // Use timeout context for flush ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond) - s.doFlush(ctx) + _ = s.doFlush(ctx) cancel() } + case <-s.flushChan: s.mu.Lock() needsFlush := len(s.buffer) >= s.config.BatchSize s.mu.Unlock() + if needsFlush { - // Use timeout context for flush ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond) - s.doFlush(ctx) + _ = s.doFlush(ctx) cancel() } } @@ -199,7 +228,10 @@ func (s *ClickHouseSink) doFlush(ctx context.Context) error { s.buffer = make([]domain.CorrelatedLog, 0, s.config.BatchSize) s.mu.Unlock() - // Prepare batch insert with retry + if s.db == nil { + return fmt.Errorf("clickhouse connection is not initialized") + } + query := fmt.Sprintf(` INSERT INTO %s (timestamp, src_ip, src_port, dst_ip, dst_port, correlated, orphan_side, apache, network) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) @@ -209,7 +241,6 @@ func (s *ClickHouseSink) doFlush(ctx context.Context) error { var lastErr error for attempt := 0; attempt < MaxRetries; attempt++ { if attempt > 0 { - // Exponential backoff delay := RetryBaseDelay * time.Duration(1<= len(substr) && containsLower(s, substr) -} - -func containsLower(s, substr string) bool { - s = toLower(s) - substr = toLower(substr) - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} - -func toLower(s string) string { - var result []byte - for i := 0; i < len(s); i++ { - c := s[i] - if c >= 'A' && c <= 'Z' { - c = c + ('a' - 'A') - } - result = append(result, c) - } - return string(result) -} diff --git a/internal/config/config.go b/internal/config/config.go index 59ff89c..c2b8a71 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "os" + "strings" "time" "gopkg.in/yaml.v3" @@ -36,9 +37,9 @@ type UnixSocketConfig struct { // OutputsConfig holds output sinks configuration. type OutputsConfig struct { - File FileOutputConfig `yaml:"file"` + File FileOutputConfig `yaml:"file"` ClickHouse ClickHouseOutputConfig `yaml:"clickhouse"` - Stdout StdoutOutputConfig `yaml:"stdout"` + Stdout StdoutOutputConfig `yaml:"stdout"` } // FileOutputConfig holds file sink configuration. @@ -152,12 +153,59 @@ func (c *Config) Validate() error { return fmt.Errorf("at least two unix socket inputs are required") } + seenNames := make(map[string]struct{}, len(c.Inputs.UnixSockets)) + seenPaths := make(map[string]struct{}, len(c.Inputs.UnixSockets)) + + for i, input := range c.Inputs.UnixSockets { + if strings.TrimSpace(input.Name) == "" { + return fmt.Errorf("inputs.unix_sockets[%d].name is required", i) + } + if strings.TrimSpace(input.Path) == "" { + return fmt.Errorf("inputs.unix_sockets[%d].path is required", i) + } + + if _, exists := seenNames[input.Name]; exists { + return fmt.Errorf("duplicate unix socket input name: %s", input.Name) + } + seenNames[input.Name] = struct{}{} + + if _, exists := seenPaths[input.Path]; exists { + return fmt.Errorf("duplicate unix socket input path: %s", input.Path) + } + seenPaths[input.Path] = struct{}{} + } + if !c.Outputs.File.Enabled && !c.Outputs.ClickHouse.Enabled && !c.Outputs.Stdout.Enabled { return fmt.Errorf("at least one output must be enabled") } - if c.Outputs.ClickHouse.Enabled && c.Outputs.ClickHouse.DSN == "" { - return fmt.Errorf("clickhouse DSN is required when enabled") + if c.Outputs.File.Enabled && strings.TrimSpace(c.Outputs.File.Path) == "" { + return fmt.Errorf("file output path is required when file output is enabled") + } + + if c.Outputs.ClickHouse.Enabled { + if strings.TrimSpace(c.Outputs.ClickHouse.DSN) == "" { + return fmt.Errorf("clickhouse DSN is required when enabled") + } + if strings.TrimSpace(c.Outputs.ClickHouse.Table) == "" { + return fmt.Errorf("clickhouse table is required when enabled") + } + if c.Outputs.ClickHouse.BatchSize <= 0 { + return fmt.Errorf("clickhouse batch_size must be > 0") + } + if c.Outputs.ClickHouse.MaxBufferSize <= 0 { + return fmt.Errorf("clickhouse max_buffer_size must be > 0") + } + if c.Outputs.ClickHouse.TimeoutMs <= 0 { + return fmt.Errorf("clickhouse timeout_ms must be > 0") + } + } + + if len(c.Correlation.Key) == 0 { + return fmt.Errorf("correlation.key cannot be empty") + } + if c.Correlation.TimeWindow.Value <= 0 { + return fmt.Errorf("correlation.time_window.value must be > 0") } return nil diff --git a/internal/domain/correlation_service.go b/internal/domain/correlation_service.go index dfac022..33f5207 100644 --- a/internal/domain/correlation_service.go +++ b/internal/domain/correlation_service.go @@ -9,6 +9,8 @@ import ( const ( // DefaultMaxBufferSize is the default maximum number of events per buffer DefaultMaxBufferSize = 10000 + // DefaultTimeWindow is used when no valid time window is provided + DefaultTimeWindow = time.Second ) // CorrelationConfig holds the correlation configuration. @@ -25,8 +27,8 @@ type CorrelationService struct { mu sync.Mutex bufferA *eventBuffer bufferB *eventBuffer - pendingA map[string]*list.Element // key -> list element containing NormalizedEvent - pendingB map[string]*list.Element + pendingA map[string][]*list.Element // key -> ordered elements containing *NormalizedEvent + pendingB map[string][]*list.Element timeProvider TimeProvider } @@ -60,12 +62,16 @@ func NewCorrelationService(config CorrelationConfig, timeProvider TimeProvider) if config.MaxBufferSize <= 0 { config.MaxBufferSize = DefaultMaxBufferSize } + if config.TimeWindow <= 0 { + config.TimeWindow = DefaultTimeWindow + } + return &CorrelationService{ config: config, bufferA: newEventBuffer(), bufferB: newEventBuffer(), - pendingA: make(map[string]*list.Element), - pendingB: make(map[string]*list.Element), + pendingA: make(map[string][]*list.Element), + pendingB: make(map[string][]*list.Element), timeProvider: timeProvider, } } @@ -84,20 +90,29 @@ func (s *CorrelationService) ProcessEvent(event *NormalizedEvent) []CorrelatedLo if event.Source == SourceA && s.config.ApacheAlwaysEmit { return []CorrelatedLog{NewCorrelatedLogFromEvent(event, "A")} } + if event.Source == SourceB && s.config.NetworkEmit { + return []CorrelatedLog{NewCorrelatedLogFromEvent(event, "B")} + } return nil } - var results []CorrelatedLog + var ( + results []CorrelatedLog + shouldBuffer bool + ) switch event.Source { case SourceA: - results = s.processSourceA(event) + results, shouldBuffer = s.processSourceA(event) case SourceB: - results = s.processSourceB(event) + results, shouldBuffer = s.processSourceB(event) + default: + return nil } - // Add the new event to the appropriate buffer - s.addEvent(event) + if shouldBuffer { + s.addEvent(event) + } return results } @@ -112,54 +127,46 @@ func (s *CorrelationService) isBufferFull(source EventSource) bool { return false } -func (s *CorrelationService) processSourceA(event *NormalizedEvent) []CorrelatedLog { +func (s *CorrelationService) processSourceA(event *NormalizedEvent) ([]CorrelatedLog, bool) { key := event.CorrelationKeyFull() - // Look for a matching B event - if elem, ok := s.pendingB[key]; ok { - bEvent := elem.Value.(*NormalizedEvent) - if s.eventsMatch(event, bEvent) { - // Found a match! - correlated := NewCorrelatedLog(event, bEvent) - s.bufferB.events.Remove(elem) - delete(s.pendingB, key) - return []CorrelatedLog{correlated} - } + // Look for the first matching B event (one-to-one first match) + if bEvent := s.findAndPopFirstMatch(s.bufferB, s.pendingB, key, func(other *NormalizedEvent) bool { + return s.eventsMatch(event, other) + }); bEvent != nil { + correlated := NewCorrelatedLog(event, bEvent) + return []CorrelatedLog{correlated}, false } // No match found if s.config.ApacheAlwaysEmit { orphan := NewCorrelatedLogFromEvent(event, "A") - return []CorrelatedLog{orphan} + return []CorrelatedLog{orphan}, false } // Keep in buffer for potential future match - return nil + return nil, true } -func (s *CorrelationService) processSourceB(event *NormalizedEvent) []CorrelatedLog { +func (s *CorrelationService) processSourceB(event *NormalizedEvent) ([]CorrelatedLog, bool) { key := event.CorrelationKeyFull() - // Look for a matching A event - if elem, ok := s.pendingA[key]; ok { - aEvent := elem.Value.(*NormalizedEvent) - if s.eventsMatch(aEvent, event) { - // Found a match! - correlated := NewCorrelatedLog(aEvent, event) - s.bufferA.events.Remove(elem) - delete(s.pendingA, key) - return []CorrelatedLog{correlated} - } + // Look for the first matching A event (one-to-one first match) + if aEvent := s.findAndPopFirstMatch(s.bufferA, s.pendingA, key, func(other *NormalizedEvent) bool { + return s.eventsMatch(other, event) + }); aEvent != nil { + correlated := NewCorrelatedLog(aEvent, event) + return []CorrelatedLog{correlated}, false } - // No match found - B is never emitted alone per spec + // No match found if s.config.NetworkEmit { orphan := NewCorrelatedLogFromEvent(event, "B") - return []CorrelatedLog{orphan} + return []CorrelatedLog{orphan}, false } - // Keep in buffer for potential future match (but won't be emitted alone) - return nil + // Keep in buffer for potential future match + return nil, true } func (s *CorrelationService) eventsMatch(a, b *NormalizedEvent) bool { @@ -176,10 +183,10 @@ func (s *CorrelationService) addEvent(event *NormalizedEvent) { switch event.Source { case SourceA: elem := s.bufferA.events.PushBack(event) - s.pendingA[key] = elem + s.pendingA[key] = append(s.pendingA[key], elem) case SourceB: elem := s.bufferB.events.PushBack(event) - s.pendingB[key] = elem + s.pendingB[key] = append(s.pendingB[key], elem) } } @@ -192,22 +199,67 @@ func (s *CorrelationService) cleanExpired() { s.cleanBuffer(s.bufferB, s.pendingB, cutoff) } -// cleanBuffer removes expired events from a buffer (shared logic for A and B). -func (s *CorrelationService) cleanBuffer(buffer *eventBuffer, pending map[string]*list.Element, cutoff time.Time) { +// cleanBuffer removes expired events from a buffer. +func (s *CorrelationService) cleanBuffer(buffer *eventBuffer, pending map[string][]*list.Element, cutoff time.Time) { for elem := buffer.events.Front(); elem != nil; { event := elem.Value.(*NormalizedEvent) if event.Timestamp.Before(cutoff) { next := elem.Next() key := event.CorrelationKeyFull() buffer.events.Remove(elem) - if pending[key] == elem { + pending[key] = removeElementFromSlice(pending[key], elem) + if len(pending[key]) == 0 { delete(pending, key) } elem = next + continue + } + // Events are inserted in arrival order; once we hit a non-expired event, stop. + break + } +} + +func (s *CorrelationService) findAndPopFirstMatch( + buffer *eventBuffer, + pending map[string][]*list.Element, + key string, + matcher func(*NormalizedEvent) bool, +) *NormalizedEvent { + elements, ok := pending[key] + if !ok || len(elements) == 0 { + return nil + } + + for idx, elem := range elements { + other := elem.Value.(*NormalizedEvent) + if !matcher(other) { + continue + } + + buffer.events.Remove(elem) + updated := append(elements[:idx], elements[idx+1:]...) + if len(updated) == 0 { + delete(pending, key) } else { - break // Events are ordered, so we can stop early + pending[key] = updated + } + + return other + } + + return nil +} + +func removeElementFromSlice(elements []*list.Element, target *list.Element) []*list.Element { + if len(elements) == 0 { + return elements + } + for i, elem := range elements { + if elem == target { + return append(elements[:i], elements[i+1:]...) } } + return elements } // Flush forces emission of remaining buffered events (for shutdown). @@ -226,11 +278,20 @@ func (s *CorrelationService) Flush() []CorrelatedLog { } } + // Emit remaining B events as orphans only if explicitly enabled + if s.config.NetworkEmit { + for elem := s.bufferB.events.Front(); elem != nil; elem = elem.Next() { + event := elem.Value.(*NormalizedEvent) + orphan := NewCorrelatedLogFromEvent(event, "B") + results = append(results, orphan) + } + } + // Clear buffers s.bufferA.events.Init() s.bufferB.events.Init() - s.pendingA = make(map[string]*list.Element) - s.pendingB = make(map[string]*list.Element) + s.pendingA = make(map[string][]*list.Element) + s.pendingB = make(map[string][]*list.Element) return results } diff --git a/internal/domain/correlation_service_test.go b/internal/domain/correlation_service_test.go index ffc03d6..312f538 100644 --- a/internal/domain/correlation_service_test.go +++ b/internal/domain/correlation_service_test.go @@ -144,10 +144,14 @@ func TestCorrelationService_Flush(t *testing.T) { SrcPort: 8080, } - svc.ProcessEvent(apacheEvent) + // A est émis immédiatement quand ApacheAlwaysEmit=true + results := svc.ProcessEvent(apacheEvent) + if len(results) != 1 { + t.Fatalf("expected 1 immediate orphan event, got %d", len(results)) + } flushed := svc.Flush() - if len(flushed) != 1 { - t.Errorf("expected 1 flushed event, got %d", len(flushed)) + if len(flushed) != 0 { + t.Errorf("expected 0 flushed events, got %d", len(flushed)) } }