// Package config provides configuration loading and validation for ja4sentinel package config import ( "encoding/json" "errors" "fmt" "os" "strconv" "strings" "gopkg.in/yaml.v3" "ja4sentinel/api" ) // LoaderImpl implements the api.Loader interface for configuration loading type LoaderImpl struct { configPath string } // NewLoader creates a new configuration loader func NewLoader(configPath string) *LoaderImpl { return &LoaderImpl{ configPath: configPath, } } // Load reads and merges configuration from file, environment variables, and CLI func (l *LoaderImpl) Load() (api.AppConfig, error) { config := api.DefaultConfig() path := l.configPath explicit := path != "" if !explicit { path = "config.yml" } fileConfig, err := l.loadFromFile(path) if err == nil { config = mergeConfigs(config, fileConfig) } else if !(!explicit && errors.Is(err, os.ErrNotExist)) { return config, fmt.Errorf("failed to load config file: %w", err) } // Override with environment variables config = l.loadFromEnv(config) // Validate the final configuration if err := l.validate(config); err != nil { return config, fmt.Errorf("invalid configuration: %w", err) } return config, nil } // loadFromFile reads configuration from a YAML file func (l *LoaderImpl) loadFromFile(path string) (api.AppConfig, error) { config := api.AppConfig{} data, err := os.ReadFile(path) if err != nil { return config, fmt.Errorf("failed to read config file: %w", err) } err = yaml.Unmarshal(data, &config) if err != nil { return config, fmt.Errorf("failed to parse config file: %w", err) } return config, nil } // loadFromEnv overrides configuration with environment variables func (l *LoaderImpl) loadFromEnv(config api.AppConfig) api.AppConfig { // JA4SENTINEL_INTERFACE if val := os.Getenv("JA4SENTINEL_INTERFACE"); val != "" { config.Core.Interface = val } // JA4SENTINEL_PORTS (comma-separated list) if val := os.Getenv("JA4SENTINEL_PORTS"); val != "" { ports := parsePorts(val) if len(ports) > 0 { config.Core.ListenPorts = ports } } // JA4SENTINEL_BPF_FILTER if val := os.Getenv("JA4SENTINEL_BPF_FILTER"); val != "" { config.Core.BPFFilter = val } // JA4SENTINEL_FLOW_TIMEOUT (in seconds) if val := os.Getenv("JA4SENTINEL_FLOW_TIMEOUT"); val != "" { if timeout, err := strconv.Atoi(val); err == nil && timeout > 0 { config.Core.FlowTimeoutSec = timeout } } // JA4SENTINEL_PACKET_BUFFER_SIZE if val := os.Getenv("JA4SENTINEL_PACKET_BUFFER_SIZE"); val != "" { if size, err := strconv.Atoi(val); err == nil && size > 0 { config.Core.PacketBufferSize = size } } // JA4SENTINEL_LOG_LEVEL if val := os.Getenv("JA4SENTINEL_LOG_LEVEL"); val != "" { config.Core.LogLevel = val } return config } // parsePorts parses a comma-separated list of ports func parsePorts(s string) []uint16 { if s == "" { return nil } parts := strings.Split(s, ",") ports := make([]uint16, 0, len(parts)) seen := make(map[uint16]struct{}, len(parts)) for _, part := range parts { part = strings.TrimSpace(part) if part == "" { continue } port, err := strconv.ParseUint(part, 10, 16) if err != nil { continue } p := uint16(port) if p == 0 { continue } if _, exists := seen[p]; exists { continue } seen[p] = struct{}{} ports = append(ports, p) } return ports } // mergeConfigs merges two configs, with override taking precedence func mergeConfigs(base, override api.AppConfig) api.AppConfig { result := base if override.Core.Interface != "" { result.Core.Interface = override.Core.Interface } if len(override.Core.ListenPorts) > 0 { result.Core.ListenPorts = override.Core.ListenPorts } if override.Core.BPFFilter != "" { result.Core.BPFFilter = override.Core.BPFFilter } if override.Core.FlowTimeoutSec > 0 { result.Core.FlowTimeoutSec = override.Core.FlowTimeoutSec } if override.Core.PacketBufferSize > 0 { result.Core.PacketBufferSize = override.Core.PacketBufferSize } if override.Core.LogLevel != "" { result.Core.LogLevel = override.Core.LogLevel } // Merge exclude_source_ips (override takes precedence) if len(override.Core.ExcludeSourceIPs) > 0 { result.Core.ExcludeSourceIPs = override.Core.ExcludeSourceIPs } if len(override.Outputs) > 0 { result.Outputs = override.Outputs } return result } // validate checks if the configuration is valid func (l *LoaderImpl) validate(config api.AppConfig) error { if strings.TrimSpace(config.Core.Interface) == "" { return fmt.Errorf("interface cannot be empty") } if len(config.Core.ListenPorts) == 0 { return fmt.Errorf("at least one listen port is required") } for _, p := range config.Core.ListenPorts { if p == 0 { return fmt.Errorf("listen port 0 is invalid") } } if config.Core.FlowTimeoutSec <= 0 || config.Core.FlowTimeoutSec > 300 { return fmt.Errorf("flow_timeout_sec must be between 1 and 300") } if config.Core.PacketBufferSize <= 0 || config.Core.PacketBufferSize > 1_000_000 { return fmt.Errorf("packet_buffer_size must be between 1 and 1000000") } // Validate log level validLogLevels := map[string]struct{}{ "debug": {}, "info": {}, "warn": {}, "error": {}, } if config.Core.LogLevel != "" { if _, ok := validLogLevels[config.Core.LogLevel]; !ok { return fmt.Errorf("log_level must be one of: debug, info, warn, error") } } // Validate exclude_source_ips (if provided) if len(config.Core.ExcludeSourceIPs) > 0 { for i, ip := range config.Core.ExcludeSourceIPs { if ip == "" { return fmt.Errorf("exclude_source_ips[%d]: entry cannot be empty", i) } // Basic validation: check if it looks like an IP or CIDR if !strings.Contains(ip, "/") { // Single IP - basic check if !isValidIP(ip) { return fmt.Errorf("exclude_source_ips[%d]: invalid IP address %q", i, ip) } } else { // CIDR - basic check if !isValidCIDR(ip) { return fmt.Errorf("exclude_source_ips[%d]: invalid CIDR %q", i, ip) } } } } allowedTypes := map[string]struct{}{ "stdout": {}, "file": {}, "unix_socket": {}, } // Validate outputs for i, output := range config.Outputs { outputType := strings.TrimSpace(output.Type) if outputType == "" { return fmt.Errorf("output[%d]: type cannot be empty", i) } if _, ok := allowedTypes[outputType]; !ok { return fmt.Errorf("output[%d]: unknown type %q", i, outputType) } switch outputType { case "file": if strings.TrimSpace(output.Params["path"]) == "" { return fmt.Errorf("output[%d]: file output requires non-empty path", i) } case "unix_socket": if strings.TrimSpace(output.Params["socket_path"]) == "" { return fmt.Errorf("output[%d]: unix_socket output requires non-empty socket_path", i) } } } return nil } // ToJSON converts config to JSON string for debugging func ToJSON(config api.AppConfig) string { data, err := json.MarshalIndent(config, "", " ") if err != nil { return fmt.Sprintf("error marshaling config: %v", err) } return string(data) } // isValidIP checks if a string is a valid IP address func isValidIP(ip string) bool { if ip == "" { return false } // Simple validation: check if it contains only valid IP characters for _, ch := range ip { if !((ch >= '0' && ch <= '9') || ch == '.') { // Could be IPv6 if ch == ':' { return true // Accept IPv6 without detailed validation } return false } } return true } // isValidCIDR checks if a string is a valid CIDR notation func isValidCIDR(cidr string) bool { if cidr == "" { return false } parts := strings.Split(cidr, "/") if len(parts) != 2 { return false } // Check IP part if !isValidIP(parts[0]) { return false } // Check prefix length prefix, err := strconv.Atoi(parts[1]) if err != nil { return false } if strings.Contains(parts[0], ":") { // IPv6 return prefix >= 0 && prefix <= 128 } // IPv4 return prefix >= 0 && prefix <= 32 }