feat(ml): replace NetworkX/Louvain with PyTorch Geometric GraphSAGE for fleet detection

Rewrite fleet.py to use a GNN-based approach: nodes are src_ip with ML feature
vectors, edges connect IPs sharing (JA4, ASN) pairs, GraphSAGE (2 SAGEConv
layers, in→64→32) produces 32D embeddings clustered by HDBSCAN. PyG NeighborLoader
activates for >50k nodes. Update thesis docs (§5.2, §6.4, §2, §8) to reflect
GraphSAGE architecture and PyG scalability.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jacquin Antoine
2026-04-13 15:45:34 +02:00
parent c1821dcbc4
commit c6cb12981c
8 changed files with 378 additions and 264 deletions

View File

@ -1,174 +1,291 @@
"""Détection de flottes de bots via graphe bipartite JA4×ASN.
"""Détection de flottes de bots par apprentissage de représentations de graphe (GraphSAGE).
§5.2 — Analyse de graphe bipartite G=(JA4 ASN, E) pour identifier les flottes
de bots coordonnées qui font tourner leurs fingerprints JA4 et ASN.
§5.2 — GNN GraphSAGE sur graphe d'IPs pour identifier les flottes coordonnées
qui font tourner leurs fingerprints JA4 et ASN.
Algorithme :
1. Construire le graphe bipartite G depuis les sessions du cycle courant
2. Projeter sur les nœuds JA4 (shared-ASN weighted projection)
3. Détecter les communautés Louvain (python-louvain)
4. Calculer fleet_score = taille × densité / log2(n_asn + 2) pour chaque communauté
5. Retourner les IPs appartenant aux communautés suspectes avec leur fleet_score
1. Agréger les features ML par src_ip (nœuds du graphe)
2. Construire le graphe : arête entre IP_A et IP_B si elles partagent
même JA4 + même ASN dans un groupe de >= min_ips IPs
3. Inférence GraphSAGE (SAGEConv × 2 : in_channels → 64 → 32, ReLU + Dropout)
4. Clustering DBSCAN/HDBSCAN sur les embeddings 32D
5. fleet_score = cluster_size × compactness / log2(n_asn + 2)
"""
import logging
from typing import Optional
from typing import Dict, Any
import pandas as pd
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
# Seuil de fleet_score à partir duquel une communauté est considérée suspecte
# ── Configuration via variables d'environnement ──────────────────────────────
FLEET_SCORE_THRESHOLD = float(__import__('os').getenv('FLEET_SCORE_THRESHOLD', '2.0'))
# Poids du fleet_score dans le score final (malus supplémentaire)
FLEET_SCORE_WEIGHT = float(__import__('os').getenv('FLEET_SCORE_WEIGHT', '0.10'))
# Nombre minimal d'arêtes pour inclure un JA4 dans l'analyse
FLEET_MIN_EDGES = int(__import__('os').getenv('FLEET_MIN_EDGES', '3'))
FLEET_MIN_IPS = int(__import__('os').getenv('FLEET_MIN_IPS', '3'))
FLEET_BATCH_THRESHOLD = int(__import__('os').getenv('FLEET_BATCH_THRESHOLD', '50000'))
FLEET_HIDDEN_DIM = int(__import__('os').getenv('FLEET_HIDDEN_DIM', '64'))
FLEET_EMBED_DIM = int(__import__('os').getenv('FLEET_EMBED_DIM', '32'))
FLEET_DROPOUT = float(__import__('os').getenv('FLEET_DROPOUT', '0.1'))
def build_fleet_graph(df: pd.DataFrame) -> Optional[object]:
"""Construit le graphe bipartite JA4×ASN à partir du cycle courant.
# ═══════════════════════════════════════════════════════════════════════════════
# Construction du graphe d'IPs
# ═══════════════════════════════════════════════════════════════════════════════
Nœuds : ensemble JA4 (préfixe 'ja4:') + ensemble ASN (préfixe 'asn:')
Arêtes : (ja4, asn) avec weight = nombre d'IPs distinctes sur ce couple
def _build_ip_graph(df: pd.DataFrame, min_ips: int = FLEET_MIN_IPS):
"""Construit le graphe d'IPs pour l'inférence GraphSAGE.
Exige que df ait les colonnes : ja4, asn_number, src_ip
Retourne None si networkx n'est pas disponible ou données insuffisantes.
Nœuds : src_ip uniques, features = vecteur ML agrégé par IP (moyenne).
Arêtes : (IP_A, IP_B) si elles co-occurent dans un groupe (JA4, ASN)
comptant >= min_ips IPs distinctes.
Retourne (unique_ips, node_features, edge_index) ou None si données
insuffisantes.
"""
try:
import networkx as nx
from networkx.algorithms import bipartite
except ImportError:
logger.warning("[Fleet] networkx non disponible — analyse de flotte désactivée.")
return None
if df.empty or 'ja4' not in df.columns or 'asn_number' not in df.columns:
return None
# Filtrer les JA4 vides et ASN 0
# Filtrer JA4 vides et ASN 0
mask = (df['ja4'].fillna('') != '') & (df['asn_number'].fillna('0') != '0')
sub = df[mask][['src_ip', 'ja4', 'asn_number']].copy()
sub = df[mask].copy()
if len(sub) < 10:
return None
# Compter les IPs par (ja4, asn) — poids de l'arête
edge_weights = (
sub.groupby(['ja4', 'asn_number'])['src_ip']
.nunique()
.reset_index(name='n_ips')
)
# Garder seulement les arêtes avec au moins FLEET_MIN_EDGES IPs distinctes
edge_weights = edge_weights[edge_weights['n_ips'] >= FLEET_MIN_EDGES]
if len(edge_weights) < 5:
# Index des IPs uniques
unique_ips = sub['src_ip'].unique()
if len(unique_ips) < 5:
return None
ip_to_idx = {ip: i for i, ip in enumerate(unique_ips)}
# ── Features par nœud : agrégation des colonnes numériques ──
numeric_cols = sub.select_dtypes(include=[np.number]).columns.tolist()
skip_cols = {'asn_number', 'fleet_campaign_flag'}
feature_cols = [c for c in numeric_cols if c not in skip_cols]
if not feature_cols:
return None
G = nx.Graph()
ja4_nodes = set()
asn_nodes = set()
for _, row in edge_weights.iterrows():
ja4_node = f"ja4:{row['ja4']}"
asn_node = f"asn:{row['asn_number']}"
G.add_node(ja4_node, bipartite=0)
G.add_node(asn_node, bipartite=1)
G.add_edge(ja4_node, asn_node, weight=int(row['n_ips']))
ja4_nodes.add(ja4_node)
asn_nodes.add(asn_node)
node_features = (
sub.groupby('src_ip')[feature_cols]
.mean()
.reindex(unique_ips)
.fillna(0)
.values
.astype(np.float32)
)
# Normalisation min-max par colonne
mins = node_features.min(axis=0, keepdims=True)
maxs = node_features.max(axis=0, keepdims=True)
ranges = np.where(maxs - mins == 0, 1.0, maxs - mins)
node_features = (node_features - mins) / ranges
return G, ja4_nodes, asn_nodes
# ── Construction des arêtes : co-occurrence (JA4, ASN) ──
groups = (
sub.groupby(['ja4', 'asn_number'])['src_ip']
.agg(lambda s: list(s.unique()))
.reset_index(name='ips')
)
groups = groups[groups['ips'].map(len) >= min_ips]
if groups.empty:
return None
edge_set: set = set()
for ips in groups['ips']:
idx = sorted(ip_to_idx[ip] for ip in ips if ip in ip_to_idx)
n = len(idx)
if n < 2 or n > 500: # Sauter les très grands groupes (CDN/infra)
continue
for a in range(n):
for b in range(a + 1, n):
edge_set.add((idx[a], idx[b]))
edge_set.add((idx[b], idx[a]))
if len(edge_set) < 5:
return None
edge_index = np.array(sorted(edge_set), dtype=np.int64).T # [2, num_edges]
return unique_ips, node_features, edge_index
def detect_fleet_communities(df: pd.DataFrame) -> dict:
"""Analyse le graphe bipartite et retourne un dict {src_ip: fleet_score}.
# ═══════════════════════════════════════════════════════════════════════════════
# Inférence GraphSAGE
# ═══════════════════════════════════════════════════════════════════════════════
fleet_score > FLEET_SCORE_THRESHOLD → IP appartient à une flotte suspectée.
fleet_score = 0 pour toutes les autres IPs.
def _infer_embeddings(node_features: np.ndarray, edge_index: np.ndarray,
n_nodes: int, batch_threshold: int = FLEET_BATCH_THRESHOLD):
"""Passe le graphe dans GraphSAGE pour obtenir les embeddings 32D par nœud.
fleet_score = community_size * graph_density / log2(n_asn + 2)
Utilise NeighborLoader (batching PyG) si n_nodes > batch_threshold.
Sinon inférence full-batch.
"""
result = build_fleet_graph(df)
if result is None:
return {}
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv
G, ja4_nodes, asn_nodes = result
class GraphSAGE(nn.Module):
def __init__(self, in_ch: int, hidden_ch: int, out_ch: int,
dropout: float):
super().__init__()
self.conv1 = SAGEConv(in_ch, hidden_ch)
self.conv2 = SAGEConv(hidden_ch, out_ch)
self.drop = nn.Dropout(dropout)
try:
import networkx as nx
from networkx.algorithms import bipartite
try:
from community import best_partition as louvain_partition
LOUVAIN_AVAILABLE = True
except ImportError:
LOUVAIN_AVAILABLE = False
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.drop(x)
x = self.conv2(x, edge_index)
return x
# Projection bipartite : graphe des JA4 partageant des ASN
G_ja4 = bipartite.weighted_projected_graph(G, ja4_nodes)
if G_ja4.number_of_edges() == 0:
return {}
in_dim = node_features.shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Détection de communautés
if LOUVAIN_AVAILABLE:
partition = louvain_partition(G_ja4, weight='weight', random_state=42)
# partition = {node: community_id}
communities: dict = {}
for node, cid in partition.items():
communities.setdefault(cid, set()).add(node)
model = GraphSAGE(in_dim, FLEET_HIDDEN_DIM, FLEET_EMBED_DIM,
FLEET_DROPOUT).to(device)
model.eval()
x = torch.tensor(node_features, dtype=torch.float32, device=device)
ei = torch.tensor(edge_index, dtype=torch.long, device=device)
with torch.no_grad():
if n_nodes > batch_threshold:
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
data = Data(x=x, edge_index=ei)
loader = NeighborLoader(
data,
num_neighbors=[25, 10], # échantillonnage 2-hop
batch_size=4096,
shuffle=False,
)
embeddings = torch.zeros(n_nodes, FLEET_EMBED_DIM, device=device)
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
embeddings[batch.n_id[:batch.batch_size]] = out[:batch.batch_size]
else:
# Fallback : composantes connexes
communities = {
i: set(c)
for i, c in enumerate(nx.connected_components(G_ja4))
if len(c) >= 2
}
embeddings = model(x, ei)
# Calculer le fleet_score de chaque communauté
fleet_scores: dict = {} # {ja4: fleet_score}
for cid, members in communities.items():
if len(members) < 2:
continue
sub_g = G.subgraph(
list(members) + [n for n in asn_nodes if any(G.has_edge(n, m) for m in members)]
)
n_asn = len([n for n in sub_g.nodes if n.startswith('asn:')])
density = nx.density(G_ja4.subgraph(members))
score = len(members) * density / max(np.log2(n_asn + 2), 0.1)
for ja4_node in members:
ja4 = ja4_node.replace('ja4:', '')
fleet_scores[ja4] = round(float(score), 3)
return embeddings.cpu().numpy()
# Mapper les fleet_scores sur les IPs
if not fleet_scores:
return {}
ip_scores: dict = {}
for _, row in df.iterrows():
ja4 = str(row.get('ja4', ''))
score = fleet_scores.get(ja4, 0.0)
if score >= FLEET_SCORE_THRESHOLD:
src_ip = str(row.get('src_ip', ''))
if src_ip:
ip_scores[src_ip] = max(ip_scores.get(src_ip, 0.0), score)
# ═══════════════════════════════════════════════════════════════════════════════
# Pipeline de détection
# ═══════════════════════════════════════════════════════════════════════════════
n_fleets = len(set(fleet_scores.values()))
if ip_scores:
logger.info(
f"[Fleet] {len(ip_scores)} IPs dans {n_fleets} communauté(s) suspecte(s) "
f"(score max={max(ip_scores.values()):.2f})"
)
return ip_scores
def detect_fleet_communities(df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:
"""Détecte les flottes via GraphSAGE + DBSCAN/HDBSCAN.
except Exception as e:
logger.warning(f"[Fleet] Erreur détection de flotte : {e}")
Retourne {src_ip: {"cluster_id": int, "fleet_score": float}}.
Les IPs non assignées (bruit DBSCAN) n'apparaissent pas dans le dict.
"""
graph = _build_ip_graph(df)
if graph is None:
return {}
unique_ips, node_features, edge_index = graph
n_nodes = len(unique_ips)
# Vérifier les dépendances PyTorch Geometric
try:
import torch # noqa: F401
from torch_geometric.nn import SAGEConv # noqa: F401
except ImportError:
logger.warning(
"[Fleet] torch/torch_geometric non disponible — "
"analyse de flotte désactivée."
)
return {}
# ── Inférence GraphSAGE ──
try:
embeddings = _infer_embeddings(node_features, edge_index, n_nodes)
except Exception as e:
logger.warning(f"[Fleet] Erreur GraphSAGE : {e}")
return {}
# ── Clustering sur les embeddings 32D ──
try:
import hdbscan
labels = hdbscan.HDBSCAN(
min_cluster_size=3,
metric='euclidean',
cluster_selection_method='eom',
).fit_predict(embeddings)
except ImportError:
from sklearn.cluster import DBSCAN
labels = DBSCAN(eps=0.5, min_samples=3, metric='euclidean').fit_predict(
embeddings
)
except Exception as e:
logger.warning(f"[Fleet] Erreur clustering : {e}")
return {}
# ── Calcul du fleet_score par cluster ──
ip_asn = (
df.groupby('src_ip')['asn_number']
.agg(lambda s: set(s.dropna().unique()))
.to_dict()
)
cluster_scores: Dict[int, float] = {}
for cid in set(labels) - {-1}:
members_idx = np.where(labels == cid)[0]
cluster_ips = unique_ips[members_idx]
cluster_size = len(members_idx)
# Nombre d'ASN distincts dans le cluster
n_asn = len(
set().union(*(ip_asn.get(str(ip), set()) for ip in cluster_ips))
)
# Compacité : inverse de la distance moyenne au centroïde
centroid = embeddings[members_idx].mean(axis=0)
avg_dist = np.linalg.norm(
embeddings[members_idx] - centroid, axis=1
).mean()
compactness = 1.0 / (1.0 + avg_dist)
# fleet_score = taille × compacité / log2(n_asn + 2)
score = cluster_size * compactness / max(np.log2(n_asn + 2), 0.1)
cluster_scores[int(cid)] = round(float(score), 3)
# ── Construction du résultat ──
ip_results: Dict[str, Dict[str, Any]] = {}
for i, ip in enumerate(unique_ips):
cid = int(labels[i])
score = cluster_scores.get(cid, 0.0)
if score > 0:
ip_results[str(ip)] = {"cluster_id": cid, "fleet_score": score}
n_flagged = sum(
1 for v in ip_results.values()
if v["fleet_score"] >= FLEET_SCORE_THRESHOLD
)
if ip_results:
max_score = max(v["fleet_score"] for v in ip_results.values())
logger.info(
f"[Fleet] {n_flagged} IPs dans {len(cluster_scores)} flotte(s) "
f"(score max={max_score:.2f}, {n_nodes} noeuds, "
f"{edge_index.shape[1]} arêtes)"
)
return ip_results
def enrich_with_fleet_score(df: pd.DataFrame) -> pd.DataFrame:
"""Enrichit le DataFrame avec fleet_score et fleet_campaign_flag.
fleet_campaign_flag = 1 si l'IP appartient à une flotte suspectée.
fleet_score = score de la communauté (0 = pas de flotte).
fleet_campaign_flag = 1 si fleet_score >= FLEET_SCORE_THRESHOLD.
fleet_score = 0 pour les IPs n'appartenant à aucune flotte détectée.
"""
df = df.copy()
fleet_map = detect_fleet_communities(df)
df['fleet_score'] = df['src_ip'].map(fleet_map).fillna(0.0).astype(float)
df['fleet_campaign_flag'] = (df['fleet_score'] >= FLEET_SCORE_THRESHOLD).astype(int)
df['fleet_score'] = df['src_ip'].map(
{ip: v["fleet_score"] for ip, v in fleet_map.items()}
).fillna(0.0).astype(float)
df['fleet_campaign_flag'] = (
df['fleet_score'] >= FLEET_SCORE_THRESHOLD
).astype(int)
return df

View File

@ -6,6 +6,7 @@ scipy>=1.14
hdbscan>=0.8.38
isotree>=0.6.1
torch>=2.0
torch_geometric>=2.4
FrEIA>=0.2
xgboost>=2.0
cleanlab>=2.6