refactor: switch Unix sockets from STREAM to DGRAM mode
- Change net.Listen("unix") to net.ListenUnixgram("unixgram")
- Replace connection-based Accept() with ReadFromUnix() datagram reading
- Remove connection limiting semaphore (not needed for DGRAM)
- Update tests with datagram-specific tests
- Socket permissions default to 0666 (world read/write)
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
package unixsocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -18,10 +17,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum line size for JSON logs (1MB)
|
||||
MaxLineSize = 1024 * 1024
|
||||
// Maximum concurrent connections per socket
|
||||
MaxConcurrentConnections = 100
|
||||
// Maximum datagram size for JSON logs (64KB - Unix datagram limit)
|
||||
MaxDatagramSize = 65535
|
||||
// Rate limit: max events per second
|
||||
MaxEventsPerSecond = 10000
|
||||
)
|
||||
@ -34,14 +31,13 @@ type Config struct {
|
||||
SocketPermissions os.FileMode
|
||||
}
|
||||
|
||||
// UnixSocketSource reads JSON events from a Unix socket.
|
||||
// UnixSocketSource reads JSON events from a Unix datagram socket.
|
||||
type UnixSocketSource struct {
|
||||
config Config
|
||||
mu sync.Mutex
|
||||
listener net.Listener
|
||||
conn *net.UnixConn
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
semaphore chan struct{} // Limit concurrent connections
|
||||
stopOnce sync.Once
|
||||
logger *observability.Logger
|
||||
}
|
||||
@ -51,7 +47,6 @@ func NewUnixSocketSource(config Config) *UnixSocketSource {
|
||||
return &UnixSocketSource{
|
||||
config: config,
|
||||
done: make(chan struct{}),
|
||||
semaphore: make(chan struct{}, MaxConcurrentConnections),
|
||||
logger: observability.NewLogger("unixsocket:" + config.Name),
|
||||
}
|
||||
}
|
||||
@ -66,7 +61,7 @@ func (s *UnixSocketSource) Name() string {
|
||||
return s.config.Name
|
||||
}
|
||||
|
||||
// Start begins listening on the Unix socket.
|
||||
// Start begins listening on the Unix datagram socket.
|
||||
func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) error {
|
||||
if strings.TrimSpace(s.config.Path) == "" {
|
||||
return fmt.Errorf("socket path cannot be empty")
|
||||
@ -83,12 +78,17 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
|
||||
}
|
||||
}
|
||||
|
||||
// Create listener
|
||||
listener, err := net.Listen("unix", s.config.Path)
|
||||
// Create Unix datagram socket
|
||||
addr, err := net.ResolveUnixAddr("unixgram", s.config.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create unix socket listener: %w", err)
|
||||
return fmt.Errorf("failed to resolve unix socket address: %w", err)
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
conn, err := net.ListenUnixgram("unixgram", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create unix datagram socket: %w", err)
|
||||
}
|
||||
s.conn = conn
|
||||
|
||||
// Set permissions - fail if we can't
|
||||
permissions := s.config.SocketPermissions
|
||||
@ -96,7 +96,7 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
|
||||
permissions = 0666 // default
|
||||
}
|
||||
if err := os.Chmod(s.config.Path, permissions); err != nil {
|
||||
_ = listener.Close()
|
||||
_ = conn.Close()
|
||||
_ = os.Remove(s.config.Path)
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
@ -104,13 +104,15 @@ func (s *UnixSocketSource) Start(ctx context.Context, eventChan chan<- *domain.N
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.acceptConnections(ctx, eventChan)
|
||||
s.readDatagrams(ctx, eventChan)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) {
|
||||
func (s *UnixSocketSource) readDatagrams(ctx context.Context, eventChan chan<- *domain.NormalizedEvent) {
|
||||
buf := make([]byte, MaxDatagramSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
@ -120,60 +122,35 @@ func (s *UnixSocketSource) acceptConnections(ctx context.Context, eventChan chan
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
// Set read deadline to allow periodic context checks
|
||||
_ = s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
|
||||
n, _, err := s.conn.ReadFromUnix(buf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Read timeout, continue to check context
|
||||
continue
|
||||
}
|
||||
// Other errors (e.g., closed socket)
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.logger.Warnf("read error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check semaphore for connection limiting
|
||||
select {
|
||||
case s.semaphore <- struct{}{}:
|
||||
// Connection accepted
|
||||
default:
|
||||
// Too many connections, reject
|
||||
_ = conn.Close()
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer func() { <-s.semaphore }()
|
||||
defer c.Close()
|
||||
s.readEvents(ctx, c, eventChan)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
data := make([]byte, n)
|
||||
copy(data, buf[:n])
|
||||
|
||||
func (s *UnixSocketSource) readEvents(ctx context.Context, conn net.Conn, eventChan chan<- *domain.NormalizedEvent) {
|
||||
// Set read deadline to prevent hanging
|
||||
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
|
||||
|
||||
scanner := bufio.NewScanner(conn)
|
||||
// Increase buffer size limit to 1MB
|
||||
buf := make([]byte, 0, 4096)
|
||||
scanner.Buffer(buf, MaxLineSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
event, err := parseJSONEvent(line, s.config.SourceType)
|
||||
event, err := parseJSONEvent(data, s.config.SourceType)
|
||||
if err != nil {
|
||||
// Log parse errors as warnings
|
||||
s.logger.Warnf("parse error: %v", err)
|
||||
@ -360,8 +337,8 @@ func (s *UnixSocketSource) Stop() error {
|
||||
|
||||
close(s.done)
|
||||
|
||||
if s.listener != nil {
|
||||
_ = s.listener.Close()
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
@ -2,6 +2,9 @@ package unixsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -391,3 +394,167 @@ func TestParseJSONEvent_TimestampFallback(t *testing.T) {
|
||||
t.Error("expected non-zero timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_StartStopDatagram(t *testing.T) {
|
||||
tmpPath := "/tmp/test_logcorrelator_datagram.sock"
|
||||
// Clean up any existing socket
|
||||
os.Remove(tmpPath)
|
||||
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_datagram",
|
||||
Path: tmpPath,
|
||||
SourceType: "B",
|
||||
SocketPermissions: 0666,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
eventChan := make(chan *domain.NormalizedEvent, 10)
|
||||
|
||||
err := source.Start(ctx, eventChan)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start source: %v", err)
|
||||
}
|
||||
|
||||
// Give socket time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify socket file exists
|
||||
if _, err := os.Stat(tmpPath); os.IsNotExist(err) {
|
||||
t.Error("socket file should exist")
|
||||
}
|
||||
|
||||
// Stop the source
|
||||
err = source.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("failed to stop source: %v", err)
|
||||
}
|
||||
|
||||
// Socket file should be cleaned up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := os.Stat(tmpPath); !os.IsNotExist(err) {
|
||||
t.Error("socket file should be removed after stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_SendDatagram(t *testing.T) {
|
||||
tmpPath := "/tmp/test_logcorrelator_send.sock"
|
||||
os.Remove(tmpPath)
|
||||
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_send",
|
||||
Path: tmpPath,
|
||||
SourceType: "B",
|
||||
SocketPermissions: 0666,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
eventChan := make(chan *domain.NormalizedEvent, 10)
|
||||
|
||||
err := source.Start(ctx, eventChan)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start source: %v", err)
|
||||
}
|
||||
|
||||
// Give socket time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect and send a datagram
|
||||
conn, err := net.Dial("unixgram", tmpPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial socket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
data := []byte(`{"src_ip": "192.168.1.1", "src_port": 8080, "ja3": "test"}`)
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write: %v", err)
|
||||
}
|
||||
|
||||
// Wait for event
|
||||
select {
|
||||
case event := <-eventChan:
|
||||
if event.SrcIP != "192.168.1.1" {
|
||||
t.Errorf("expected src_ip 192.168.1.1, got %s", event.SrcIP)
|
||||
}
|
||||
if event.SrcPort != 8080 {
|
||||
t.Errorf("expected src_port 8080, got %d", event.SrcPort)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("timeout waiting for event")
|
||||
case <-ctx.Done():
|
||||
t.Error("context cancelled")
|
||||
}
|
||||
|
||||
err = source.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("failed to stop source: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketSource_MultipleDatagrams(t *testing.T) {
|
||||
tmpPath := "/tmp/test_logcorrelator_multi.sock"
|
||||
os.Remove(tmpPath)
|
||||
|
||||
source := NewUnixSocketSource(Config{
|
||||
Name: "test_multi",
|
||||
Path: tmpPath,
|
||||
SourceType: "B",
|
||||
SocketPermissions: 0666,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
eventChan := make(chan *domain.NormalizedEvent, 100)
|
||||
|
||||
err := source.Start(ctx, eventChan)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start source: %v", err)
|
||||
}
|
||||
|
||||
// Give socket time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect and send multiple datagrams
|
||||
conn, err := net.Dial("unixgram", tmpPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial socket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
data := []byte(fmt.Sprintf(`{"src_ip": "192.168.1.%d", "src_port": %d, "ja3": "test%d"}`, i+1, 8080+i, i))
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write datagram %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all events
|
||||
received := 0
|
||||
timeout := time.After(3 * time.Second)
|
||||
for received < 5 {
|
||||
select {
|
||||
case event := <-eventChan:
|
||||
received++
|
||||
t.Logf("received event %d: src_ip=%s", received, event.SrcIP)
|
||||
case <-timeout:
|
||||
t.Errorf("timeout waiting for events, received %d/5", received)
|
||||
goto done
|
||||
case <-ctx.Done():
|
||||
t.Error("context cancelled")
|
||||
goto done
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
err = source.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("failed to stop source: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user