""" Clustering d'IPs multi-métriques — backend ReactFlow. Features utilisées (21 dimensions) : TCP stack : TTL initial, MSS, scale, fenêtre TCP Comportement : vélocité, POST ratio, fuzzing, assets, accès direct Anomalie ML : score, IP-ID zéro TLS/Protocole: ALPN mismatch, ALPN absent, efficacité H2 Navigateur : browser score, headless, ordre headers, UA-CH mismatch Temporel : entropie, diversité JA4, UA rotatif Algorithme : 1. Échantillonnage stratifié (top détections + top hits) 2. Construction + normalisation des vecteurs de features 3. K-means++ (Arthur & Vassilvitskii, 2007) 4. PCA-2D par power iteration pour les positions ReactFlow 5. Nommage automatique par features dominantes du centroïde 6. Calcul des arêtes : k-NN dans l'espace des features """ from __future__ import annotations import math import time import hashlib from typing import Optional from fastapi import APIRouter, HTTPException, Query from ..database import db from ..services.clustering_engine import ( FEATURES, FEATURE_KEYS, FEATURE_NORMS, FEATURE_NAMES, N_FEATURES, build_feature_vector, kmeans_pp, pca_2d, name_cluster, risk_score_from_centroid, _mean_vec, ) router = APIRouter(prefix="/api/clustering", tags=["clustering"]) # ─── Cache en mémoire ───────────────────────────────────────────────────────── # Stocke (cluster_id → liste d'IPs) pour le drill-down # + timestamp de dernière mise à jour _cache: dict = { "assignments": {}, # ip+ja4 → cluster_idx "cluster_ips": {}, # cluster_idx → [(ip, ja4)] "params": {}, # k, ts } # ─── Couleurs ───────────────────────────────────────────────────────────────── _THREAT_COLOR = { 0.92: "#dc2626", # Bot scanner 0.70: "#ef4444", # Critique 0.45: "#f97316", # Élevé 0.25: "#eab308", # Modéré 0.00: "#6b7280", # Sain / inconnu } def _risk_to_color(risk: float) -> str: for threshold, color in sorted(_THREAT_COLOR.items(), reverse=True): if risk >= threshold: return color return "#6b7280" # ─── SQL ────────────────────────────────────────────────────────────────────── _SQL_FEATURES = """ SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip, t.ja4, any(t.tcp_ttl_raw) AS ttl, any(t.tcp_win_raw) AS win, any(t.tcp_scale_raw) AS scale, any(t.tcp_mss_raw) AS mss, any(t.first_ua) AS ua, sum(t.hits) AS hits, avg(abs(ml.anomaly_score)) AS avg_score, avg(ml.hit_velocity) AS avg_velocity, avg(ml.fuzzing_index) AS avg_fuzzing, avg(ml.is_headless) AS pct_headless, avg(ml.post_ratio) AS avg_post, avg(ml.ip_id_zero_ratio) AS ip_id_zero, avg(ml.temporal_entropy) AS entropy, avg(ml.modern_browser_score) AS browser_score, avg(ml.alpn_http_mismatch) AS alpn_mismatch, avg(ml.is_alpn_missing) AS alpn_missing, avg(ml.multiplexing_efficiency) AS h2_eff, avg(ml.header_order_confidence) AS hdr_conf, avg(ml.ua_ch_mismatch) AS ua_ch_mismatch, avg(ml.asset_ratio) AS asset_ratio, avg(ml.direct_access_ratio) AS direct_ratio, avg(ml.distinct_ja4_count) AS ja4_count, max(ml.is_ua_rotating) AS ua_rotating, max(ml.threat_level) AS threat, any(ml.country_code) AS country, any(ml.asn_org) AS asn_org FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 AND ml.detected_at >= now() - INTERVAL 24 HOUR WHERE t.window_start >= now() - INTERVAL 24 HOUR AND t.tcp_ttl_raw > 0 GROUP BY t.src_ip, t.ja4 ORDER BY -- Stratégie : IPs anormales en premier, puis fort trafic -- Cela garantit que les bots Masscan (anomalie=0.97, hits=1-2) sont inclus avg(abs(ml.anomaly_score)) DESC, sum(t.hits) DESC LIMIT %(limit)s """ # Noms des colonnes SQL dans l'ordre _SQL_COLS = [ "ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits", "avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post", "ip_id_zero", "entropy", "browser_score", "alpn_mismatch", "alpn_missing", "h2_eff", "hdr_conf", "ua_ch_mismatch", "asset_ratio", "direct_ratio", "ja4_count", "ua_rotating", "threat", "country", "asn_org", ] # ─── Endpoints ──────────────────────────────────────────────────────────────── @router.get("/clusters") async def get_clusters( k: int = Query(14, ge=4, le=30, description="Nombre de clusters"), n_samples: int = Query(3000, ge=500, le=8000, description="Taille de l'échantillon"), ): """ Clustering multi-métriques des IPs. Retourne les nœuds (clusters) + arêtes pour ReactFlow, avec : - positions 2D issues de PCA sur les 21 features - profil radar des features par cluster (normalisé [0,1]) - statistiques détaillées (moyennes brutes des features) - sample d'IPs représentatives """ t0 = time.time() try: result = db.query(_SQL_FEATURES, {"limit": n_samples}) except Exception as e: raise HTTPException(status_code=500, detail=f"ClickHouse: {e}") # ── Construction des vecteurs de features ───────────────────────────── rows: list[dict] = [] for row in result.result_rows: d = {col: row[i] for i, col in enumerate(_SQL_COLS)} rows.append(d) if len(rows) < k: raise HTTPException(status_code=400, detail="Pas assez de données pour ce k") points = [build_feature_vector(r) for r in rows] # ── K-means++ ──────────────────────────────────────────────────────── km = kmeans_pp(points, k=k, max_iter=60, seed=42) # ── PCA-2D sur les centroïdes ───────────────────────────────────────── # On projette les centroïdes dans l'espace PCA des données # → les positions relatives reflètent la variance des données coords_all = pca_2d(points) # Moyenne des positions PCA par cluster = position 2D du centroïde cluster_xs: list[list[float]] = [[] for _ in range(k)] cluster_ys: list[list[float]] = [[] for _ in range(k)] for i, label in enumerate(km.labels): cluster_xs[label].append(coords_all[i][0]) cluster_ys[label].append(coords_all[i][1]) centroid_2d: list[tuple[float, float]] = [] for j in range(k): if cluster_xs[j]: cx = sum(cluster_xs[j]) / len(cluster_xs[j]) cy = sum(cluster_ys[j]) / len(cluster_ys[j]) else: cx, cy = 0.5, 0.5 centroid_2d.append((cx, cy)) # ── Agrégation des statistiques par cluster ─────────────────────────── cluster_rows: list[list[dict]] = [[] for _ in range(k)] cluster_members: list[list[tuple[str, str]]] = [[] for _ in range(k)] for i, label in enumerate(km.labels): cluster_rows[label].append(rows[i]) cluster_members[label].append((rows[i]["ip"], rows[i]["ja4"])) # Mise à jour du cache pour le drill-down _cache["cluster_ips"] = {j: cluster_members[j] for j in range(k)} _cache["params"] = {"k": k, "ts": t0} # ── Construction des nœuds ReactFlow ───────────────────────────────── CANVAS_W, CANVAS_H = 1400, 780 nodes = [] for j in range(k): if not cluster_rows[j]: continue # Statistiques brutes moyennées def avg_feat(key: str) -> float: vals = [float(r.get(key) or 0) for r in cluster_rows[j]] return sum(vals) / len(vals) if vals else 0.0 mean_ttl = avg_feat("ttl") mean_mss = avg_feat("mss") mean_scale = avg_feat("scale") mean_win = avg_feat("win") mean_score = avg_feat("avg_score") mean_vel = avg_feat("avg_velocity") mean_fuzz = avg_feat("avg_fuzzing") mean_hless = avg_feat("pct_headless") mean_post = avg_feat("avg_post") mean_asset = avg_feat("asset_ratio") mean_direct= avg_feat("direct_ratio") mean_alpn = avg_feat("alpn_mismatch") mean_h2 = avg_feat("h2_eff") mean_hconf = avg_feat("hdr_conf") mean_ua_ch = avg_feat("ua_ch_mismatch") mean_entr = avg_feat("entropy") mean_ja4 = avg_feat("ja4_count") mean_ip_id = avg_feat("ip_id_zero") mean_brow = avg_feat("browser_score") mean_uarot = avg_feat("ua_rotating") ip_count = len(set(r["ip"] for r in cluster_rows[j])) hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j])) # Pays / ASN / Menace dominants threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")] countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")] orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")] def topk(lst: list[str], n: int = 5) -> list[str]: from collections import Counter return [v for v, _ in Counter(lst).most_common(n) if v] raw_stats = { "mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale, } label = name_cluster(km.centroids[j], raw_stats) risk = risk_score_from_centroid(km.centroids[j]) color = _risk_to_color(risk) # Profil radar normalisé (valeurs centroïde [0,1]) radar = [ {"feature": name, "value": round(km.centroids[j][i], 4)} for i, name in enumerate(FEATURE_NAMES) ] # Position 2D (PCA normalisée → pixels ReactFlow) px_x = centroid_2d[j][0] * CANVAS_W * 0.85 + 80 px_y = (1 - centroid_2d[j][1]) * CANVAS_H * 0.85 + 50 # inverser y (haut=risque) # Rayon ∝ √ip_count radius = max(18, min(90, int(math.sqrt(ip_count) * 0.3))) # Sample IPs (top 8 par hits) sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8] sample_ips = [r["ip"] for r in sample_rows] sample_ua = str(cluster_rows[j][0].get("ua") or "") cluster_id = f"c{j}_k{k}" nodes.append({ "id": cluster_id, "label": label, "cluster_idx": j, "x": round(px_x, 1), "y": round(px_y, 1), "radius": radius, "color": color, "risk_score": risk, # Caractéristiques TCP "mean_ttl": round(mean_ttl, 1), "mean_mss": round(mean_mss, 0), "mean_scale": round(mean_scale, 1), "mean_win": round(mean_win, 0), # Comportement HTTP "mean_score": round(mean_score, 4), "mean_velocity": round(mean_vel, 3), "mean_fuzzing": round(mean_fuzz, 3), "mean_headless": round(mean_hless, 3), "mean_post": round(mean_post, 3), "mean_asset": round(mean_asset, 3), "mean_direct": round(mean_direct, 3), # TLS / Protocole "mean_alpn_mismatch": round(mean_alpn, 3), "mean_h2_eff": round(mean_h2, 3), "mean_hdr_conf": round(mean_hconf, 3), "mean_ua_ch": round(mean_ua_ch, 3), # Temporel "mean_entropy": round(mean_entr, 3), "mean_ja4_diversity": round(mean_ja4, 3), "mean_ip_id_zero": round(mean_ip_id, 3), "mean_browser_score": round(mean_brow, 1), "mean_ua_rotating": round(mean_uarot, 3), # Meta "ip_count": ip_count, "hit_count": hit_count, "top_threat": topk(threats, 1)[0] if topk(threats, 1) else "", "top_countries": topk(countries, 5), "top_orgs": topk(orgs, 5), "sample_ips": sample_ips, "sample_ua": sample_ua, # Profil radar pour visualisation "radar": radar, }) # ── Arêtes : k-NN dans l'espace des features ────────────────────────── # Chaque cluster est connecté à ses 2 voisins les plus proches edges = [] seen: set[frozenset] = set() centroids = km.centroids for i, ni in enumerate(nodes): ci = ni["cluster_idx"] # Distance² aux autres centroïdes dists = [ (j, nj["cluster_idx"], sum((centroids[ci][d] - centroids[nj["cluster_idx"]][d]) ** 2 for d in range(N_FEATURES))) for j, nj in enumerate(nodes) if j != i ] dists.sort(key=lambda x: x[2]) # 2 voisins les plus proches for j, cj, dist2 in dists[:2]: key = frozenset([ni["id"], nodes[j]["id"]]) if key in seen: continue seen.add(key) similarity = round(1.0 / (1.0 + math.sqrt(dist2)), 3) edges.append({ "id": f"e_{ni['id']}_{nodes[j]['id']}", "source": ni["id"], "target": nodes[j]["id"], "similarity": similarity, "weight": round(similarity * 5, 1), }) # ── Stats globales ──────────────────────────────────────────────────── total_ips = sum(n["ip_count"] for n in nodes) total_hits = sum(n["hit_count"] for n in nodes) bot_ips = sum(n["ip_count"] for n in nodes if n["risk_score"] > 0.40 or "🤖" in n["label"]) high_risk = sum(n["ip_count"] for n in nodes if n["risk_score"] > 0.20) elapsed = round(time.time() - t0, 2) return { "nodes": nodes, "edges": edges, "stats": { "total_clusters": len(nodes), "total_ips": total_ips, "total_hits": total_hits, "bot_ips": bot_ips, "high_risk_ips": high_risk, "n_samples": len(rows), "k": k, "elapsed_s": elapsed, }, "feature_names": FEATURE_NAMES, } @router.get("/cluster/{cluster_id}/ips") async def get_cluster_ips( cluster_id: str, limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): """ IPs appartenant à un cluster (depuis le cache de la dernière exécution). Si le cache est expiré, retourne une erreur guidant vers /clusters. """ if not _cache.get("cluster_ips"): raise HTTPException( status_code=404, detail="Cache expiré — appelez /api/clustering/clusters d'abord" ) # Extrait l'index cluster depuis l'id (format: c{idx}_k{k}) try: idx = int(cluster_id.split("_")[0][1:]) except (ValueError, IndexError): raise HTTPException(status_code=400, detail="cluster_id invalide") members = _cache["cluster_ips"].get(idx, []) if not members: return {"ips": [], "total": 0, "cluster_id": cluster_id} total = len(members) page_members = members[offset: offset + limit] # Requête SQL pour les détails de ces IPs spécifiques ip_list = [m[0] for m in page_members] ja4_list = [m[1] for m in page_members] if not ip_list: return {"ips": [], "total": total, "cluster_id": cluster_id} # On ne peut pas facilement passer une liste en paramètre ClickHouse — # on la construit directement (valeurs nettoyées) safe_ips = [ip.replace("'", "") for ip in ip_list[:100]] ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips) sql = f""" SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip, t.ja4, any(t.tcp_ttl_raw) AS ttl, any(t.tcp_win_raw) AS win, any(t.tcp_scale_raw) AS scale, any(t.tcp_mss_raw) AS mss, sum(t.hits) AS hits, any(t.first_ua) AS ua, round(avg(abs(ml.anomaly_score)), 3) AS avg_score, max(ml.threat_level) AS threat_level, any(ml.country_code) AS country_code, any(ml.asn_org) AS asn_org, round(avg(ml.fuzzing_index), 2) AS fuzzing, round(avg(ml.hit_velocity), 2) AS velocity FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 AND ml.detected_at >= now() - INTERVAL 24 HOUR WHERE t.window_start >= now() - INTERVAL 24 HOUR AND replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') IN ({ip_filter}) GROUP BY t.src_ip, t.ja4 ORDER BY hits DESC """ try: result = db.query(sql) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) ips = [] for row in result.result_rows: ips.append({ "ip": str(row[0]), "ja4": str(row[1] or ""), "tcp_ttl": int(row[2] or 0), "tcp_win": int(row[3] or 0), "tcp_scale": int(row[4] or 0), "tcp_mss": int(row[5] or 0), "hits": int(row[6] or 0), "ua": str(row[7] or ""), "avg_score": float(row[8] or 0), "threat_level": str(row[9] or ""), "country_code": str(row[10] or ""), "asn_org": str(row[11] or ""), "fuzzing": float(row[12] or 0), "velocity": float(row[13] or 0), }) return {"ips": ips, "total": total, "cluster_id": cluster_id}