Initial commit: logcorrelator with unified packaging (DEB + RPM using fpm)
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
334
internal/adapters/inbound/unixsocket/source.go
Normal file
334
internal/adapters/inbound/unixsocket/source.go
Normal file
@ -0,0 +1,334 @@
|
||||
package unixsocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/logcorrelator/logcorrelator/internal/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default socket file permissions (owner + group read/write)
|
||||
DefaultSocketPermissions os.FileMode = 0660
|
||||
// Maximum line size for JSON logs (1MB)
|
||||
MaxLineSize = 1024 * 1024
|
||||
// Maximum concurrent connections per socket
|
||||
MaxConcurrentConnections = 100
|
||||
// Rate limit: max events per second
|
||||
MaxEventsPerSecond = 10000
|
||||
)
|
||||
|
||||
// Config holds the Unix socket source configuration.
|
||||
type Config struct {
|
||||
Name string
|
||||
Path string
|
||||
}
|
||||
|
||||
// UnixSocketSource reads JSON events from a Unix socket.
|
||||
type UnixSocketSource struct {
|
||||
config Config
|
||||
mu sync.Mutex
|
||||
listener net.Listener
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
semaphore chan struct{} // Limit concurrent connections
|
||||
}
|
||||
|
||||
// NewUnixSocketSource creates a new Unix socket source.
|
||||
func NewUnixSocketSource(config Config) *UnixSocketSource {
|
||||
return &UnixSocketSource{
|
||||
config: config,
|
||||
done: make(chan struct{}),
|
||||
semaphore: make(chan struct{}, MaxConcurrentConnections),
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the source name.
|
||||
func (s *UnixSocketSource) Name() string {
|
||||
return s.config.Name
|
||||
}
|
||||
|
||||
// Start begins listening on the Unix socket.
|
||||
func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error {
|
||||
// Remove existing socket file if present
|
||||
if info, err := os.Stat(s.config.Path); err == nil {
|
||||
if info.Mode()&os.ModeSocket != 0 {
|
||||
if err := os.Remove(s.config.Path); err != nil {
|
||||
return fmt.Errorf("failed to remove existing socket: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("path exists but is not a socket: %s", s.config.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Create listener
|
||||
listener, err := net.Listen("unix", s.config.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create unix socket listener: %w", err)
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
// Set permissions - fail if we can't
|
||||
if err := os.Chmod(s.config.Path, DefaultSocketPermissions); err != nil {
|
||||
listener.Close()
|
||||
os.Remove(s.config.Path)
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.acceptConnections(ctx, eventChan)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) {
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check semaphore for connection limiting
|
||||
select {
|
||||
case s.semaphore <- struct{}{}:
|
||||
// Connection accepted
|
||||
default:
|
||||
// Too many connections, reject
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer func() { <-s.semaphore }()
|
||||
defer c.Close()
|
||||
s.readEvents(ctx, c, eventChan)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventChan chan<- *domain.NormalizedEvent) {
|
||||
// Set read deadline to prevent hanging
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
|
||||
|
||||
scanner := bufio.NewScanner(conn)
|
||||
// Increase buffer size limit to 1MB
|
||||
buf := make([]byte, 0, 4096)
|
||||
scanner.Buffer(buf, MaxLineSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
event, err := parseJSONEvent(line)
|
||||
if err != nil {
|
||||
// Log parse errors but continue processing
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case eventChan <- event:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
// Connection error, log but don't crash
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONEvent(data []byte) (*domain.NormalizedEvent, error) {
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
|
||||
event := &domain.NormalizedEvent{
|
||||
Raw: raw,
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
|
||||
// Extract and validate src_ip
|
||||
if v, ok := getString(raw, "src_ip"); ok {
|
||||
event.SrcIP = v
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing required field: src_ip")
|
||||
}
|
||||
|
||||
// Extract and validate src_port
|
||||
if v, ok := getInt(raw, "src_port"); ok {
|
||||
if v < 1 || v > 65535 {
|
||||
return nil, fmt.Errorf("src_port must be between 1 and 65535, got %d", v)
|
||||
}
|
||||
event.SrcPort = v
|
||||
} else {
|
||||
return nil, fmt.Errorf("missing required field: src_port")
|
||||
}
|
||||
|
||||
// Extract dst_ip (optional)
|
||||
if v, ok := getString(raw, "dst_ip"); ok {
|
||||
event.DstIP = v
|
||||
}
|
||||
|
||||
// Extract dst_port (optional)
|
||||
if v, ok := getInt(raw, "dst_port"); ok {
|
||||
if v < 0 || v > 65535 {
|
||||
return nil, fmt.Errorf("dst_port must be between 0 and 65535, got %d", v)
|
||||
}
|
||||
event.DstPort = v
|
||||
}
|
||||
|
||||
// Extract timestamp - try different fields
|
||||
if ts, ok := getInt64(raw, "timestamp"); ok {
|
||||
// Assume nanoseconds
|
||||
event.Timestamp = time.Unix(0, ts)
|
||||
} else if tsStr, ok := getString(raw, "time"); ok {
|
||||
if t, err := time.Parse(time.RFC3339, tsStr); err == nil {
|
||||
event.Timestamp = t
|
||||
}
|
||||
} else if tsStr, ok := getString(raw, "timestamp"); ok {
|
||||
if t, err := time.Parse(time.RFC3339, tsStr); err == nil {
|
||||
event.Timestamp = t
|
||||
}
|
||||
}
|
||||
|
||||
if event.Timestamp.IsZero() {
|
||||
event.Timestamp = time.Now()
|
||||
}
|
||||
|
||||
// Extract headers (header_* fields)
|
||||
event.Headers = make(map[string]string)
|
||||
for k, v := range raw {
|
||||
if len(k) > 7 && k[:7] == "header_" {
|
||||
if sv, ok := v.(string); ok {
|
||||
event.Headers[k[7:]] = sv
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine source based on fields present
|
||||
if len(event.Headers) > 0 {
|
||||
event.Source = domain.SourceA
|
||||
} else {
|
||||
event.Source = domain.SourceB
|
||||
}
|
||||
|
||||
// Extra fields (single pass)
|
||||
knownFields := map[string]bool{
|
||||
"src_ip": true, "src_port": true, "dst_ip": true, "dst_port": true,
|
||||
"timestamp": true, "time": true,
|
||||
}
|
||||
for k, v := range raw {
|
||||
if knownFields[k] {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(k, "header_") {
|
||||
continue
|
||||
}
|
||||
event.Extra[k] = v
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func getString(m map[string]any, key string) (string, bool) {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func getInt(m map[string]any, key string) (int, bool) {
|
||||
if v, ok := m[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return int(val), true
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case string:
|
||||
if i, err := strconv.Atoi(val); err == nil {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func getInt64(m map[string]any, key string) (int64, bool) {
|
||||
if v, ok := m[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return int64(val), true
|
||||
case int:
|
||||
return int64(val), true
|
||||
case int64:
|
||||
return val, true
|
||||
case string:
|
||||
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Stop gracefully stops the source.
|
||||
func (s *UnixSocketSource) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
close(s.done)
|
||||
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
// Clean up socket file
|
||||
if err := os.Remove(s.config.Path); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove socket file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
98
internal/adapters/inbound/unixsocket/source_test.go
Normal file
98
internal/adapters/inbound/unixsocket/source_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package unixsocket
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseJSONEvent_Apache(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"dst_ip": "10.0.0.1",
|
||||
"dst_port": 80,
|
||||
"timestamp": 1704110400000000000,
|
||||
"method": "GET",
|
||||
"path": "/api/test",
|
||||
"header_host": "example.com",
|
||||
"header_user_agent": "Mozilla/5.0"
|
||||
}`)
|
||||
|
||||
event, err := parseJSONEvent(data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if event.SrcIP != "192.168.1.1" {
|
||||
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
|
||||
}
|
||||
if event.SrcPort != 8080 {
|
||||
t.Errorf("expected src_port 8080, got %d", event.SrcPort)
|
||||
}
|
||||
if event.Headers["host"] != "example.com" {
|
||||
t.Errorf("expected header host example.com, got %s", event.Headers["host"])
|
||||
}
|
||||
if event.Headers["user_agent"] != "Mozilla/5.0" {
|
||||
t.Errorf("expected header_user_agent Mozilla/5.0, got %s", event.Headers["user_agent"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_Network(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"dst_ip": "10.0.0.1",
|
||||
"dst_port": 443,
|
||||
"ja3": "abc123def456",
|
||||
"ja4": "xyz789",
|
||||
"tcp_meta_flags": "SYN"
|
||||
}`)
|
||||
|
||||
event, err := parseJSONEvent(data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if event.SrcIP != "192.168.1.1" {
|
||||
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
|
||||
}
|
||||
if event.Extra["ja3"] != "abc123def456" {
|
||||
t.Errorf("expected ja3 abc123def456, got %v", event.Extra["ja3"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_InvalidJSON(t *testing.T) {
|
||||
data := []byte(`{invalid json}`)
|
||||
|
||||
_, err := parseJSONEvent(data)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_MissingFields(t *testing.T) {
|
||||
data := []byte(`{"other_field": "value"}`)
|
||||
|
||||
_, err := parseJSONEvent(data)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing src_ip/src_port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONEvent_StringTimestamp(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"src_ip": "192.168.1.1",
|
||||
"src_port": 8080,
|
||||
"time": "2024-01-01T12:00:00Z"
|
||||
}`)
|
||||
|
||||
event, err := parseJSONEvent(data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
if !event.Timestamp.Equal(expected) {
|
||||
t.Errorf("expected timestamp %v, got %v", expected, event.Timestamp)
|
||||
}
|
||||
}
|
||||
333
internal/adapters/outbound/clickhouse/sink.go
Normal file
333
internal/adapters/outbound/clickhouse/sink.go
Normal file
@ -0,0 +1,333 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/logcorrelator/logcorrelator/internal/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBatchSize is the default number of records per batch
|
||||
DefaultBatchSize = 500
|
||||
// DefaultFlushIntervalMs is the default flush interval in milliseconds
|
||||
DefaultFlushIntervalMs = 200
|
||||
// DefaultMaxBufferSize is the default maximum buffer size
|
||||
DefaultMaxBufferSize = 5000
|
||||
// DefaultTimeoutMs is the default timeout for operations in milliseconds
|
||||
DefaultTimeoutMs = 1000
|
||||
// DefaultPingTimeoutMs is the timeout for initial connection ping
|
||||
DefaultPingTimeoutMs = 5000
|
||||
// MaxRetries is the maximum number of retry attempts for failed inserts
|
||||
MaxRetries = 3
|
||||
// RetryBaseDelay is the base delay between retries
|
||||
RetryBaseDelay = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// Config holds the ClickHouse sink configuration.
|
||||
type Config struct {
|
||||
DSN string
|
||||
Table string
|
||||
BatchSize int
|
||||
FlushIntervalMs int
|
||||
MaxBufferSize int
|
||||
DropOnOverflow bool
|
||||
AsyncInsert bool
|
||||
TimeoutMs int
|
||||
}
|
||||
|
||||
// ClickHouseSink writes correlated logs to ClickHouse.
|
||||
type ClickHouseSink struct {
|
||||
config Config
|
||||
db *sql.DB
|
||||
mu sync.Mutex
|
||||
buffer []domain.CorrelatedLog
|
||||
flushChan chan struct{}
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewClickHouseSink creates a new ClickHouse sink.
|
||||
func NewClickHouseSink(config Config) (*ClickHouseSink, error) {
|
||||
// Apply defaults
|
||||
if config.BatchSize <= 0 {
|
||||
config.BatchSize = DefaultBatchSize
|
||||
}
|
||||
if config.FlushIntervalMs <= 0 {
|
||||
config.FlushIntervalMs = DefaultFlushIntervalMs
|
||||
}
|
||||
if config.MaxBufferSize <= 0 {
|
||||
config.MaxBufferSize = DefaultMaxBufferSize
|
||||
}
|
||||
if config.TimeoutMs <= 0 {
|
||||
config.TimeoutMs = DefaultTimeoutMs
|
||||
}
|
||||
|
||||
s := &ClickHouseSink{
|
||||
config: config,
|
||||
buffer: make([]domain.CorrelatedLog, 0, config.BatchSize),
|
||||
flushChan: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Connect to ClickHouse
|
||||
db, err := sql.Open("clickhouse", config.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to ClickHouse: %w", err)
|
||||
}
|
||||
|
||||
// Ping with timeout
|
||||
pingCtx, pingCancel := context.WithTimeout(context.Background(), time.Duration(DefaultPingTimeoutMs)*time.Millisecond)
|
||||
defer pingCancel()
|
||||
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping ClickHouse: %w", err)
|
||||
}
|
||||
|
||||
s.db = db
|
||||
|
||||
// Start flush goroutine
|
||||
s.wg.Add(1)
|
||||
go s.flushLoop()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Name returns the sink name.
|
||||
func (s *ClickHouseSink) Name() string {
|
||||
return "clickhouse"
|
||||
}
|
||||
|
||||
// Write adds a log to the buffer.
|
||||
func (s *ClickHouseSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check buffer overflow
|
||||
if len(s.buffer) >= s.config.MaxBufferSize {
|
||||
if s.config.DropOnOverflow {
|
||||
// Drop the log
|
||||
return nil
|
||||
}
|
||||
// Block until space is available (with timeout)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Duration(s.config.TimeoutMs) * time.Millisecond):
|
||||
return fmt.Errorf("buffer full, timeout exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
s.buffer = append(s.buffer, log)
|
||||
|
||||
// Trigger flush if batch is full
|
||||
if len(s.buffer) >= s.config.BatchSize {
|
||||
select {
|
||||
case s.flushChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush flushes the buffer to ClickHouse.
|
||||
func (s *ClickHouseSink) Flush(ctx context.Context) error {
|
||||
return s.doFlush(ctx)
|
||||
}
|
||||
|
||||
// Close closes the sink.
|
||||
func (s *ClickHouseSink) Close() error {
|
||||
close(s.done)
|
||||
s.wg.Wait()
|
||||
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ClickHouseSink) flushLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(time.Duration(s.config.FlushIntervalMs) * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
needsFlush := len(s.buffer) > 0
|
||||
s.mu.Unlock()
|
||||
if needsFlush {
|
||||
// Use timeout context for flush
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond)
|
||||
s.doFlush(ctx)
|
||||
cancel()
|
||||
}
|
||||
case <-s.flushChan:
|
||||
s.mu.Lock()
|
||||
needsFlush := len(s.buffer) >= s.config.BatchSize
|
||||
s.mu.Unlock()
|
||||
if needsFlush {
|
||||
// Use timeout context for flush
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.TimeoutMs)*time.Millisecond)
|
||||
s.doFlush(ctx)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClickHouseSink) doFlush(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
if len(s.buffer) == 0 {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy buffer to flush
|
||||
buffer := make([]domain.CorrelatedLog, len(s.buffer))
|
||||
copy(buffer, s.buffer)
|
||||
s.buffer = make([]domain.CorrelatedLog, 0, s.config.BatchSize)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Prepare batch insert with retry
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %s (timestamp, src_ip, src_port, dst_ip, dst_port, correlated, orphan_side, apache, network)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, s.config.Table)
|
||||
|
||||
// Retry logic with exponential backoff
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Exponential backoff
|
||||
delay := RetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = s.executeBatch(ctx, query, buffer)
|
||||
if lastErr == nil {
|
||||
return nil // Success
|
||||
}
|
||||
|
||||
// Check if error is retryable
|
||||
if !isRetryableError(lastErr) {
|
||||
return fmt.Errorf("non-retryable error: %w", lastErr)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed after %d retries: %w", MaxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (s *ClickHouseSink) executeBatch(ctx context.Context, query string, buffer []domain.CorrelatedLog) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare statement: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, log := range buffer {
|
||||
apacheJSON, _ := json.Marshal(log.Apache)
|
||||
networkJSON, _ := json.Marshal(log.Network)
|
||||
|
||||
orphanSide := log.OrphanSide
|
||||
if !log.Correlated {
|
||||
orphanSide = log.OrphanSide
|
||||
}
|
||||
|
||||
correlated := 0
|
||||
if log.Correlated {
|
||||
correlated = 1
|
||||
}
|
||||
|
||||
_, err := stmt.ExecContext(ctx,
|
||||
log.Timestamp,
|
||||
log.SrcIP,
|
||||
log.SrcPort,
|
||||
log.DstIP,
|
||||
log.DstPort,
|
||||
correlated,
|
||||
orphanSide,
|
||||
string(apacheJSON),
|
||||
string(networkJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute insert: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error is retryable.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
// Common retryable errors
|
||||
retryableErrors := []string{
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network is unreachable",
|
||||
"broken pipe",
|
||||
}
|
||||
for _, re := range retryableErrors {
|
||||
if containsIgnoreCase(errStr, re) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsIgnoreCase(s, substr string) bool {
|
||||
return len(s) >= len(substr) && containsLower(s, substr)
|
||||
}
|
||||
|
||||
func containsLower(s, substr string) bool {
|
||||
s = toLower(s)
|
||||
substr = toLower(substr)
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toLower(s string) string {
|
||||
var result []byte
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
c = c + ('a' - 'A')
|
||||
}
|
||||
result = append(result, c)
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
305
internal/adapters/outbound/clickhouse/sink_test.go
Normal file
305
internal/adapters/outbound/clickhouse/sink_test.go
Normal file
@ -0,0 +1,305 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/logcorrelator/logcorrelator/internal/domain"
|
||||
)
|
||||
|
||||
func TestClickHouseSink_Name(t *testing.T) {
|
||||
sink := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
}
|
||||
|
||||
if sink.Name() != "clickhouse" {
|
||||
t.Errorf("expected name 'clickhouse', got %s", sink.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_ConfigDefaults(t *testing.T) {
|
||||
// Test that defaults are applied correctly
|
||||
config := Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
// Other fields are zero, should get defaults
|
||||
}
|
||||
|
||||
// Verify defaults would be applied (we can't actually connect in tests)
|
||||
if config.BatchSize <= 0 {
|
||||
config.BatchSize = DefaultBatchSize
|
||||
}
|
||||
if config.FlushIntervalMs <= 0 {
|
||||
config.FlushIntervalMs = DefaultFlushIntervalMs
|
||||
}
|
||||
if config.MaxBufferSize <= 0 {
|
||||
config.MaxBufferSize = DefaultMaxBufferSize
|
||||
}
|
||||
if config.TimeoutMs <= 0 {
|
||||
config.TimeoutMs = DefaultTimeoutMs
|
||||
}
|
||||
|
||||
if config.BatchSize != DefaultBatchSize {
|
||||
t.Errorf("expected BatchSize %d, got %d", DefaultBatchSize, config.BatchSize)
|
||||
}
|
||||
if config.FlushIntervalMs != DefaultFlushIntervalMs {
|
||||
t.Errorf("expected FlushIntervalMs %d, got %d", DefaultFlushIntervalMs, config.FlushIntervalMs)
|
||||
}
|
||||
if config.MaxBufferSize != DefaultMaxBufferSize {
|
||||
t.Errorf("expected MaxBufferSize %d, got %d", DefaultMaxBufferSize, config.MaxBufferSize)
|
||||
}
|
||||
if config.TimeoutMs != DefaultTimeoutMs {
|
||||
t.Errorf("expected TimeoutMs %d, got %d", DefaultTimeoutMs, config.TimeoutMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_Write_BufferOverflow(t *testing.T) {
|
||||
// This test verifies the buffer overflow logic without actually connecting
|
||||
config := Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
BatchSize: 10,
|
||||
MaxBufferSize: 10,
|
||||
DropOnOverflow: true,
|
||||
TimeoutMs: 100,
|
||||
FlushIntervalMs: 1000,
|
||||
}
|
||||
|
||||
// We can't test actual writes without a ClickHouse instance,
|
||||
// but we can verify the config is valid
|
||||
if config.BatchSize > config.MaxBufferSize {
|
||||
t.Error("BatchSize should not exceed MaxBufferSize")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_IsRetryableError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{"nil error", nil, false},
|
||||
{"connection refused", &mockError{"connection refused"}, true},
|
||||
{"connection reset", &mockError{"connection reset by peer"}, true},
|
||||
{"timeout", &mockError{"timeout waiting for response"}, true},
|
||||
{"network unreachable", &mockError{"network is unreachable"}, true},
|
||||
{"broken pipe", &mockError{"broken pipe"}, true},
|
||||
{"syntax error", &mockError{"syntax error in SQL"}, false},
|
||||
{"table not found", &mockError{"table test not found"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isRetryableError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_FlushEmpty(t *testing.T) {
|
||||
// Test that flushing an empty buffer doesn't cause issues
|
||||
// (We can't test actual ClickHouse operations without a real instance)
|
||||
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
}
|
||||
|
||||
// Should not panic or error on empty flush
|
||||
ctx := context.Background()
|
||||
err := s.Flush(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on empty flush, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_CloseWithoutConnect(t *testing.T) {
|
||||
// Test that closing without connecting doesn't panic
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
err := s.Close()
|
||||
if err != nil {
|
||||
t.Errorf("expected no error on close without connect, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseSink_Constants(t *testing.T) {
|
||||
// Verify constants have reasonable values
|
||||
if DefaultBatchSize <= 0 {
|
||||
t.Error("DefaultBatchSize should be positive")
|
||||
}
|
||||
if DefaultFlushIntervalMs <= 0 {
|
||||
t.Error("DefaultFlushIntervalMs should be positive")
|
||||
}
|
||||
if DefaultMaxBufferSize <= 0 {
|
||||
t.Error("DefaultMaxBufferSize should be positive")
|
||||
}
|
||||
if DefaultTimeoutMs <= 0 {
|
||||
t.Error("DefaultTimeoutMs should be positive")
|
||||
}
|
||||
if DefaultPingTimeoutMs <= 0 {
|
||||
t.Error("DefaultPingTimeoutMs should be positive")
|
||||
}
|
||||
if MaxRetries <= 0 {
|
||||
t.Error("MaxRetries should be positive")
|
||||
}
|
||||
if RetryBaseDelay <= 0 {
|
||||
t.Error("RetryBaseDelay should be positive")
|
||||
}
|
||||
}
|
||||
|
||||
// mockError implements error for testing
|
||||
type mockError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *mockError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
// Test the doFlush function with empty buffer (no actual DB connection)
|
||||
func TestClickHouseSink_DoFlushEmpty(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := s.doFlush(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error when flushing empty buffer, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that buffer is properly managed (without actual DB operations)
|
||||
func TestClickHouseSink_BufferManagement(t *testing.T) {
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Correlated: true,
|
||||
}
|
||||
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 100, // Allow more than 1 element
|
||||
DropOnOverflow: false,
|
||||
TimeoutMs: 1000,
|
||||
},
|
||||
buffer: []domain.CorrelatedLog{log},
|
||||
}
|
||||
|
||||
// Verify buffer has data
|
||||
if len(s.buffer) != 1 {
|
||||
t.Fatalf("expected buffer length 1, got %d", len(s.buffer))
|
||||
}
|
||||
|
||||
// Test that Write properly adds to buffer
|
||||
ctx := context.Background()
|
||||
err := s.Write(ctx, log)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error on Write: %v", err)
|
||||
}
|
||||
|
||||
if len(s.buffer) != 2 {
|
||||
t.Errorf("expected buffer length 2 after Write, got %d", len(s.buffer))
|
||||
}
|
||||
}
|
||||
|
||||
// Test Write with context cancellation
|
||||
func TestClickHouseSink_Write_ContextCancel(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 1,
|
||||
DropOnOverflow: false,
|
||||
TimeoutMs: 10,
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0, 1),
|
||||
}
|
||||
|
||||
// Fill the buffer
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
s.buffer = append(s.buffer, log)
|
||||
|
||||
// Try to write with cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
err := s.Write(ctx, log)
|
||||
if err == nil {
|
||||
t.Error("expected error when writing with cancelled context")
|
||||
}
|
||||
}
|
||||
|
||||
// Test DropOnOverflow behavior
|
||||
func TestClickHouseSink_Write_DropOnOverflow(t *testing.T) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 1,
|
||||
DropOnOverflow: true,
|
||||
TimeoutMs: 10,
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0, 1),
|
||||
}
|
||||
|
||||
// Fill the buffer
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1", SrcPort: 8080}
|
||||
s.buffer = append(s.buffer, log)
|
||||
|
||||
// Try to write when buffer is full - should drop silently
|
||||
ctx := context.Background()
|
||||
err := s.Write(ctx, log)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error when DropOnOverflow is true, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark Write operation (without actual DB)
|
||||
func BenchmarkClickHouseSink_Write(b *testing.B) {
|
||||
s := &ClickHouseSink{
|
||||
config: Config{
|
||||
DSN: "clickhouse://test:test@localhost:9000/test",
|
||||
Table: "test_table",
|
||||
MaxBufferSize: 10000,
|
||||
DropOnOverflow: true,
|
||||
},
|
||||
buffer: make([]domain.CorrelatedLog, 0, 10000),
|
||||
}
|
||||
|
||||
log := domain.CorrelatedLog{
|
||||
Timestamp: time.Now(),
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Correlated: true,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Write(ctx, log)
|
||||
}
|
||||
}
|
||||
168
internal/adapters/outbound/file/sink.go
Normal file
168
internal/adapters/outbound/file/sink.go
Normal file
@ -0,0 +1,168 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/logcorrelator/logcorrelator/internal/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultFilePermissions for output files
|
||||
DefaultFilePermissions os.FileMode = 0644
|
||||
// DefaultDirPermissions for output directories
|
||||
DefaultDirPermissions os.FileMode = 0750
|
||||
)
|
||||
|
||||
// Config holds the file sink configuration.
|
||||
type Config struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
// FileSink writes correlated logs to a file as JSON lines.
|
||||
type FileSink struct {
|
||||
config Config
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
writer *bufio.Writer
|
||||
}
|
||||
|
||||
// NewFileSink creates a new file sink.
|
||||
func NewFileSink(config Config) (*FileSink, error) {
|
||||
// Validate path
|
||||
if err := validateFilePath(config.Path); err != nil {
|
||||
return nil, fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
|
||||
return &FileSink{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the sink name.
|
||||
func (s *FileSink) Name() string {
|
||||
return "file"
|
||||
}
|
||||
|
||||
// Write writes a correlated log to the file.
|
||||
func (s *FileSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.file == nil {
|
||||
if err := s.openFile(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal log: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.writer.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write log: %w", err)
|
||||
}
|
||||
if _, err := s.writer.WriteString("\n"); err != nil {
|
||||
return fmt.Errorf("failed to write newline: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush flushes any buffered data.
|
||||
func (s *FileSink) Flush(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.writer != nil {
|
||||
return s.writer.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the sink.
|
||||
func (s *FileSink) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.writer != nil {
|
||||
if err := s.writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if s.file != nil {
|
||||
return s.file.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileSink) openFile() error {
|
||||
// Validate path again before opening
|
||||
if err := validateFilePath(s.config.Path); err != nil {
|
||||
return fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(s.config.Path)
|
||||
if err := os.MkdirAll(dir, DefaultDirPermissions); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(s.config.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, DefaultFilePermissions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
s.file = file
|
||||
s.writer = bufio.NewWriter(file)
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFilePath validates that the file path is safe and allowed.
|
||||
func validateFilePath(path string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("path cannot be empty")
|
||||
}
|
||||
|
||||
// Clean the path
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Ensure path is absolute or relative to allowed directories
|
||||
allowedPrefixes := []string{
|
||||
"/var/log/logcorrelator",
|
||||
"/var/log",
|
||||
"/tmp",
|
||||
}
|
||||
|
||||
// Check if path is in allowed directories
|
||||
allowed := false
|
||||
for _, prefix := range allowedPrefixes {
|
||||
if strings.HasPrefix(cleanPath, prefix) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
// Allow relative paths for testing
|
||||
if !filepath.IsAbs(cleanPath) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("path must be in allowed directories: %v", allowedPrefixes)
|
||||
}
|
||||
|
||||
// Check for path traversal
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return fmt.Errorf("path cannot contain '..'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
96
internal/adapters/outbound/file/sink_test.go
Normal file
96
internal/adapters/outbound/file/sink_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/logcorrelator/logcorrelator/internal/domain"
|
||||
)
|
||||
|
||||
func TestFileSink_Write(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080,
|
||||
Correlated: true,
|
||||
}
|
||||
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
if err := sink.Flush(context.Background()); err != nil {
|
||||
t.Fatalf("failed to flush: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and contains data
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
t.Error("expected non-empty file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_MultipleWrites(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testPath := filepath.Join(tmpDir, "test.log")
|
||||
|
||||
sink, err := NewFileSink(Config{Path: testPath})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
defer sink.Close()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
log := domain.CorrelatedLog{
|
||||
SrcIP: "192.168.1.1",
|
||||
SrcPort: 8080 + i,
|
||||
}
|
||||
if err := sink.Write(context.Background(), log); err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sink.Close()
|
||||
|
||||
// Verify file has 5 lines
|
||||
data, err := os.ReadFile(testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
|
||||
lines := 0
|
||||
for _, b := range data {
|
||||
if b == '\n' {
|
||||
lines++
|
||||
}
|
||||
}
|
||||
|
||||
if lines != 5 {
|
||||
t.Errorf("expected 5 lines, got %d", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSink_Name(t *testing.T) {
|
||||
sink, err := NewFileSink(Config{Path: "/tmp/test.log"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sink: %v", err)
|
||||
}
|
||||
|
||||
if sink.Name() != "file" {
|
||||
t.Errorf("expected name 'file', got %s", sink.Name())
|
||||
}
|
||||
}
|
||||
123
internal/adapters/outbound/multi/sink.go
Normal file
123
internal/adapters/outbound/multi/sink.go
Normal file
@ -0,0 +1,123 @@
|
||||
package multi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/logcorrelator/logcorrelator/internal/domain"
|
||||
"github.com/logcorrelator/logcorrelator/internal/ports"
|
||||
)
|
||||
|
||||
// MultiSink fans out correlated logs to multiple sinks.
|
||||
type MultiSink struct {
|
||||
mu sync.RWMutex
|
||||
sinks []ports.CorrelatedLogSink
|
||||
}
|
||||
|
||||
// NewMultiSink creates a new multi-sink.
|
||||
func NewMultiSink(sinks ...ports.CorrelatedLogSink) *MultiSink {
|
||||
return &MultiSink{
|
||||
sinks: sinks,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the sink name.
|
||||
func (s *MultiSink) Name() string {
|
||||
return "multi"
|
||||
}
|
||||
|
||||
// AddSink adds a sink to the fan-out.
|
||||
func (s *MultiSink) AddSink(sink ports.CorrelatedLogSink) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sinks = append(s.sinks, sink)
|
||||
}
|
||||
|
||||
// Write writes a correlated log to all sinks concurrently.
|
||||
// Returns the first error encountered (but all sinks are attempted).
|
||||
func (s *MultiSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
s.mu.RLock()
|
||||
sinks := make([]ports.CorrelatedLogSink, len(s.sinks))
|
||||
copy(sinks, s.sinks)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if len(sinks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var firstErr error
|
||||
var firstErrMu sync.Mutex
|
||||
errChan := make(chan error, len(sinks))
|
||||
|
||||
for _, sink := range sinks {
|
||||
wg.Add(1)
|
||||
go func(sk ports.CorrelatedLogSink) {
|
||||
defer wg.Done()
|
||||
if err := sk.Write(ctx, log); err != nil {
|
||||
// Non-blocking send to errChan
|
||||
select {
|
||||
case errChan <- err:
|
||||
default:
|
||||
// Channel full, error will be handled via firstErr
|
||||
}
|
||||
}
|
||||
}(sink)
|
||||
}
|
||||
|
||||
// Wait for all writes to complete in a separate goroutine
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Collect errors with timeout
|
||||
select {
|
||||
case <-done:
|
||||
close(errChan)
|
||||
// Collect first error
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
firstErrMu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
firstErrMu.Unlock()
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
firstErrMu.Lock()
|
||||
defer firstErrMu.Unlock()
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// Flush flushes all sinks.
|
||||
func (s *MultiSink) Flush(ctx context.Context) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, sink := range s.sinks {
|
||||
if err := sink.Flush(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes all sinks.
|
||||
func (s *MultiSink) Close() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var firstErr error
|
||||
for _, sink := range s.sinks {
|
||||
if err := sink.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
114
internal/adapters/outbound/multi/sink_test.go
Normal file
114
internal/adapters/outbound/multi/sink_test.go
Normal file
@ -0,0 +1,114 @@
|
||||
package multi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/logcorrelator/logcorrelator/internal/domain"
|
||||
)
|
||||
|
||||
type mockSink struct {
|
||||
name string
|
||||
mu sync.Mutex
|
||||
writeFunc func(domain.CorrelatedLog) error
|
||||
flushFunc func() error
|
||||
closeFunc func() error
|
||||
}
|
||||
|
||||
func (m *mockSink) Name() string { return m.name }
|
||||
func (m *mockSink) Write(ctx context.Context, log domain.CorrelatedLog) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.writeFunc(log)
|
||||
}
|
||||
func (m *mockSink) Flush(ctx context.Context) error { return m.flushFunc() }
|
||||
func (m *mockSink) Close() error { return m.closeFunc() }
|
||||
|
||||
func TestMultiSink_Write(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
writeCount := 0
|
||||
|
||||
sink1 := &mockSink{
|
||||
name: "sink1",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
mu.Lock()
|
||||
writeCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
sink2 := &mockSink{
|
||||
name: "sink2",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
mu.Lock()
|
||||
writeCount++
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink1, sink2)
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(context.Background(), log)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if writeCount != 2 {
|
||||
t.Errorf("expected 2 writes, got %d", writeCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_Write_OneFails(t *testing.T) {
|
||||
sink1 := &mockSink{
|
||||
name: "sink1",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
return nil
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
sink2 := &mockSink{
|
||||
name: "sink2",
|
||||
writeFunc: func(log domain.CorrelatedLog) error {
|
||||
return context.Canceled
|
||||
},
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms := NewMultiSink(sink1, sink2)
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(context.Background(), log)
|
||||
if err == nil {
|
||||
t.Error("expected error when one sink fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSink_AddSink(t *testing.T) {
|
||||
ms := NewMultiSink()
|
||||
|
||||
sink := &mockSink{
|
||||
name: "dynamic",
|
||||
writeFunc: func(log domain.CorrelatedLog) error { return nil },
|
||||
flushFunc: func() error { return nil },
|
||||
closeFunc: func() error { return nil },
|
||||
}
|
||||
|
||||
ms.AddSink(sink)
|
||||
|
||||
log := domain.CorrelatedLog{SrcIP: "192.168.1.1"}
|
||||
err := ms.Write(context.Background(), log)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user