feature: add source IP exclusion with CIDR support
Features:
- Add exclude_source_ips configuration option
- Support single IPs (192.168.1.1) and CIDR ranges (10.0.0.0/8)
- Filter packets in parser before TLS processing
- Log exclusion configuration at startup
- New ipfilter package with IP/CIDR matching
- Unit tests for ipfilter package
Configuration example:
exclude_source_ips:
- "10.0.0.0/8" # Exclude private network
- "192.168.1.1" # Exclude specific IP
- "172.16.0.0/12" # Exclude another range
- "2001:db8::/32" # IPv6 support
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
84
internal/ipfilter/ipfilter.go
Normal file
84
internal/ipfilter/ipfilter.go
Normal file
@ -0,0 +1,84 @@
|
||||
// Package ipfilter provides IP address and CIDR range matching for filtering
|
||||
package ipfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Filter checks if an IP address should be excluded based on a list of IPs or CIDR ranges
|
||||
type Filter struct {
|
||||
mu sync.RWMutex
|
||||
networks []*net.IPNet
|
||||
ips []net.IP
|
||||
}
|
||||
|
||||
// New creates a new IP filter from a list of IP addresses or CIDR ranges
|
||||
// Accepts formats like: "192.168.1.1", "10.0.0.0/8", "2001:db8::/32"
|
||||
func New(excludeList []string) (*Filter, error) {
|
||||
f := &Filter{
|
||||
networks: make([]*net.IPNet, 0),
|
||||
ips: make([]net.IP, 0),
|
||||
}
|
||||
|
||||
for _, entry := range excludeList {
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try parsing as CIDR first
|
||||
if _, ipNet, err := net.ParseCIDR(entry); err == nil {
|
||||
f.networks = append(f.networks, ipNet)
|
||||
continue
|
||||
}
|
||||
|
||||
// Try parsing as single IP
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
f.ips = append(f.ips, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid IP or CIDR: %s", entry)
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// ShouldExclude checks if an IP address should be excluded
|
||||
func (f *Filter) ShouldExclude(ipStr string) bool {
|
||||
if f == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
// Check against single IPs
|
||||
for _, filterIP := range f.ips {
|
||||
if ip.Equal(filterIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check against CIDR ranges
|
||||
for _, network := range f.networks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Count returns the number of loaded filter entries
|
||||
func (f *Filter) Count() (ips int, networks int) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return len(f.ips), len(f.networks)
|
||||
}
|
||||
160
internal/ipfilter/ipfilter_test.go
Normal file
160
internal/ipfilter/ipfilter_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
package ipfilter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilter_New(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
list []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty list",
|
||||
list: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "single IP",
|
||||
list: []string{"192.168.1.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "single CIDR",
|
||||
list: []string{"10.0.0.0/8"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed IPs and CIDRs",
|
||||
list: []string{"192.168.1.1", "10.0.0.0/8", "172.16.0.0/12"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid IP",
|
||||
list: []string{"999.999.999.999"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid CIDR",
|
||||
list: []string{"10.0.0.0/33"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
list: []string{"2001:db8::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 CIDR",
|
||||
list: []string{"2001:db8::/32"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := New(tt.list)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil && f == nil {
|
||||
t.Error("New() should return non-nil filter on success")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_ShouldExclude(t *testing.T) {
|
||||
f, err := New([]string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"2001:db8::1",
|
||||
"fc00::/7",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
want bool
|
||||
}{
|
||||
// Exact IP matches
|
||||
{"exact match", "192.168.1.1", true},
|
||||
{"exact IPv6 match", "2001:db8::1", true},
|
||||
|
||||
// CIDR matches
|
||||
{"CIDR match 10.0.0.1", "10.0.0.1", true},
|
||||
{"CIDR match 10.255.255.255", "10.255.255.255", true},
|
||||
{"CIDR match 172.16.0.1", "172.16.0.1", true},
|
||||
{"CIDR match 172.31.255.255", "172.31.255.255", true},
|
||||
{"CIDR IPv6 match", "fc00::1", true},
|
||||
|
||||
// No matches
|
||||
{"no match 192.168.2.1", "192.168.2.1", false},
|
||||
{"no match 11.0.0.1", "11.0.0.1", false},
|
||||
{"no match 172.32.0.1", "172.32.0.1", false},
|
||||
{"no match 8.8.8.8", "8.8.8.8", false},
|
||||
|
||||
// Invalid IP
|
||||
{"invalid IP", "invalid", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := f.ShouldExclude(tt.ip); got != tt.want {
|
||||
t.Errorf("ShouldExclude(%q) = %v, want %v", tt.ip, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_ShouldExclude_NilFilter(t *testing.T) {
|
||||
var f *Filter
|
||||
if f.ShouldExclude("192.168.1.1") {
|
||||
t.Error("ShouldExclude on nil filter should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_Count(t *testing.T) {
|
||||
f, err := New([]string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ips, networks := f.Count()
|
||||
if ips != 2 {
|
||||
t.Errorf("Count() ips = %d, want 2", ips)
|
||||
}
|
||||
if networks != 2 {
|
||||
t.Errorf("Count() networks = %d, want 2", networks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_EmptyEntries(t *testing.T) {
|
||||
f, err := New([]string{"", "192.168.1.1", ""})
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ips, _ := f.Count()
|
||||
if ips != 1 {
|
||||
t.Errorf("Count() ips = %d, want 1 (empty entries should be skipped)", ips)
|
||||
}
|
||||
|
||||
if !f.ShouldExclude("192.168.1.1") {
|
||||
t.Error("Should exclude 192.168.1.1")
|
||||
}
|
||||
if f.ShouldExclude("192.168.1.2") {
|
||||
t.Error("Should not exclude 192.168.1.2")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user