refactor: switch Unix sockets from STREAM to DGRAM mode
Some checks failed
Build and Test / test (push) Has been cancelled
Build and Test / build (push) Has been cancelled
Build and Test / docker (push) Has been cancelled

- Change net.Listen("unix") to net.ListenUnixgram("unixgram")
- Replace connection-based Accept() with ReadFromUnix() datagram reading
- Remove connection limiting semaphore (not needed for DGRAM)
- Update tests with datagram-specific tests
- Socket permissions default to 0666 (world read/write)

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
Jacquin Antoine
2026-03-02 22:43:10 +01:00
parent eb3cc78170
commit 15ca33ee3a
2 changed files with 211 additions and 67 deletions

View File

@ -1,7 +1,6 @@
package unixsocket package unixsocket
import ( import (
"bufio"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -18,10 +17,8 @@ import (
) )
const ( const (
// Maximum line size for JSON logs (1MB) // Maximum datagram size for JSON logs (64KB - Unix datagram limit)
MaxLineSize = 1024 * 1024 MaxDatagramSize = 65535
// Maximum concurrent connections per socket
MaxConcurrentConnections = 100
// Rate limit: max events per second // Rate limit: max events per second
MaxEventsPerSecond = 10000 MaxEventsPerSecond = 10000
) )
@ -34,14 +31,13 @@ type Config struct {
SocketPermissions os.FileMode SocketPermissions os.FileMode
} }
// UnixSocketSource reads JSON events from a Unix socket. // UnixSocketSource reads JSON events from a Unix datagram socket.
type UnixSocketSource struct { type UnixSocketSource struct {
config Config config Config
mu sync.Mutex mu sync.Mutex
listener net.Listener conn *net.UnixConn
done chan struct{} done chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
semaphore chan struct{} // Limit concurrent connections
stopOnce sync.Once stopOnce sync.Once
logger *observability.Logger logger *observability.Logger
} }
@ -51,7 +47,6 @@ func NewUnixSocketSource(config Config) *UnixSocketSource {
return &UnixSocketSource{ return &UnixSocketSource{
config: config, config: config,
done: make(chan struct{}), done: make(chan struct{}),
semaphore: make(chan struct{}, MaxConcurrentConnections),
logger: observability.NewLogger("unixsocket:" + config.Name), logger: observability.NewLogger("unixsocket:" + config.Name),
} }
} }
@ -66,7 +61,7 @@ func (s *UnixSocketSource) Name() string {
return s.config.Name return s.config.Name
} }
// Start begins listening on the Unix socket. // Start begins listening on the Unix datagram socket.
func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error { func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error {
if strings.TrimSpace(s.config.Path) == "" { if strings.TrimSpace(s.config.Path) == "" {
return fmt.Errorf("socket path cannot be empty") return fmt.Errorf("socket path cannot be empty")
@ -83,12 +78,17 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
} }
} }
// Create listener // Create Unix datagram socket
listener, err := net.Listen("unix", s.config.Path) addr, err := net.ResolveUnixAddr("unixgram", s.config.Path)
if err != nil { if err != nil {
return fmt.Errorf("failed to create unix socket listener: %w", err) return fmt.Errorf("failed to resolve unix socket address: %w", err)
} }
s.listener = listener
conn, err := net.ListenUnixgram("unixgram", addr)
if err != nil {
return fmt.Errorf("failed to create unix datagram socket: %w", err)
}
s.conn = conn
// Set permissions - fail if we can't // Set permissions - fail if we can't
permissions := s.config.SocketPermissions permissions := s.config.SocketPermissions
@ -96,7 +96,7 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
permissions = 0666 // default permissions = 0666 // default
} }
if err := os.Chmod(s.config.Path, permissions); err != nil { if err := os.Chmod(s.config.Path, permissions); err != nil {
_ = listener.Close() _ = conn.Close()
_ = os.Remove(s.config.Path) _ = os.Remove(s.config.Path)
return fmt.Errorf("failed to set socket permissions: %w", err) return fmt.Errorf("failed to set socket permissions: %w", err)
} }
@ -104,13 +104,15 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
defer s.wg.Done() defer s.wg.Done()
s.acceptConnections(ctx, eventChan) s.readDatagrams(ctx, eventChan)
}() }()
return nil return nil
} }
func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) { func (s *UnixSocketSource) readDatagrams(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) {
buf := make([]byte, MaxDatagramSize)
for { for {
select { select {
case <-s.done: case <-s.done:
@ -120,60 +122,35 @@ func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan
default: default:
} }
conn, err := s.listener.Accept() // Set read deadline to allow periodic context checks
_ = s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, _, err := s.conn.ReadFromUnix(buf)
if err != nil { if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Read timeout, continue to check context
continue
}
// Other errors (e.g., closed socket)
select { select {
case <-s.done: case <-s.done:
return return
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
s.logger.Warnf("read error: %v", err)
continue continue
} }
} }
// Check semaphore for connection limiting if n == 0 {
select {
case s.semaphore <- struct{}{}:
// Connection accepted
default:
// Too many connections, reject
_ = conn.Close()
continue continue
} }
s.wg.Add(1) data := make([]byte, n)
go func(c net.Conn) { copy(data, buf[:n])
defer s.wg.Done()
defer func() { <-s.semaphore }()
defer c.Close()
s.readEvents(ctx, c, eventChan)
}(conn)
}
}
func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventChan chan<- *domain.NormalizedEvent) { event, err := parseJSONEvent(data, s.config.SourceType)
// Set read deadline to prevent hanging
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
scanner := bufio.NewScanner(conn)
// Increase buffer size limit to 1MB
buf := make([]byte, 0, 4096)
scanner.Buffer(buf, MaxLineSize)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
line := scanner.Bytes()
if len(line) == 0 {
continue
}
event, err := parseJSONEvent(line, s.config.SourceType)
if err != nil { if err != nil {
// Log parse errors as warnings // Log parse errors as warnings
s.logger.Warnf("parse error: %v", err) s.logger.Warnf("parse error: %v", err)
@ -360,8 +337,8 @@ func (s *UnixSocketSource) Stop() error {
close(s.done) close(s.done)
if s.listener != nil { if s.conn != nil {
_ = s.listener.Close() _ = s.conn.Close()
} }
s.wg.Wait() s.wg.Wait()

View File

@ -2,6 +2,9 @@ package unixsocket
import ( import (
"context" "context"
"fmt"
"net"
"os"
"testing" "testing"
"time" "time"
@ -391,3 +394,167 @@ func TestParseJSONEvent_TimestampFallback(t *testing.T) {
t.Error("expected non-zero timestamp") t.Error("expected non-zero timestamp")
} }
} }
func TestUnixSocketSource_StartStopDatagram(t *testing.T) {
tmpPath := "/tmp/test_logcorrelator_datagram.sock"
// Clean up any existing socket
os.Remove(tmpPath)
source := NewUnixSocketSource(Config{
Name: "test_datagram",
Path: tmpPath,
SourceType: "B",
SocketPermissions: 0666,
})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
eventChan := make(chan *domain.NormalizedEvent, 10)
err := source.Start(ctx, eventChan)
if err != nil {
t.Fatalf("failed to start source: %v", err)
}
// Give socket time to start
time.Sleep(100 * time.Millisecond)
// Verify socket file exists
if _, err := os.Stat(tmpPath); os.IsNotExist(err) {
t.Error("socket file should exist")
}
// Stop the source
err = source.Stop()
if err != nil {
t.Errorf("failed to stop source: %v", err)
}
// Socket file should be cleaned up
time.Sleep(100 * time.Millisecond)
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
t.Error("socket file should be removed after stop")
}
}
func TestUnixSocketSource_SendDatagram(t *testing.T) {
tmpPath := "/tmp/test_logcorrelator_send.sock"
os.Remove(tmpPath)
source := NewUnixSocketSource(Config{
Name: "test_send",
Path: tmpPath,
SourceType: "B",
SocketPermissions: 0666,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
eventChan := make(chan *domain.NormalizedEvent, 10)
err := source.Start(ctx, eventChan)
if err != nil {
t.Fatalf("failed to start source: %v", err)
}
// Give socket time to start
time.Sleep(100 * time.Millisecond)
// Connect and send a datagram
conn, err := net.Dial("unixgram", tmpPath)
if err != nil {
t.Fatalf("failed to dial socket: %v", err)
}
defer conn.Close()
data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "test"}`)
_, err = conn.Write(data)
if err != nil {
t.Fatalf("failed to write: %v", err)
}
// Wait for event
select {
case event := <-eventChan:
if event.SrcIP != "192.168.1.1" {
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
}
if event.SrcPort != 8080 {
t.Errorf("expected src_port 8080, got %d", event.SrcPort)
}
case <-time.After(2 * time.Second):
t.Error("timeout waiting for event")
case <-ctx.Done():
t.Error("context cancelled")
}
err = source.Stop()
if err != nil {
t.Errorf("failed to stop source: %v", err)
}
}
func TestUnixSocketSource_MultipleDatagrams(t *testing.T) {
tmpPath := "/tmp/test_logcorrelator_multi.sock"
os.Remove(tmpPath)
source := NewUnixSocketSource(Config{
Name: "test_multi",
Path: tmpPath,
SourceType: "B",
SocketPermissions: 0666,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
eventChan := make(chan *domain.NormalizedEvent, 100)
err := source.Start(ctx, eventChan)
if err != nil {
t.Fatalf("failed to start source: %v", err)
}
// Give socket time to start
time.Sleep(100 * time.Millisecond)
// Connect and send multiple datagrams
conn, err := net.Dial("unixgram", tmpPath)
if err != nil {
t.Fatalf("failed to dial socket: %v", err)
}
defer conn.Close()
for i := 0; i < 5; i++ {
data := []byte(fmt.Sprintf(`{"src_ip": "192.168.1.%d", "src_port": %d, "ja3": "test%d"}`, i+1, 8080+i, i))
_, err = conn.Write(data)
if err != nil {
t.Fatalf("failed to write datagram %d: %v", i, err)
}
}
// Wait for all events
received := 0
timeout := time.After(3 * time.Second)
for received < 5 {
select {
case event := <-eventChan:
received++
t.Logf("received event %d: src_ip=%s", received, event.SrcIP)
case <-timeout:
t.Errorf("timeout waiting for events, received %d/5", received)
goto done
case <-ctx.Done():
t.Error("context cancelled")
goto done
}
}
done:
err = source.Stop()
if err != nil {
t.Errorf("failed to stop source: %v", err)
}
}