Rewrite fleet.py to use a GNN-based approach: nodes are src_ip with ML feature vectors, edges connect IPs sharing (JA4, ASN) pairs, GraphSAGE (2 SAGEConv layers, in→64→32) produces 32D embeddings clustered by HDBSCAN. PyG NeighborLoader activates for >50k nodes. Update thesis docs (§5.2, §6.4, §2, §8) to reflect GraphSAGE architecture and PyG scalability. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
292 lines
11 KiB
Python
292 lines
11 KiB
Python
"""Détection de flottes de bots par apprentissage de représentations de graphe (GraphSAGE).
|
||
|
||
§5.2 — GNN GraphSAGE sur graphe d'IPs pour identifier les flottes coordonnées
|
||
qui font tourner leurs fingerprints JA4 et ASN.
|
||
|
||
Algorithme :
|
||
1. Agréger les features ML par src_ip (nœuds du graphe)
|
||
2. Construire le graphe : arête entre IP_A et IP_B si elles partagent
|
||
même JA4 + même ASN dans un groupe de >= min_ips IPs
|
||
3. Inférence GraphSAGE (SAGEConv × 2 : in_channels → 64 → 32, ReLU + Dropout)
|
||
4. Clustering DBSCAN/HDBSCAN sur les embeddings 32D
|
||
5. fleet_score = cluster_size × compactness / log2(n_asn + 2)
|
||
"""
|
||
import logging
|
||
from typing import Dict, Any
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ── Configuration via variables d'environnement ──────────────────────────────
|
||
FLEET_SCORE_THRESHOLD = float(__import__('os').getenv('FLEET_SCORE_THRESHOLD', '2.0'))
|
||
FLEET_SCORE_WEIGHT = float(__import__('os').getenv('FLEET_SCORE_WEIGHT', '0.10'))
|
||
FLEET_MIN_IPS = int(__import__('os').getenv('FLEET_MIN_IPS', '3'))
|
||
FLEET_BATCH_THRESHOLD = int(__import__('os').getenv('FLEET_BATCH_THRESHOLD', '50000'))
|
||
FLEET_HIDDEN_DIM = int(__import__('os').getenv('FLEET_HIDDEN_DIM', '64'))
|
||
FLEET_EMBED_DIM = int(__import__('os').getenv('FLEET_EMBED_DIM', '32'))
|
||
FLEET_DROPOUT = float(__import__('os').getenv('FLEET_DROPOUT', '0.1'))
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# Construction du graphe d'IPs
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
def _build_ip_graph(df: pd.DataFrame, min_ips: int = FLEET_MIN_IPS):
|
||
"""Construit le graphe d'IPs pour l'inférence GraphSAGE.
|
||
|
||
Nœuds : src_ip uniques, features = vecteur ML agrégé par IP (moyenne).
|
||
Arêtes : (IP_A, IP_B) si elles co-occurent dans un groupe (JA4, ASN)
|
||
comptant >= min_ips IPs distinctes.
|
||
|
||
Retourne (unique_ips, node_features, edge_index) ou None si données
|
||
insuffisantes.
|
||
"""
|
||
if df.empty or 'ja4' not in df.columns or 'asn_number' not in df.columns:
|
||
return None
|
||
|
||
# Filtrer JA4 vides et ASN 0
|
||
mask = (df['ja4'].fillna('') != '') & (df['asn_number'].fillna('0') != '0')
|
||
sub = df[mask].copy()
|
||
if len(sub) < 10:
|
||
return None
|
||
|
||
# Index des IPs uniques
|
||
unique_ips = sub['src_ip'].unique()
|
||
if len(unique_ips) < 5:
|
||
return None
|
||
ip_to_idx = {ip: i for i, ip in enumerate(unique_ips)}
|
||
|
||
# ── Features par nœud : agrégation des colonnes numériques ──
|
||
numeric_cols = sub.select_dtypes(include=[np.number]).columns.tolist()
|
||
skip_cols = {'asn_number', 'fleet_campaign_flag'}
|
||
feature_cols = [c for c in numeric_cols if c not in skip_cols]
|
||
if not feature_cols:
|
||
return None
|
||
|
||
node_features = (
|
||
sub.groupby('src_ip')[feature_cols]
|
||
.mean()
|
||
.reindex(unique_ips)
|
||
.fillna(0)
|
||
.values
|
||
.astype(np.float32)
|
||
)
|
||
# Normalisation min-max par colonne
|
||
mins = node_features.min(axis=0, keepdims=True)
|
||
maxs = node_features.max(axis=0, keepdims=True)
|
||
ranges = np.where(maxs - mins == 0, 1.0, maxs - mins)
|
||
node_features = (node_features - mins) / ranges
|
||
|
||
# ── Construction des arêtes : co-occurrence (JA4, ASN) ──
|
||
groups = (
|
||
sub.groupby(['ja4', 'asn_number'])['src_ip']
|
||
.agg(lambda s: list(s.unique()))
|
||
.reset_index(name='ips')
|
||
)
|
||
groups = groups[groups['ips'].map(len) >= min_ips]
|
||
if groups.empty:
|
||
return None
|
||
|
||
edge_set: set = set()
|
||
for ips in groups['ips']:
|
||
idx = sorted(ip_to_idx[ip] for ip in ips if ip in ip_to_idx)
|
||
n = len(idx)
|
||
if n < 2 or n > 500: # Sauter les très grands groupes (CDN/infra)
|
||
continue
|
||
for a in range(n):
|
||
for b in range(a + 1, n):
|
||
edge_set.add((idx[a], idx[b]))
|
||
edge_set.add((idx[b], idx[a]))
|
||
|
||
if len(edge_set) < 5:
|
||
return None
|
||
|
||
edge_index = np.array(sorted(edge_set), dtype=np.int64).T # [2, num_edges]
|
||
return unique_ips, node_features, edge_index
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# Inférence GraphSAGE
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
def _infer_embeddings(node_features: np.ndarray, edge_index: np.ndarray,
|
||
n_nodes: int, batch_threshold: int = FLEET_BATCH_THRESHOLD):
|
||
"""Passe le graphe dans GraphSAGE pour obtenir les embeddings 32D par nœud.
|
||
|
||
Utilise NeighborLoader (batching PyG) si n_nodes > batch_threshold.
|
||
Sinon inférence full-batch.
|
||
"""
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch_geometric.nn import SAGEConv
|
||
|
||
class GraphSAGE(nn.Module):
|
||
def __init__(self, in_ch: int, hidden_ch: int, out_ch: int,
|
||
dropout: float):
|
||
super().__init__()
|
||
self.conv1 = SAGEConv(in_ch, hidden_ch)
|
||
self.conv2 = SAGEConv(hidden_ch, out_ch)
|
||
self.drop = nn.Dropout(dropout)
|
||
|
||
def forward(self, x, edge_index):
|
||
x = self.conv1(x, edge_index).relu()
|
||
x = self.drop(x)
|
||
x = self.conv2(x, edge_index)
|
||
return x
|
||
|
||
in_dim = node_features.shape[1]
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
model = GraphSAGE(in_dim, FLEET_HIDDEN_DIM, FLEET_EMBED_DIM,
|
||
FLEET_DROPOUT).to(device)
|
||
model.eval()
|
||
|
||
x = torch.tensor(node_features, dtype=torch.float32, device=device)
|
||
ei = torch.tensor(edge_index, dtype=torch.long, device=device)
|
||
|
||
with torch.no_grad():
|
||
if n_nodes > batch_threshold:
|
||
from torch_geometric.data import Data
|
||
from torch_geometric.loader import NeighborLoader
|
||
|
||
data = Data(x=x, edge_index=ei)
|
||
loader = NeighborLoader(
|
||
data,
|
||
num_neighbors=[25, 10], # échantillonnage 2-hop
|
||
batch_size=4096,
|
||
shuffle=False,
|
||
)
|
||
embeddings = torch.zeros(n_nodes, FLEET_EMBED_DIM, device=device)
|
||
for batch in loader:
|
||
batch = batch.to(device)
|
||
out = model(batch.x, batch.edge_index)
|
||
embeddings[batch.n_id[:batch.batch_size]] = out[:batch.batch_size]
|
||
else:
|
||
embeddings = model(x, ei)
|
||
|
||
return embeddings.cpu().numpy()
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
# Pipeline de détection
|
||
# ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
def detect_fleet_communities(df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:
|
||
"""Détecte les flottes via GraphSAGE + DBSCAN/HDBSCAN.
|
||
|
||
Retourne {src_ip: {"cluster_id": int, "fleet_score": float}}.
|
||
Les IPs non assignées (bruit DBSCAN) n'apparaissent pas dans le dict.
|
||
"""
|
||
graph = _build_ip_graph(df)
|
||
if graph is None:
|
||
return {}
|
||
|
||
unique_ips, node_features, edge_index = graph
|
||
n_nodes = len(unique_ips)
|
||
|
||
# Vérifier les dépendances PyTorch Geometric
|
||
try:
|
||
import torch # noqa: F401
|
||
from torch_geometric.nn import SAGEConv # noqa: F401
|
||
except ImportError:
|
||
logger.warning(
|
||
"[Fleet] torch/torch_geometric non disponible — "
|
||
"analyse de flotte désactivée."
|
||
)
|
||
return {}
|
||
|
||
# ── Inférence GraphSAGE ──
|
||
try:
|
||
embeddings = _infer_embeddings(node_features, edge_index, n_nodes)
|
||
except Exception as e:
|
||
logger.warning(f"[Fleet] Erreur GraphSAGE : {e}")
|
||
return {}
|
||
|
||
# ── Clustering sur les embeddings 32D ──
|
||
try:
|
||
import hdbscan
|
||
labels = hdbscan.HDBSCAN(
|
||
min_cluster_size=3,
|
||
metric='euclidean',
|
||
cluster_selection_method='eom',
|
||
).fit_predict(embeddings)
|
||
except ImportError:
|
||
from sklearn.cluster import DBSCAN
|
||
labels = DBSCAN(eps=0.5, min_samples=3, metric='euclidean').fit_predict(
|
||
embeddings
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"[Fleet] Erreur clustering : {e}")
|
||
return {}
|
||
|
||
# ── Calcul du fleet_score par cluster ──
|
||
ip_asn = (
|
||
df.groupby('src_ip')['asn_number']
|
||
.agg(lambda s: set(s.dropna().unique()))
|
||
.to_dict()
|
||
)
|
||
|
||
cluster_scores: Dict[int, float] = {}
|
||
for cid in set(labels) - {-1}:
|
||
members_idx = np.where(labels == cid)[0]
|
||
cluster_ips = unique_ips[members_idx]
|
||
cluster_size = len(members_idx)
|
||
|
||
# Nombre d'ASN distincts dans le cluster
|
||
n_asn = len(
|
||
set().union(*(ip_asn.get(str(ip), set()) for ip in cluster_ips))
|
||
)
|
||
|
||
# Compacité : inverse de la distance moyenne au centroïde
|
||
centroid = embeddings[members_idx].mean(axis=0)
|
||
avg_dist = np.linalg.norm(
|
||
embeddings[members_idx] - centroid, axis=1
|
||
).mean()
|
||
compactness = 1.0 / (1.0 + avg_dist)
|
||
|
||
# fleet_score = taille × compacité / log2(n_asn + 2)
|
||
score = cluster_size * compactness / max(np.log2(n_asn + 2), 0.1)
|
||
cluster_scores[int(cid)] = round(float(score), 3)
|
||
|
||
# ── Construction du résultat ──
|
||
ip_results: Dict[str, Dict[str, Any]] = {}
|
||
for i, ip in enumerate(unique_ips):
|
||
cid = int(labels[i])
|
||
score = cluster_scores.get(cid, 0.0)
|
||
if score > 0:
|
||
ip_results[str(ip)] = {"cluster_id": cid, "fleet_score": score}
|
||
|
||
n_flagged = sum(
|
||
1 for v in ip_results.values()
|
||
if v["fleet_score"] >= FLEET_SCORE_THRESHOLD
|
||
)
|
||
if ip_results:
|
||
max_score = max(v["fleet_score"] for v in ip_results.values())
|
||
logger.info(
|
||
f"[Fleet] {n_flagged} IPs dans {len(cluster_scores)} flotte(s) "
|
||
f"(score max={max_score:.2f}, {n_nodes} noeuds, "
|
||
f"{edge_index.shape[1]} arêtes)"
|
||
)
|
||
|
||
return ip_results
|
||
|
||
|
||
def enrich_with_fleet_score(df: pd.DataFrame) -> pd.DataFrame:
|
||
"""Enrichit le DataFrame avec fleet_score et fleet_campaign_flag.
|
||
|
||
fleet_campaign_flag = 1 si fleet_score >= FLEET_SCORE_THRESHOLD.
|
||
fleet_score = 0 pour les IPs n'appartenant à aucune flotte détectée.
|
||
"""
|
||
df = df.copy()
|
||
fleet_map = detect_fleet_communities(df)
|
||
|
||
df['fleet_score'] = df['src_ip'].map(
|
||
{ip: v["fleet_score"] for ip, v in fleet_map.items()}
|
||
).fillna(0.0).astype(float)
|
||
df['fleet_campaign_flag'] = (
|
||
df['fleet_score'] >= FLEET_SCORE_THRESHOLD
|
||
).astype(int)
|
||
return df
|