refactor: switch Unix sockets from STREAM to DGRAM mode
Some checks failed
Build and Test / test (push) Has been cancelled
Build and Test / build (push) Has been cancelled
Build and Test / docker (push) Has been cancelled

- Change net.Listen("unix") to net.ListenUnixgram("unixgram")
- Replace connection-based Accept() with ReadFromUnix() datagram reading
- Remove connection limiting semaphore (not needed for DGRAM)
- Update tests with datagram-specific tests
- Socket permissions default to 0666 (world read/write)

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
Jacquin Antoine
2026-03-02 22:43:10 +01:00
parent eb3cc78170
commit 15ca33ee3a
2 changed files with 211 additions and 67 deletions

View File

@ -1,7 +1,6 @@
package unixsocket
import (
"bufio"
"context"
"encoding/json"
"fmt"
@ -18,10 +17,8 @@ import (
)
const (
// Maximum line size for JSON logs (1MB)
MaxLineSize = 1024 * 1024
// Maximum concurrent connections per socket
MaxConcurrentConnections = 100
// Maximum datagram size for JSON logs (64KB - Unix datagram limit)
MaxDatagramSize = 65535
// Rate limit: max events per second
MaxEventsPerSecond = 10000
)
@ -34,14 +31,13 @@ type Config struct {
SocketPermissions os.FileMode
}
// UnixSocketSource reads JSON events from a Unix socket.
// UnixSocketSource reads JSON events from a Unix datagram socket.
type UnixSocketSource struct {
config Config
mu sync.Mutex
listener net.Listener
conn *net.UnixConn
done chan struct{}
wg sync.WaitGroup
semaphore chan struct{} // Limit concurrent connections
stopOnce sync.Once
logger *observability.Logger
}
@ -51,7 +47,6 @@ func NewUnixSocketSource(config Config) *UnixSocketSource {
return &UnixSocketSource{
config: config,
done: make(chan struct{}),
semaphore: make(chan struct{}, MaxConcurrentConnections),
logger: observability.NewLogger("unixsocket:" + config.Name),
}
}
@ -66,7 +61,7 @@ func (s *UnixSocketSource) Name() string {
return s.config.Name
}
// Start begins listening on the Unix socket.
// Start begins listening on the Unix datagram socket.
func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error {
if strings.TrimSpace(s.config.Path) == "" {
return fmt.Errorf("socket path cannot be empty")
@ -83,12 +78,17 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
}
}
// Create listener
listener, err := net.Listen("unix", s.config.Path)
// Create Unix datagram socket
addr, err := net.ResolveUnixAddr("unixgram", s.config.Path)
if err != nil {
return fmt.Errorf("failed to create unix socket listener: %w", err)
return fmt.Errorf("failed to resolve unix socket address: %w", err)
}
s.listener = listener
conn, err := net.ListenUnixgram("unixgram", addr)
if err != nil {
return fmt.Errorf("failed to create unix datagram socket: %w", err)
}
s.conn = conn
// Set permissions - fail if we can't
permissions := s.config.SocketPermissions
@ -96,7 +96,7 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
permissions = 0666 // default
}
if err := os.Chmod(s.config.Path, permissions); err != nil {
_ = listener.Close()
_ = conn.Close()
_ = os.Remove(s.config.Path)
return fmt.Errorf("failed to set socket permissions: %w", err)
}
@ -104,13 +104,15 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.acceptConnections(ctx, eventChan)
s.readDatagrams(ctx, eventChan)
}()
return nil
}
func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) {
func (s *UnixSocketSource) readDatagrams(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) {
buf := make([]byte, MaxDatagramSize)
for {
select {
case <-s.done:
@ -120,60 +122,35 @@ func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan
default:
}
conn, err := s.listener.Accept()
// Set read deadline to allow periodic context checks
_ = s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, _, err := s.conn.ReadFromUnix(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Read timeout, continue to check context
continue
}
// Other errors (e.g., closed socket)
select {
case <-s.done:
return
case <-ctx.Done():
return
default:
s.logger.Warnf("read error: %v", err)
continue
}
}
// Check semaphore for connection limiting
select {
case s.semaphore <- struct{}{}:
// Connection accepted
default:
// Too many connections, reject
_ = conn.Close()
if n == 0 {
continue
}
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
defer func() { <-s.semaphore }()
defer c.Close()
s.readEvents(ctx, c, eventChan)
}(conn)
}
}
data := make([]byte, n)
copy(data, buf[:n])
func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventChan chan<- *domain.NormalizedEvent) {
// Set read deadline to prevent hanging
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
scanner := bufio.NewScanner(conn)
// Increase buffer size limit to 1MB
buf := make([]byte, 0, 4096)
scanner.Buffer(buf, MaxLineSize)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
}
line := scanner.Bytes()
if len(line) == 0 {
continue
}
event, err := parseJSONEvent(line, s.config.SourceType)
event, err := parseJSONEvent(data, s.config.SourceType)
if err != nil {
// Log parse errors as warnings
s.logger.Warnf("parse error: %v", err)
@ -360,8 +337,8 @@ func (s *UnixSocketSource) Stop() error {
close(s.done)
if s.listener != nil {
_ = s.listener.Close()
if s.conn != nil {
_ = s.conn.Close()
}
s.wg.Wait()

View File

@ -2,6 +2,9 @@ package unixsocket
import (
"context"
"fmt"
"net"
"os"
"testing"
"time"
@ -391,3 +394,167 @@ func TestParseJSONEvent_TimestampFallback(t *testing.T) {
t.Error("expected non-zero timestamp")
}
}
func TestUnixSocketSource_StartStopDatagram(t *testing.T) {
tmpPath := "/tmp/test_logcorrelator_datagram.sock"
// Clean up any existing socket
os.Remove(tmpPath)
source := NewUnixSocketSource(Config{
Name: "test_datagram",
Path: tmpPath,
SourceType: "B",
SocketPermissions: 0666,
})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
eventChan := make(chan *domain.NormalizedEvent, 10)
err := source.Start(ctx, eventChan)
if err != nil {
t.Fatalf("failed to start source: %v", err)
}
// Give socket time to start
time.Sleep(100 * time.Millisecond)
// Verify socket file exists
if _, err := os.Stat(tmpPath); os.IsNotExist(err) {
t.Error("socket file should exist")
}
// Stop the source
err = source.Stop()
if err != nil {
t.Errorf("failed to stop source: %v", err)
}
// Socket file should be cleaned up
time.Sleep(100 * time.Millisecond)
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
t.Error("socket file should be removed after stop")
}
}
func TestUnixSocketSource_SendDatagram(t *testing.T) {
tmpPath := "/tmp/test_logcorrelator_send.sock"
os.Remove(tmpPath)
source := NewUnixSocketSource(Config{
Name: "test_send",
Path: tmpPath,
SourceType: "B",
SocketPermissions: 0666,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
eventChan := make(chan *domain.NormalizedEvent, 10)
err := source.Start(ctx, eventChan)
if err != nil {
t.Fatalf("failed to start source: %v", err)
}
// Give socket time to start
time.Sleep(100 * time.Millisecond)
// Connect and send a datagram
conn, err := net.Dial("unixgram", tmpPath)
if err != nil {
t.Fatalf("failed to dial socket: %v", err)
}
defer conn.Close()
data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "test"}`)
_, err = conn.Write(data)
if err != nil {
t.Fatalf("failed to write: %v", err)
}
// Wait for event
select {
case event := <-eventChan:
if event.SrcIP != "192.168.1.1" {
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
}
if event.SrcPort != 8080 {
t.Errorf("expected src_port 8080, got %d", event.SrcPort)
}
case <-time.After(2 * time.Second):
t.Error("timeout waiting for event")
case <-ctx.Done():
t.Error("context cancelled")
}
err = source.Stop()
if err != nil {
t.Errorf("failed to stop source: %v", err)
}
}
func TestUnixSocketSource_MultipleDatagrams(t *testing.T) {
tmpPath := "/tmp/test_logcorrelator_multi.sock"
os.Remove(tmpPath)
source := NewUnixSocketSource(Config{
Name: "test_multi",
Path: tmpPath,
SourceType: "B",
SocketPermissions: 0666,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
eventChan := make(chan *domain.NormalizedEvent, 100)
err := source.Start(ctx, eventChan)
if err != nil {
t.Fatalf("failed to start source: %v", err)
}
// Give socket time to start
time.Sleep(100 * time.Millisecond)
// Connect and send multiple datagrams
conn, err := net.Dial("unixgram", tmpPath)
if err != nil {
t.Fatalf("failed to dial socket: %v", err)
}
defer conn.Close()
for i := 0; i < 5; i++ {
data := []byte(fmt.Sprintf(`{"src_ip": "192.168.1.%d", "src_port": %d, "ja3": "test%d"}`, i+1, 8080+i, i))
_, err = conn.Write(data)
if err != nil {
t.Fatalf("failed to write datagram %d: %v", i, err)
}
}
// Wait for all events
received := 0
timeout := time.After(3 * time.Second)
for received < 5 {
select {
case event := <-eventChan:
received++
t.Logf("received event %d: src_ip=%s", received, event.SrcIP)
case <-timeout:
t.Errorf("timeout waiting for events, received %d/5", received)
goto done
case <-ctx.Done():
t.Error("context cancelled")
goto done
}
}
done:
err = source.Stop()
if err != nil {
t.Errorf("failed to stop source: %v", err)
}
}