fix: renforcer limites TLS, timeouts socket et validation config
Some checks failed
Build RPM Package / Build RPM Packages (CentOS 7, Rocky 8/9/10) (push) Has been cancelled
Some checks failed
Build RPM Package / Build RPM Packages (CentOS 7, Rocky 8/9/10) (push) Has been cancelled
Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) <aider@aider.chat>
This commit is contained in:
@ -191,7 +191,13 @@ func main() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if closer, ok := outputWriter.(interface{ Close() error }); ok {
|
if mw, ok := outputWriter.(interface{ CloseAll() error }); ok {
|
||||||
|
if err := mw.CloseAll(); err != nil {
|
||||||
|
appLogger.Error("main", "Failed to close output writers", map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if closer, ok := outputWriter.(interface{ Close() error }); ok {
|
||||||
if err := closer.Close(); err != nil {
|
if err := closer.Close(); err != nil {
|
||||||
appLogger.Error("main", "Failed to close output writer", map[string]string{
|
appLogger.Error("main", "Failed to close output writer", map[string]string{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
|
|||||||
@ -38,7 +38,7 @@ func (l *LoaderImpl) Load() (api.AppConfig, error) {
|
|||||||
fileConfig, err := l.loadFromFile(path)
|
fileConfig, err := l.loadFromFile(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
config = mergeConfigs(config, fileConfig)
|
config = mergeConfigs(config, fileConfig)
|
||||||
} else if !( !explicit && errors.Is(err, os.ErrNotExist)) {
|
} else if !(!explicit && errors.Is(err, os.ErrNotExist)) {
|
||||||
return config, fmt.Errorf("failed to load config file: %w", err)
|
return config, fmt.Errorf("failed to load config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,6 +115,7 @@ func parsePorts(s string) []uint16 {
|
|||||||
|
|
||||||
parts := strings.Split(s, ",")
|
parts := strings.Split(s, ",")
|
||||||
ports := make([]uint16, 0, len(parts))
|
ports := make([]uint16, 0, len(parts))
|
||||||
|
seen := make(map[uint16]struct{}, len(parts))
|
||||||
|
|
||||||
for _, part := range parts {
|
for _, part := range parts {
|
||||||
part = strings.TrimSpace(part)
|
part = strings.TrimSpace(part)
|
||||||
@ -123,9 +124,19 @@ func parsePorts(s string) []uint16 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
port, err := strconv.ParseUint(part, 10, 16)
|
port, err := strconv.ParseUint(part, 10, 16)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
ports = append(ports, uint16(port))
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p := uint16(port)
|
||||||
|
if p == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[p]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
ports = append(ports, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ports
|
return ports
|
||||||
@ -164,19 +175,53 @@ func mergeConfigs(base, override api.AppConfig) api.AppConfig {
|
|||||||
|
|
||||||
// validate checks if the configuration is valid
|
// validate checks if the configuration is valid
|
||||||
func (l *LoaderImpl) validate(config api.AppConfig) error {
|
func (l *LoaderImpl) validate(config api.AppConfig) error {
|
||||||
if config.Core.Interface == "" {
|
if strings.TrimSpace(config.Core.Interface) == "" {
|
||||||
return fmt.Errorf("interface cannot be empty")
|
return fmt.Errorf("interface cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Core.ListenPorts) == 0 {
|
if len(config.Core.ListenPorts) == 0 {
|
||||||
return fmt.Errorf("at least one listen port is required")
|
return fmt.Errorf("at least one listen port is required")
|
||||||
}
|
}
|
||||||
|
for _, p := range config.Core.ListenPorts {
|
||||||
|
if p == 0 {
|
||||||
|
return fmt.Errorf("listen port 0 is invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Core.FlowTimeoutSec <= 0 || config.Core.FlowTimeoutSec > 300 {
|
||||||
|
return fmt.Errorf("flow_timeout_sec must be between 1 and 300")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Core.PacketBufferSize <= 0 || config.Core.PacketBufferSize > 1_000_000 {
|
||||||
|
return fmt.Errorf("packet_buffer_size must be between 1 and 1000000")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowedTypes := map[string]struct{}{
|
||||||
|
"stdout": {},
|
||||||
|
"file": {},
|
||||||
|
"unix_socket": {},
|
||||||
|
}
|
||||||
|
|
||||||
// Validate outputs
|
// Validate outputs
|
||||||
for i, output := range config.Outputs {
|
for i, output := range config.Outputs {
|
||||||
if output.Type == "" {
|
outputType := strings.TrimSpace(output.Type)
|
||||||
|
if outputType == "" {
|
||||||
return fmt.Errorf("output[%d]: type cannot be empty", i)
|
return fmt.Errorf("output[%d]: type cannot be empty", i)
|
||||||
}
|
}
|
||||||
|
if _, ok := allowedTypes[outputType]; !ok {
|
||||||
|
return fmt.Errorf("output[%d]: unknown type %q", i, outputType)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch outputType {
|
||||||
|
case "file":
|
||||||
|
if strings.TrimSpace(output.Params["path"]) == "" {
|
||||||
|
return fmt.Errorf("output[%d]: file output requires non-empty path", i)
|
||||||
|
}
|
||||||
|
case "unix_socket":
|
||||||
|
if strings.TrimSpace(output.Params["socket_path"]) == "" {
|
||||||
|
return fmt.Errorf("output[%d]: unix_socket output requires non-empty socket_path", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -58,12 +58,28 @@ func TestParsePorts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParsePorts_DeduplicateAndIgnoreZero(t *testing.T) {
|
||||||
|
got := parsePorts("443, 0, 443, 8443")
|
||||||
|
want := []uint16{443, 8443}
|
||||||
|
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("parsePorts() length = %d, want %d (got: %v)", len(got), len(want), got)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("parsePorts()[%d] = %d, want %d", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMergeConfigs(t *testing.T) {
|
func TestMergeConfigs(t *testing.T) {
|
||||||
base := api.AppConfig{
|
base := api.AppConfig{
|
||||||
Core: api.Config{
|
Core: api.Config{
|
||||||
Interface: "eth0",
|
Interface: "eth0",
|
||||||
ListenPorts: []uint16{443},
|
ListenPorts: []uint16{443},
|
||||||
BPFFilter: "",
|
BPFFilter: "",
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
},
|
},
|
||||||
Outputs: []api.OutputConfig{},
|
Outputs: []api.OutputConfig{},
|
||||||
}
|
}
|
||||||
@ -73,6 +89,8 @@ func TestMergeConfigs(t *testing.T) {
|
|||||||
Interface: "lo",
|
Interface: "lo",
|
||||||
ListenPorts: []uint16{8443},
|
ListenPorts: []uint16{8443},
|
||||||
BPFFilter: "tcp",
|
BPFFilter: "tcp",
|
||||||
|
FlowTimeoutSec: 60,
|
||||||
|
PacketBufferSize: 2000,
|
||||||
},
|
},
|
||||||
Outputs: []api.OutputConfig{
|
Outputs: []api.OutputConfig{
|
||||||
{Type: "stdout", Enabled: true},
|
{Type: "stdout", Enabled: true},
|
||||||
@ -93,6 +111,12 @@ func TestMergeConfigs(t *testing.T) {
|
|||||||
if len(result.Outputs) != 1 {
|
if len(result.Outputs) != 1 {
|
||||||
t.Errorf("Outputs length = %v, want 1", len(result.Outputs))
|
t.Errorf("Outputs length = %v, want 1", len(result.Outputs))
|
||||||
}
|
}
|
||||||
|
if result.Core.FlowTimeoutSec != 60 {
|
||||||
|
t.Errorf("FlowTimeoutSec = %v, want 60", result.Core.FlowTimeoutSec)
|
||||||
|
}
|
||||||
|
if result.Core.PacketBufferSize != 2000 {
|
||||||
|
t.Errorf("PacketBufferSize = %v, want 2000", result.Core.PacketBufferSize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidate(t *testing.T) {
|
func TestValidate(t *testing.T) {
|
||||||
@ -109,6 +133,8 @@ func TestValidate(t *testing.T) {
|
|||||||
Core: api.Config{
|
Core: api.Config{
|
||||||
Interface: "eth0",
|
Interface: "eth0",
|
||||||
ListenPorts: []uint16{443},
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
},
|
},
|
||||||
Outputs: []api.OutputConfig{
|
Outputs: []api.OutputConfig{
|
||||||
{Type: "stdout", Enabled: true},
|
{Type: "stdout", Enabled: true},
|
||||||
@ -122,6 +148,8 @@ func TestValidate(t *testing.T) {
|
|||||||
Core: api.Config{
|
Core: api.Config{
|
||||||
Interface: "",
|
Interface: "",
|
||||||
ListenPorts: []uint16{443},
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
@ -132,6 +160,8 @@ func TestValidate(t *testing.T) {
|
|||||||
Core: api.Config{
|
Core: api.Config{
|
||||||
Interface: "eth0",
|
Interface: "eth0",
|
||||||
ListenPorts: []uint16{},
|
ListenPorts: []uint16{},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
@ -142,6 +172,8 @@ func TestValidate(t *testing.T) {
|
|||||||
Core: api.Config{
|
Core: api.Config{
|
||||||
Interface: "eth0",
|
Interface: "eth0",
|
||||||
ListenPorts: []uint16{443},
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
},
|
},
|
||||||
Outputs: []api.OutputConfig{
|
Outputs: []api.OutputConfig{
|
||||||
{Type: "", Enabled: true},
|
{Type: "", Enabled: true},
|
||||||
@ -149,6 +181,18 @@ func TestValidate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "listen port zero",
|
||||||
|
config: api.AppConfig{
|
||||||
|
Core: api.Config{
|
||||||
|
Interface: "eth0",
|
||||||
|
ListenPorts: []uint16{0},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -161,6 +205,162 @@ func TestValidate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidate_InvalidCoreBounds(t *testing.T) {
|
||||||
|
loader := &LoaderImpl{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg api.AppConfig
|
||||||
|
hasErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "timeout zero",
|
||||||
|
cfg: api.AppConfig{
|
||||||
|
Core: api.Config{
|
||||||
|
Interface: "eth0",
|
||||||
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 0,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout too high",
|
||||||
|
cfg: api.AppConfig{
|
||||||
|
Core: api.Config{
|
||||||
|
Interface: "eth0",
|
||||||
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 301,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "buffer zero",
|
||||||
|
cfg: api.AppConfig{
|
||||||
|
Core: api.Config{
|
||||||
|
Interface: "eth0",
|
||||||
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "buffer too high",
|
||||||
|
cfg: api.AppConfig{
|
||||||
|
Core: api.Config{
|
||||||
|
Interface: "eth0",
|
||||||
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1_000_001,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid bounds",
|
||||||
|
cfg: api.AppConfig{
|
||||||
|
Core: api.Config{
|
||||||
|
Interface: "eth0",
|
||||||
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
|
},
|
||||||
|
Outputs: []api.OutputConfig{
|
||||||
|
{Type: "stdout", Enabled: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hasErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := loader.validate(tt.cfg)
|
||||||
|
if (err != nil) != tt.hasErr {
|
||||||
|
t.Fatalf("validate() error = %v, wantErr %v", err, tt.hasErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate_InvalidOutputs(t *testing.T) {
|
||||||
|
loader := &LoaderImpl{}
|
||||||
|
|
||||||
|
baseCore := api.Config{
|
||||||
|
Interface: "eth0",
|
||||||
|
ListenPorts: []uint16{443},
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
outputs []api.OutputConfig
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unknown output type",
|
||||||
|
outputs: []api.OutputConfig{
|
||||||
|
{Type: "unknown", Enabled: true},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file without path",
|
||||||
|
outputs: []api.OutputConfig{
|
||||||
|
{Type: "file", Enabled: true, Params: map[string]string{}},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket without socket_path",
|
||||||
|
outputs: []api.OutputConfig{
|
||||||
|
{Type: "unix_socket", Enabled: true, Params: map[string]string{}},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid file output",
|
||||||
|
outputs: []api.OutputConfig{
|
||||||
|
{Type: "file", Enabled: true, Params: map[string]string{"path": "/tmp/x.log"}},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid unix socket output",
|
||||||
|
outputs: []api.OutputConfig{
|
||||||
|
{Type: "unix_socket", Enabled: true, Params: map[string]string{"socket_path": "/tmp/x.sock"}},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid stdout output",
|
||||||
|
outputs: []api.OutputConfig{
|
||||||
|
{Type: "stdout", Enabled: true},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := api.AppConfig{
|
||||||
|
Core: baseCore,
|
||||||
|
Outputs: tt.outputs,
|
||||||
|
}
|
||||||
|
err := loader.validate(cfg)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoadFromEnv(t *testing.T) {
|
func TestLoadFromEnv(t *testing.T) {
|
||||||
// Save original env vars
|
// Save original env vars
|
||||||
origInterface := os.Getenv("JA4SENTINEL_INTERFACE")
|
origInterface := os.Getenv("JA4SENTINEL_INTERFACE")
|
||||||
@ -198,6 +398,8 @@ func TestToJSON(t *testing.T) {
|
|||||||
Interface: "eth0",
|
Interface: "eth0",
|
||||||
ListenPorts: []uint16{443, 8443},
|
ListenPorts: []uint16{443, 8443},
|
||||||
BPFFilter: "tcp",
|
BPFFilter: "tcp",
|
||||||
|
FlowTimeoutSec: 30,
|
||||||
|
PacketBufferSize: 1000,
|
||||||
},
|
},
|
||||||
Outputs: []api.OutputConfig{
|
Outputs: []api.OutputConfig{
|
||||||
{Type: "stdout", Enabled: true, Params: map[string]string{}},
|
{Type: "stdout", Enabled: true, Params: map[string]string{}},
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"ja4sentinel/api"
|
"ja4sentinel/api"
|
||||||
)
|
)
|
||||||
@ -79,22 +80,24 @@ type UnixSocketWriter struct {
|
|||||||
socketPath string
|
socketPath string
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
dialTimeout time.Duration
|
||||||
|
writeTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUnixSocketWriter creates a new UNIX socket writer
|
// NewUnixSocketWriter creates a new UNIX socket writer
|
||||||
func NewUnixSocketWriter(socketPath string) (*UnixSocketWriter, error) {
|
func NewUnixSocketWriter(socketPath string) (*UnixSocketWriter, error) {
|
||||||
w := &UnixSocketWriter{
|
w := &UnixSocketWriter{
|
||||||
socketPath: socketPath,
|
socketPath: socketPath,
|
||||||
|
dialTimeout: 2 * time.Second,
|
||||||
|
writeTimeout: 2 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to connect (socket may not exist yet)
|
// Try to connect (socket may not exist yet)
|
||||||
conn, err := net.Dial("unix", socketPath)
|
conn, err := net.DialTimeout("unix", socketPath, w.dialTimeout)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
// Socket doesn't exist yet, we'll try to connect on first write
|
w.conn = conn
|
||||||
return w, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
w.conn = conn
|
|
||||||
return w, nil
|
return w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +110,7 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
|
|||||||
if w.conn != nil {
|
if w.conn != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
conn, err := net.Dial("unix", w.socketPath)
|
conn, err := net.DialTimeout("unix", w.socketPath, w.dialTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
|
return fmt.Errorf("failed to connect to socket %s: %w", w.socketPath, err)
|
||||||
}
|
}
|
||||||
@ -123,22 +126,32 @@ func (w *UnixSocketWriter) Write(rec api.LogRecord) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal record: %w", err)
|
return fmt.Errorf("failed to marshal record: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add newline for line-based protocols
|
|
||||||
data = append(data, '\n')
|
data = append(data, '\n')
|
||||||
|
|
||||||
if _, err = w.conn.Write(data); err != nil {
|
if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("failed to set write deadline: %w", err)
|
||||||
|
}
|
||||||
|
if _, err = w.conn.Write(data); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
_ = w.conn.Close()
|
_ = w.conn.Close()
|
||||||
w.conn = nil
|
w.conn = nil
|
||||||
|
|
||||||
if err2 := ensureConn(); err2 != nil {
|
if errConn := ensureConn(); errConn != nil {
|
||||||
return fmt.Errorf("failed to write to socket and reconnect failed: %w", err2)
|
return fmt.Errorf("failed to write to socket and reconnect failed: %w", errConn)
|
||||||
}
|
}
|
||||||
if _, err2 := w.conn.Write(data); err2 != nil {
|
|
||||||
|
if errDeadline := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); errDeadline != nil {
|
||||||
_ = w.conn.Close()
|
_ = w.conn.Close()
|
||||||
w.conn = nil
|
w.conn = nil
|
||||||
return fmt.Errorf("failed to write to socket after reconnect: %w", err2)
|
return fmt.Errorf("failed to set write deadline after reconnect: %w", errDeadline)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, errRetry := w.conn.Write(data); errRetry != nil {
|
||||||
|
_ = w.conn.Close()
|
||||||
|
w.conn = nil
|
||||||
|
return fmt.Errorf("failed to write to socket after reconnect: %w", errRetry)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -4,9 +4,11 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -238,6 +240,112 @@ func TestUnixSocketWriter(t *testing.T) {
|
|||||||
writer.Close()
|
writer.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnixSocketWriter_Write_NonexistentSocket_ReturnsQuickly(t *testing.T) {
|
||||||
|
socketPath := filepath.Join(t.TempDir(), "ja4sentinel_missing.sock")
|
||||||
|
writer, err := NewUnixSocketWriter(socketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewUnixSocketWriter() error = %v", err)
|
||||||
|
}
|
||||||
|
defer writer.Close()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err = writer.Write(api.LogRecord{
|
||||||
|
SrcIP: "192.168.1.10",
|
||||||
|
SrcPort: 44444,
|
||||||
|
DstIP: "10.0.0.10",
|
||||||
|
DstPort: 443,
|
||||||
|
})
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Write() should fail for non-existent socket")
|
||||||
|
}
|
||||||
|
if elapsed >= 3*time.Second {
|
||||||
|
t.Fatalf("Write() took too long: %v (expected < 3s)", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type timeoutError struct{}
|
||||||
|
|
||||||
|
func (timeoutError) Error() string { return "i/o timeout" }
|
||||||
|
func (timeoutError) Timeout() bool { return true }
|
||||||
|
func (timeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
type mockAddr string
|
||||||
|
|
||||||
|
func (a mockAddr) Network() string { return "unix" }
|
||||||
|
func (a mockAddr) String() string { return string(a) }
|
||||||
|
|
||||||
|
type mockConn struct {
|
||||||
|
writeCalls int
|
||||||
|
closeCalled bool
|
||||||
|
setWriteDeadlineCalled bool
|
||||||
|
setReadDeadlineCalled bool
|
||||||
|
setAnyDeadlineWasCalled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") }
|
||||||
|
|
||||||
|
func (m *mockConn) Write(_ []byte) (int, error) {
|
||||||
|
m.writeCalls++
|
||||||
|
return 0, timeoutError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) Close() error {
|
||||||
|
m.closeCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) LocalAddr() net.Addr { return mockAddr("local") }
|
||||||
|
func (m *mockConn) RemoteAddr() net.Addr { return mockAddr("remote") }
|
||||||
|
|
||||||
|
func (m *mockConn) SetDeadline(_ time.Time) error {
|
||||||
|
m.setAnyDeadlineWasCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) SetReadDeadline(_ time.Time) error {
|
||||||
|
m.setReadDeadlineCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockConn) SetWriteDeadline(_ time.Time) error {
|
||||||
|
m.setWriteDeadlineCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSocketWriter_Write_UsesWriteDeadline(t *testing.T) {
|
||||||
|
mc := &mockConn{}
|
||||||
|
writer := &UnixSocketWriter{
|
||||||
|
socketPath: filepath.Join(t.TempDir(), "missing.sock"),
|
||||||
|
conn: mc,
|
||||||
|
dialTimeout: 100 * time.Millisecond,
|
||||||
|
writeTimeout: 100 * time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := writer.Write(api.LogRecord{
|
||||||
|
SrcIP: "192.168.1.20",
|
||||||
|
SrcPort: 55555,
|
||||||
|
DstIP: "10.0.0.20",
|
||||||
|
DstPort: 443,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Write() should fail because reconnect target does not exist")
|
||||||
|
}
|
||||||
|
if !mc.setWriteDeadlineCalled {
|
||||||
|
t.Fatal("expected SetWriteDeadline to be called before write")
|
||||||
|
}
|
||||||
|
if !mc.closeCalled {
|
||||||
|
t.Fatal("expected connection to be closed after first write failure")
|
||||||
|
}
|
||||||
|
if mc.writeCalls != 1 {
|
||||||
|
t.Fatalf("expected exactly 1 write on initial conn, got %d", mc.writeCalls)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "reconnect failed") {
|
||||||
|
t.Fatalf("expected reconnect failure error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type unixTestServer struct {
|
type unixTestServer struct {
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
received chan string
|
received chan string
|
||||||
|
|||||||
@ -48,6 +48,8 @@ type ParserImpl struct {
|
|||||||
cleanupDone chan struct{}
|
cleanupDone chan struct{}
|
||||||
cleanupClose chan struct{}
|
cleanupClose chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
|
maxTrackedFlows int
|
||||||
|
maxHelloBufferBytes int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewParser creates a new TLS parser with connection state tracking
|
// NewParser creates a new TLS parser with connection state tracking
|
||||||
@ -62,6 +64,8 @@ func NewParserWithTimeout(timeout time.Duration) *ParserImpl {
|
|||||||
flowTimeout: timeout,
|
flowTimeout: timeout,
|
||||||
cleanupDone: make(chan struct{}),
|
cleanupDone: make(chan struct{}),
|
||||||
cleanupClose: make(chan struct{}),
|
cleanupClose: make(chan struct{}),
|
||||||
|
maxTrackedFlows: 50000,
|
||||||
|
maxHelloBufferBytes: 256 * 1024, // 256 KiB
|
||||||
}
|
}
|
||||||
go p.cleanupLoop()
|
go p.cleanupLoop()
|
||||||
return p
|
return p
|
||||||
@ -164,15 +168,26 @@ func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) {
|
|||||||
return nil, nil // No payload
|
return nil, nil // No payload
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get or create connection flow
|
|
||||||
key := flowKey(srcIP, srcPort, dstIP, dstPort)
|
key := flowKey(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
p.mu.RLock()
|
||||||
|
_, flowExists := p.flows[key]
|
||||||
|
p.mu.RUnlock()
|
||||||
|
|
||||||
|
if !flowExists && payload[0] != 22 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
flow := p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta)
|
flow := p.getOrCreateFlow(key, srcIP, srcPort, dstIP, dstPort, ipMeta, tcpMeta)
|
||||||
|
if flow == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Check if flow is already done
|
// Check if flow is already done
|
||||||
p.mu.RLock()
|
p.mu.RLock()
|
||||||
isDone := flow.State == JA4_DONE
|
state := flow.State
|
||||||
p.mu.RUnlock()
|
p.mu.RUnlock()
|
||||||
if isDone {
|
if state == JA4_DONE {
|
||||||
return nil, nil // Already processed this flow
|
return nil, nil // Already processed this flow
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,8 +216,13 @@ func (p *ParserImpl) Process(pkt api.RawPacket) (*api.TLSClientHello, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for fragmented ClientHello (accumulate segments)
|
// Check for fragmented ClientHello (accumulate segments)
|
||||||
if flow.State == WAIT_CLIENT_HELLO || flow.State == NEW {
|
if state == WAIT_CLIENT_HELLO || state == NEW {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
|
if len(flow.HelloBuffer)+len(payload) > p.maxHelloBufferBytes {
|
||||||
|
delete(p.flows, key)
|
||||||
|
p.mu.Unlock()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
flow.State = WAIT_CLIENT_HELLO
|
flow.State = WAIT_CLIENT_HELLO
|
||||||
flow.HelloBuffer = append(flow.HelloBuffer, payload...)
|
flow.HelloBuffer = append(flow.HelloBuffer, payload...)
|
||||||
bufferCopy := make([]byte, len(flow.HelloBuffer))
|
bufferCopy := make([]byte, len(flow.HelloBuffer))
|
||||||
@ -246,6 +266,10 @@ func (p *ParserImpl) getOrCreateFlow(key string, srcIP string, srcPort uint16, d
|
|||||||
return flow
|
return flow
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(p.flows) >= p.maxTrackedFlows {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
flow := &ConnectionFlow{
|
flow := &ConnectionFlow{
|
||||||
State: NEW,
|
State: NEW,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
|
|||||||
@ -1,8 +1,13 @@
|
|||||||
package tlsparse
|
package tlsparse
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ja4sentinel/api"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,6 +208,8 @@ func createTLSServerHello(version uint16) []byte {
|
|||||||
|
|
||||||
func TestNewParser(t *testing.T) {
|
func TestNewParser(t *testing.T) {
|
||||||
parser := NewParser()
|
parser := NewParser()
|
||||||
|
defer parser.Close()
|
||||||
|
|
||||||
if parser == nil {
|
if parser == nil {
|
||||||
t.Error("NewParser() returned nil")
|
t.Error("NewParser() returned nil")
|
||||||
}
|
}
|
||||||
@ -288,3 +295,152 @@ func TestExtractTCPMeta_MSSInvalid_NoPanic(t *testing.T) {
|
|||||||
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
|
t.Fatalf("expected MSS_INVALID in options, got %v", meta.Options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOrCreateFlow_RespectsMaxTrackedFlows(t *testing.T) {
|
||||||
|
parser := NewParser()
|
||||||
|
defer parser.Close()
|
||||||
|
|
||||||
|
parser.maxTrackedFlows = 1
|
||||||
|
|
||||||
|
flow1 := parser.getOrCreateFlow(
|
||||||
|
flowKey("192.168.1.1", 12345, "10.0.0.1", 443),
|
||||||
|
"192.168.1.1", 12345, "10.0.0.1", 443,
|
||||||
|
api.IPMeta{}, api.TCPMeta{},
|
||||||
|
)
|
||||||
|
if flow1 == nil {
|
||||||
|
t.Fatal("first flow should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
flow2 := parser.getOrCreateFlow(
|
||||||
|
flowKey("192.168.1.2", 12346, "10.0.0.1", 443),
|
||||||
|
"192.168.1.2", 12346, "10.0.0.1", 443,
|
||||||
|
api.IPMeta{}, api.TCPMeta{},
|
||||||
|
)
|
||||||
|
if flow2 != nil {
|
||||||
|
t.Fatal("second flow should be nil when maxTrackedFlows is reached")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_DropsWhenHelloBufferExceedsLimit(t *testing.T) {
|
||||||
|
parser := NewParserWithTimeout(30 * time.Second)
|
||||||
|
defer parser.Close()
|
||||||
|
|
||||||
|
parser.maxHelloBufferBytes = 10
|
||||||
|
|
||||||
|
srcIP := "192.168.1.10"
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(443)
|
||||||
|
|
||||||
|
// TLS-like payload, but intentionally incomplete to trigger accumulation.
|
||||||
|
payloadChunk := []byte{0x16, 0x03, 0x03, 0x00, 0x20, 0x01} // len = 6
|
||||||
|
|
||||||
|
pkt1 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk)
|
||||||
|
ch, err := parser.Process(pkt1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Process() error = %v", err)
|
||||||
|
}
|
||||||
|
if ch != nil {
|
||||||
|
t.Fatal("first Process() should not return complete ClientHello")
|
||||||
|
}
|
||||||
|
|
||||||
|
key := flowKey(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
parser.mu.RLock()
|
||||||
|
_, existsAfterFirst := parser.flows[key]
|
||||||
|
parser.mu.RUnlock()
|
||||||
|
if !existsAfterFirst {
|
||||||
|
t.Fatal("flow should exist after first chunk")
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt2 := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payloadChunk)
|
||||||
|
ch, err = parser.Process(pkt2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Process() error = %v", err)
|
||||||
|
}
|
||||||
|
if ch != nil {
|
||||||
|
t.Fatal("second Process() should not return ClientHello")
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.mu.RLock()
|
||||||
|
_, existsAfterSecond := parser.flows[key]
|
||||||
|
parser.mu.RUnlock()
|
||||||
|
if existsAfterSecond {
|
||||||
|
t.Fatal("flow should be removed when hello buffer exceeds maxHelloBufferBytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_NonTLSNewFlowNotTracked(t *testing.T) {
|
||||||
|
parser := NewParser()
|
||||||
|
defer parser.Close()
|
||||||
|
|
||||||
|
srcIP := "192.168.1.20"
|
||||||
|
dstIP := "10.0.0.2"
|
||||||
|
srcPort := uint16(23456)
|
||||||
|
dstPort := uint16(443)
|
||||||
|
|
||||||
|
// Non-TLS content type (not 22)
|
||||||
|
payload := []byte{0x17, 0x03, 0x03, 0x00, 0x05, 0x00}
|
||||||
|
|
||||||
|
pkt := buildRawPacket(t, srcIP, dstIP, srcPort, dstPort, payload)
|
||||||
|
ch, err := parser.Process(pkt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Process() error = %v", err)
|
||||||
|
}
|
||||||
|
if ch != nil {
|
||||||
|
t.Fatal("Process() should return nil for non-TLS new flow")
|
||||||
|
}
|
||||||
|
|
||||||
|
key := flowKey(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
parser.mu.RLock()
|
||||||
|
_, exists := parser.flows[key]
|
||||||
|
parser.mu.RUnlock()
|
||||||
|
if exists {
|
||||||
|
t.Fatal("non-TLS new flow should not be tracked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRawPacket(t *testing.T, srcIP, dstIP string, srcPort, dstPort uint16, payload []byte) api.RawPacket {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ip := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
SrcIP: net.ParseIP(srcIP).To4(),
|
||||||
|
DstIP: net.ParseIP(dstIP).To4(),
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
Seq: 1,
|
||||||
|
ACK: true,
|
||||||
|
Window: 65535,
|
||||||
|
}
|
||||||
|
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
t.Fatalf("SetNetworkLayerForChecksum() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
eth := &layers.Ethernet{
|
||||||
|
SrcMAC: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
|
||||||
|
DstMAC: net.HardwareAddr{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
|
||||||
|
EthernetType: layers.EthernetTypeIPv4,
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
FixLengths: true,
|
||||||
|
ComputeChecksums: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gopacket.SerializeLayers(buf, opts, eth, ip, tcp, gopacket.Payload(payload)); err != nil {
|
||||||
|
t.Fatalf("SerializeLayers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.RawPacket{
|
||||||
|
Data: buf.Bytes(),
|
||||||
|
Timestamp: time.Now().UnixNano(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user