diff --git a/internal/adapters/inbound/unixsocket/source.go b/internal/adapters/inbound/unixsocket/source.go index 7b134a8..6ff8fba 100644 --- a/internal/adapters/inbound/unixsocket/source.go +++ b/internal/adapters/inbound/unixsocket/source.go @@ -1,7 +1,6 @@ package unixsocket import ( - "bufio" "context" "encoding/json" "fmt" @@ -18,10 +17,8 @@ import ( ) const ( - // Maximum line size for JSON logs (1MB) - MaxLineSize = 1024 * 1024 - // Maximum concurrent connections per socket - MaxConcurrentConnections = 100 + // Maximum datagram size for JSON logs (64KB - Unix datagram limit) + MaxDatagramSize = 65535 // Rate limit: max events per second MaxEventsPerSecond = 10000 ) @@ -34,25 +31,23 @@ type Config struct { SocketPermissions os.FileMode } -// UnixSocketSource reads JSON events from a Unix socket. +// UnixSocketSource reads JSON events from a Unix datagram socket. type UnixSocketSource struct { - config Config - mu sync.Mutex - listener net.Listener - done chan struct{} - wg sync.WaitGroup - semaphore chan struct{} // Limit concurrent connections - stopOnce sync.Once - logger *observability.Logger + config Config + mu sync.Mutex + conn *net.UnixConn + done chan struct{} + wg sync.WaitGroup + stopOnce sync.Once + logger *observability.Logger } // NewUnixSocketSource creates a new Unix socket source. func NewUnixSocketSource(config Config) *UnixSocketSource { return &UnixSocketSource{ - config: config, - done: make(chan struct{}), - semaphore: make(chan struct{}, MaxConcurrentConnections), - logger: observability.NewLogger("unixsocket:" + config.Name), + config: config, + done: make(chan struct{}), + logger: observability.NewLogger("unixsocket:" + config.Name), } } @@ -66,7 +61,7 @@ func (s *UnixSocketSource) Name() string { return s.config.Name } -// Start begins listening on the Unix socket. +// Start begins listening on the Unix datagram socket. func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error { if strings.TrimSpace(s.config.Path) == "" { return fmt.Errorf("socket path cannot be empty") @@ -83,12 +78,17 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N } } - // Create listener - listener, err := net.Listen("unix", s.config.Path) + // Create Unix datagram socket + addr, err := net.ResolveUnixAddr("unixgram", s.config.Path) if err != nil { - return fmt.Errorf("failed to create unix socket listener: %w", err) + return fmt.Errorf("failed to resolve unix socket address: %w", err) } - s.listener = listener + + conn, err := net.ListenUnixgram("unixgram", addr) + if err != nil { + return fmt.Errorf("failed to create unix datagram socket: %w", err) + } + s.conn = conn // Set permissions - fail if we can't permissions := s.config.SocketPermissions @@ -96,7 +96,7 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N permissions = 0666 // default } if err := os.Chmod(s.config.Path, permissions); err != nil { - _ = listener.Close() + _ = conn.Close() _ = os.Remove(s.config.Path) return fmt.Errorf("failed to set socket permissions: %w", err) } @@ -104,13 +104,15 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N s.wg.Add(1) go func() { defer s.wg.Done() - s.acceptConnections(ctx, eventChan) + s.readDatagrams(ctx, eventChan) }() return nil } -func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) { +func (s *UnixSocketSource) readDatagrams(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) { + buf := make([]byte, MaxDatagramSize) + for { select { case <-s.done: @@ -120,60 +122,35 @@ func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan default: } - conn, err := s.listener.Accept() + // Set read deadline to allow periodic context checks + _ = s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + + n, _, err := s.conn.ReadFromUnix(buf) if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Read timeout, continue to check context + continue + } + // Other errors (e.g., closed socket) select { case <-s.done: return case <-ctx.Done(): return default: + s.logger.Warnf("read error: %v", err) continue } } - // Check semaphore for connection limiting - select { - case s.semaphore <- struct{}{}: - // Connection accepted - default: - // Too many connections, reject - _ = conn.Close() + if n == 0 { continue } - s.wg.Add(1) - go func(c net.Conn) { - defer s.wg.Done() - defer func() { <-s.semaphore }() - defer c.Close() - s.readEvents(ctx, c, eventChan) - }(conn) - } -} + data := make([]byte, n) + copy(data, buf[:n]) -func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventChan chan<- *domain.NormalizedEvent) { - // Set read deadline to prevent hanging - _ = conn.SetReadDeadline(time.Now().Add(5 * time.Minute)) - - scanner := bufio.NewScanner(conn) - // Increase buffer size limit to 1MB - buf := make([]byte, 0, 4096) - scanner.Buffer(buf, MaxLineSize) - - for scanner.Scan() { - select { - case <-ctx.Done(): - return - default: - } - - line := scanner.Bytes() - if len(line) == 0 { - continue - } - - event, err := parseJSONEvent(line, s.config.SourceType) + event, err := parseJSONEvent(data, s.config.SourceType) if err != nil { // Log parse errors as warnings s.logger.Warnf("parse error: %v", err) @@ -360,8 +337,8 @@ func (s *UnixSocketSource) Stop() error { close(s.done) - if s.listener != nil { - _ = s.listener.Close() + if s.conn != nil { + _ = s.conn.Close() } s.wg.Wait() diff --git a/internal/adapters/inbound/unixsocket/source_test.go b/internal/adapters/inbound/unixsocket/source_test.go index d706daa..f83e745 100644 --- a/internal/adapters/inbound/unixsocket/source_test.go +++ b/internal/adapters/inbound/unixsocket/source_test.go @@ -2,6 +2,9 @@ package unixsocket import ( "context" + "fmt" + "net" + "os" "testing" "time" @@ -391,3 +394,167 @@ func TestParseJSONEvent_TimestampFallback(t *testing.T) { t.Error("expected non-zero timestamp") } } + +func TestUnixSocketSource_StartStopDatagram(t *testing.T) { + tmpPath := "/tmp/test_logcorrelator_datagram.sock" + // Clean up any existing socket + os.Remove(tmpPath) + + source := NewUnixSocketSource(Config{ + Name: "test_datagram", + Path: tmpPath, + SourceType: "B", + SocketPermissions: 0666, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + eventChan := make(chan *domain.NormalizedEvent, 10) + + err := source.Start(ctx, eventChan) + if err != nil { + t.Fatalf("failed to start source: %v", err) + } + + // Give socket time to start + time.Sleep(100 * time.Millisecond) + + // Verify socket file exists + if _, err := os.Stat(tmpPath); os.IsNotExist(err) { + t.Error("socket file should exist") + } + + // Stop the source + err = source.Stop() + if err != nil { + t.Errorf("failed to stop source: %v", err) + } + + // Socket file should be cleaned up + time.Sleep(100 * time.Millisecond) + if _, err := os.Stat(tmpPath); !os.IsNotExist(err) { + t.Error("socket file should be removed after stop") + } +} + +func TestUnixSocketSource_SendDatagram(t *testing.T) { + tmpPath := "/tmp/test_logcorrelator_send.sock" + os.Remove(tmpPath) + + source := NewUnixSocketSource(Config{ + Name: "test_send", + Path: tmpPath, + SourceType: "B", + SocketPermissions: 0666, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + eventChan := make(chan *domain.NormalizedEvent, 10) + + err := source.Start(ctx, eventChan) + if err != nil { + t.Fatalf("failed to start source: %v", err) + } + + // Give socket time to start + time.Sleep(100 * time.Millisecond) + + // Connect and send a datagram + conn, err := net.Dial("unixgram", tmpPath) + if err != nil { + t.Fatalf("failed to dial socket: %v", err) + } + defer conn.Close() + + data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "test"}`) + _, err = conn.Write(data) + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + // Wait for event + select { + case event := <-eventChan: + if event.SrcIP != "192.168.1.1" { + t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP) + } + if event.SrcPort != 8080 { + t.Errorf("expected src_port 8080, got %d", event.SrcPort) + } + case <-time.After(2 * time.Second): + t.Error("timeout waiting for event") + case <-ctx.Done(): + t.Error("context cancelled") + } + + err = source.Stop() + if err != nil { + t.Errorf("failed to stop source: %v", err) + } +} + +func TestUnixSocketSource_MultipleDatagrams(t *testing.T) { + tmpPath := "/tmp/test_logcorrelator_multi.sock" + os.Remove(tmpPath) + + source := NewUnixSocketSource(Config{ + Name: "test_multi", + Path: tmpPath, + SourceType: "B", + SocketPermissions: 0666, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + eventChan := make(chan *domain.NormalizedEvent, 100) + + err := source.Start(ctx, eventChan) + if err != nil { + t.Fatalf("failed to start source: %v", err) + } + + // Give socket time to start + time.Sleep(100 * time.Millisecond) + + // Connect and send multiple datagrams + conn, err := net.Dial("unixgram", tmpPath) + if err != nil { + t.Fatalf("failed to dial socket: %v", err) + } + defer conn.Close() + + for i := 0; i < 5; i++ { + data := []byte(fmt.Sprintf(`{"src_ip": "192.168.1.%d", "src_port": %d, "ja3": "test%d"}`, i+1, 8080+i, i)) + _, err = conn.Write(data) + if err != nil { + t.Fatalf("failed to write datagram %d: %v", i, err) + } + } + + // Wait for all events + received := 0 + timeout := time.After(3 * time.Second) + for received < 5 { + select { + case event := <-eventChan: + received++ + t.Logf("received event %d: src_ip=%s", received, event.SrcIP) + case <-timeout: + t.Errorf("timeout waiting for events, received %d/5", received) + goto done + case <-ctx.Done(): + t.Error("context cancelled") + goto done + } + } + +done: + err = source.Stop() + if err != nil { + t.Errorf("failed to stop source: %v", err) + } +}