Files
ja4-platform/services/bot-detector/bot_detector/fleet.py
Jacquin Antoine c6cb12981c feat(ml): replace NetworkX/Louvain with PyTorch Geometric GraphSAGE for fleet detection
Rewrite fleet.py to use a GNN-based approach: nodes are src_ip with ML feature
vectors, edges connect IPs sharing (JA4, ASN) pairs, GraphSAGE (2 SAGEConv
layers, in→64→32) produces 32D embeddings clustered by HDBSCAN. PyG NeighborLoader
activates for >50k nodes. Update thesis docs (§5.2, §6.4, §2, §8) to reflect
GraphSAGE architecture and PyG scalability.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-13 15:45:34 +02:00

292 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Détection de flottes de bots par apprentissage de représentations de graphe (GraphSAGE).
§5.2 — GNN GraphSAGE sur graphe d'IPs pour identifier les flottes coordonnées
qui font tourner leurs fingerprints JA4 et ASN.
Algorithme :
1. Agréger les features ML par src_ip (nœuds du graphe)
2. Construire le graphe : arête entre IP_A et IP_B si elles partagent
même JA4 + même ASN dans un groupe de >= min_ips IPs
3. Inférence GraphSAGE (SAGEConv × 2 : in_channels → 64 → 32, ReLU + Dropout)
4. Clustering DBSCAN/HDBSCAN sur les embeddings 32D
5. fleet_score = cluster_size × compactness / log2(n_asn + 2)
"""
import logging
from typing import Dict, Any
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
# ── Configuration via variables d'environnement ──────────────────────────────
FLEET_SCORE_THRESHOLD = float(__import__('os').getenv('FLEET_SCORE_THRESHOLD', '2.0'))
FLEET_SCORE_WEIGHT = float(__import__('os').getenv('FLEET_SCORE_WEIGHT', '0.10'))
FLEET_MIN_IPS = int(__import__('os').getenv('FLEET_MIN_IPS', '3'))
FLEET_BATCH_THRESHOLD = int(__import__('os').getenv('FLEET_BATCH_THRESHOLD', '50000'))
FLEET_HIDDEN_DIM = int(__import__('os').getenv('FLEET_HIDDEN_DIM', '64'))
FLEET_EMBED_DIM = int(__import__('os').getenv('FLEET_EMBED_DIM', '32'))
FLEET_DROPOUT = float(__import__('os').getenv('FLEET_DROPOUT', '0.1'))
# ═══════════════════════════════════════════════════════════════════════════════
# Construction du graphe d'IPs
# ═══════════════════════════════════════════════════════════════════════════════
def _build_ip_graph(df: pd.DataFrame, min_ips: int = FLEET_MIN_IPS):
"""Construit le graphe d'IPs pour l'inférence GraphSAGE.
Nœuds : src_ip uniques, features = vecteur ML agrégé par IP (moyenne).
Arêtes : (IP_A, IP_B) si elles co-occurent dans un groupe (JA4, ASN)
comptant >= min_ips IPs distinctes.
Retourne (unique_ips, node_features, edge_index) ou None si données
insuffisantes.
"""
if df.empty or 'ja4' not in df.columns or 'asn_number' not in df.columns:
return None
# Filtrer JA4 vides et ASN 0
mask = (df['ja4'].fillna('') != '') & (df['asn_number'].fillna('0') != '0')
sub = df[mask].copy()
if len(sub) < 10:
return None
# Index des IPs uniques
unique_ips = sub['src_ip'].unique()
if len(unique_ips) < 5:
return None
ip_to_idx = {ip: i for i, ip in enumerate(unique_ips)}
# ── Features par nœud : agrégation des colonnes numériques ──
numeric_cols = sub.select_dtypes(include=[np.number]).columns.tolist()
skip_cols = {'asn_number', 'fleet_campaign_flag'}
feature_cols = [c for c in numeric_cols if c not in skip_cols]
if not feature_cols:
return None
node_features = (
sub.groupby('src_ip')[feature_cols]
.mean()
.reindex(unique_ips)
.fillna(0)
.values
.astype(np.float32)
)
# Normalisation min-max par colonne
mins = node_features.min(axis=0, keepdims=True)
maxs = node_features.max(axis=0, keepdims=True)
ranges = np.where(maxs - mins == 0, 1.0, maxs - mins)
node_features = (node_features - mins) / ranges
# ── Construction des arêtes : co-occurrence (JA4, ASN) ──
groups = (
sub.groupby(['ja4', 'asn_number'])['src_ip']
.agg(lambda s: list(s.unique()))
.reset_index(name='ips')
)
groups = groups[groups['ips'].map(len) >= min_ips]
if groups.empty:
return None
edge_set: set = set()
for ips in groups['ips']:
idx = sorted(ip_to_idx[ip] for ip in ips if ip in ip_to_idx)
n = len(idx)
if n < 2 or n > 500: # Sauter les très grands groupes (CDN/infra)
continue
for a in range(n):
for b in range(a + 1, n):
edge_set.add((idx[a], idx[b]))
edge_set.add((idx[b], idx[a]))
if len(edge_set) < 5:
return None
edge_index = np.array(sorted(edge_set), dtype=np.int64).T # [2, num_edges]
return unique_ips, node_features, edge_index
# ═══════════════════════════════════════════════════════════════════════════════
# Inférence GraphSAGE
# ═══════════════════════════════════════════════════════════════════════════════
def _infer_embeddings(node_features: np.ndarray, edge_index: np.ndarray,
n_nodes: int, batch_threshold: int = FLEET_BATCH_THRESHOLD):
"""Passe le graphe dans GraphSAGE pour obtenir les embeddings 32D par nœud.
Utilise NeighborLoader (batching PyG) si n_nodes > batch_threshold.
Sinon inférence full-batch.
"""
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self, in_ch: int, hidden_ch: int, out_ch: int,
dropout: float):
super().__init__()
self.conv1 = SAGEConv(in_ch, hidden_ch)
self.conv2 = SAGEConv(hidden_ch, out_ch)
self.drop = nn.Dropout(dropout)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.drop(x)
x = self.conv2(x, edge_index)
return x
in_dim = node_features.shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(in_dim, FLEET_HIDDEN_DIM, FLEET_EMBED_DIM,
FLEET_DROPOUT).to(device)
model.eval()
x = torch.tensor(node_features, dtype=torch.float32, device=device)
ei = torch.tensor(edge_index, dtype=torch.long, device=device)
with torch.no_grad():
if n_nodes > batch_threshold:
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
data = Data(x=x, edge_index=ei)
loader = NeighborLoader(
data,
num_neighbors=[25, 10], # échantillonnage 2-hop
batch_size=4096,
shuffle=False,
)
embeddings = torch.zeros(n_nodes, FLEET_EMBED_DIM, device=device)
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
embeddings[batch.n_id[:batch.batch_size]] = out[:batch.batch_size]
else:
embeddings = model(x, ei)
return embeddings.cpu().numpy()
# ═══════════════════════════════════════════════════════════════════════════════
# Pipeline de détection
# ═══════════════════════════════════════════════════════════════════════════════
def detect_fleet_communities(df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:
"""Détecte les flottes via GraphSAGE + DBSCAN/HDBSCAN.
Retourne {src_ip: {"cluster_id": int, "fleet_score": float}}.
Les IPs non assignées (bruit DBSCAN) n'apparaissent pas dans le dict.
"""
graph = _build_ip_graph(df)
if graph is None:
return {}
unique_ips, node_features, edge_index = graph
n_nodes = len(unique_ips)
# Vérifier les dépendances PyTorch Geometric
try:
import torch # noqa: F401
from torch_geometric.nn import SAGEConv # noqa: F401
except ImportError:
logger.warning(
"[Fleet] torch/torch_geometric non disponible — "
"analyse de flotte désactivée."
)
return {}
# ── Inférence GraphSAGE ──
try:
embeddings = _infer_embeddings(node_features, edge_index, n_nodes)
except Exception as e:
logger.warning(f"[Fleet] Erreur GraphSAGE : {e}")
return {}
# ── Clustering sur les embeddings 32D ──
try:
import hdbscan
labels = hdbscan.HDBSCAN(
min_cluster_size=3,
metric='euclidean',
cluster_selection_method='eom',
).fit_predict(embeddings)
except ImportError:
from sklearn.cluster import DBSCAN
labels = DBSCAN(eps=0.5, min_samples=3, metric='euclidean').fit_predict(
embeddings
)
except Exception as e:
logger.warning(f"[Fleet] Erreur clustering : {e}")
return {}
# ── Calcul du fleet_score par cluster ──
ip_asn = (
df.groupby('src_ip')['asn_number']
.agg(lambda s: set(s.dropna().unique()))
.to_dict()
)
cluster_scores: Dict[int, float] = {}
for cid in set(labels) - {-1}:
members_idx = np.where(labels == cid)[0]
cluster_ips = unique_ips[members_idx]
cluster_size = len(members_idx)
# Nombre d'ASN distincts dans le cluster
n_asn = len(
set().union(*(ip_asn.get(str(ip), set()) for ip in cluster_ips))
)
# Compacité : inverse de la distance moyenne au centroïde
centroid = embeddings[members_idx].mean(axis=0)
avg_dist = np.linalg.norm(
embeddings[members_idx] - centroid, axis=1
).mean()
compactness = 1.0 / (1.0 + avg_dist)
# fleet_score = taille × compacité / log2(n_asn + 2)
score = cluster_size * compactness / max(np.log2(n_asn + 2), 0.1)
cluster_scores[int(cid)] = round(float(score), 3)
# ── Construction du résultat ──
ip_results: Dict[str, Dict[str, Any]] = {}
for i, ip in enumerate(unique_ips):
cid = int(labels[i])
score = cluster_scores.get(cid, 0.0)
if score > 0:
ip_results[str(ip)] = {"cluster_id": cid, "fleet_score": score}
n_flagged = sum(
1 for v in ip_results.values()
if v["fleet_score"] >= FLEET_SCORE_THRESHOLD
)
if ip_results:
max_score = max(v["fleet_score"] for v in ip_results.values())
logger.info(
f"[Fleet] {n_flagged} IPs dans {len(cluster_scores)} flotte(s) "
f"(score max={max_score:.2f}, {n_nodes} noeuds, "
f"{edge_index.shape[1]} arêtes)"
)
return ip_results
def enrich_with_fleet_score(df: pd.DataFrame) -> pd.DataFrame:
"""Enrichit le DataFrame avec fleet_score et fleet_campaign_flag.
fleet_campaign_flag = 1 si fleet_score >= FLEET_SCORE_THRESHOLD.
fleet_score = 0 pour les IPs n'appartenant à aucune flotte détectée.
"""
df = df.copy()
fleet_map = detect_fleet_communities(df)
df['fleet_score'] = df['src_ip'].map(
{ip: v["fleet_score"] for ip, v in fleet_map.items()}
).fillna(0.0).astype(float)
df['fleet_campaign_flag'] = (
df['fleet_score'] >= FLEET_SCORE_THRESHOLD
).astype(int)
return df