""" Clustering d'IPs multi-métriques — WebGL / deck.gl backend. - Calcul sur la TOTALITÉ des IPs (GROUP BY src_ip, ja4 sans LIMIT) - K-means++ vectorisé (numpy) + PCA-2D + enveloppes convexes (scipy) - Calcul en background thread + cache 30 min - Endpoints : /clusters, /status, /cluster/{id}/points """ from __future__ import annotations import math import time import logging import threading from collections import Counter from concurrent.futures import ThreadPoolExecutor from typing import Optional, Any import numpy as np from fastapi import APIRouter, HTTPException, Query from ..database import db from ..services.clustering_engine import ( FEATURE_KEYS, FEATURE_NAMES, FEATURE_NORMS, N_FEATURES, build_feature_vector, kmeans_pp, pca_2d, compute_hulls, name_cluster, risk_score_from_centroid, standardize, ) log = logging.getLogger(__name__) router = APIRouter(prefix="/api/clustering", tags=["clustering"]) # ─── Cache global ────────────────────────────────────────────────────────────── _CACHE: dict[str, Any] = { "status": "idle", # idle | computing | ready | error "error": None, "result": None, # dict résultat complet "ts": 0.0, # timestamp dernière mise à jour "params": {}, "cluster_ips": {}, # cluster_idx → [(ip, ja4, pca_x, pca_y, risk)] } _CACHE_TTL = 1800 # 30 minutes _LOCK = threading.Lock() _EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="clustering") # ─── Couleurs menace ────────────────────────────────────────────────────────── _THREAT_COLOR = { 0.70: "#dc2626", # Critique 0.45: "#f97316", # Élevé 0.25: "#eab308", # Modéré 0.00: "#22c55e", # Sain } def _risk_to_color(risk: float) -> str: for threshold, color in sorted(_THREAT_COLOR.items(), reverse=True): if risk >= threshold: return color return "#6b7280" # ─── SQL : TOUTES les IPs sans LIMIT ───────────────────────────────────────── _SQL_ALL_IPS = """ SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip, t.ja4, any(t.tcp_ttl_raw) AS ttl, any(t.tcp_win_raw) AS win, any(t.tcp_scale_raw) AS scale, any(t.tcp_mss_raw) AS mss, any(t.first_ua) AS ua, sum(t.hits) AS hits, avg(abs(ml.anomaly_score)) AS avg_score, avg(ml.hit_velocity) AS avg_velocity, avg(ml.fuzzing_index) AS avg_fuzzing, avg(ml.is_headless) AS pct_headless, avg(ml.post_ratio) AS avg_post, avg(ml.ip_id_zero_ratio) AS ip_id_zero, avg(ml.temporal_entropy) AS entropy, avg(ml.modern_browser_score) AS browser_score, avg(ml.alpn_http_mismatch) AS alpn_mismatch, avg(ml.is_alpn_missing) AS alpn_missing, avg(ml.multiplexing_efficiency) AS h2_eff, avg(ml.header_order_confidence) AS hdr_conf, avg(ml.ua_ch_mismatch) AS ua_ch_mismatch, avg(ml.asset_ratio) AS asset_ratio, avg(ml.direct_access_ratio) AS direct_ratio, avg(ml.distinct_ja4_count) AS ja4_count, max(ml.is_ua_rotating) AS ua_rotating, max(ml.threat_level) AS threat, any(ml.country_code) AS country, any(ml.asn_org) AS asn_org, -- Features headers HTTP (depuis view_dashboard_entities) avg(ml.has_accept_language) AS hdr_accept_lang, any(vh.hdr_enc) AS hdr_has_encoding, any(vh.hdr_sec_fetch) AS hdr_has_sec_fetch, any(vh.hdr_count) AS hdr_count_raw FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 AND ml.detected_at >= now() - INTERVAL %(hours)s HOUR LEFT JOIN ( SELECT toIPv6(concat('::ffff:', toString(src_ip))) AS src_ip_v6, ja4, any(arrayExists(x -> x LIKE '%%Accept-Encoding%%', client_headers)) AS hdr_enc, any(arrayExists(x -> x LIKE '%%Sec-Fetch%%', client_headers)) AS hdr_sec_fetch, any(length(splitByChar(',', client_headers[1]))) AS hdr_count FROM mabase_prod.view_dashboard_entities WHERE length(client_headers) > 0 AND log_date >= today() - 2 GROUP BY src_ip_v6, ja4 ) vh ON t.src_ip = vh.src_ip_v6 AND t.ja4 = vh.ja4 WHERE t.window_start >= now() - INTERVAL %(hours)s HOUR AND t.tcp_ttl_raw > 0 GROUP BY t.src_ip, t.ja4 """ _SQL_COLS = [ "ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits", "avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post", "ip_id_zero", "entropy", "browser_score", "alpn_mismatch", "alpn_missing", "h2_eff", "hdr_conf", "ua_ch_mismatch", "asset_ratio", "direct_ratio", "ja4_count", "ua_rotating", "threat", "country", "asn_org", "hdr_accept_lang", "hdr_has_encoding", "hdr_has_sec_fetch", "hdr_count_raw", ] # ─── Worker de clustering (thread pool) ────────────────────────────────────── def _run_clustering_job(k: int, hours: int, sensitivity: float = 1.0) -> None: """Exécuté dans le thread pool. Met à jour _CACHE. sensitivity : multiplicateur de k [0.5 – 3.0]. 1.0 = comportement par défaut 2.0 = deux fois plus de clusters → groupes plus homogènes 0.5 = moitié → vue très agrégée """ k_actual = max(4, min(50, round(k * sensitivity))) t0 = time.time() with _LOCK: _CACHE["status"] = "computing" _CACHE["error"] = None try: log.info(f"[clustering] Démarrage k={k_actual} (base={k}×sens={sensitivity}) hours={hours}") # ── 1. Chargement de toutes les IPs ────────────────────────────── result = db.query(_SQL_ALL_IPS, {"hours": hours}) rows: list[dict] = [] for row in result.result_rows: rows.append({col: row[i] for i, col in enumerate(_SQL_COLS)}) n = len(rows) log.info(f"[clustering] {n} IPs chargées") if n < k_actual: raise ValueError(f"Seulement {n} IPs disponibles (k={k_actual} requis)") # ── 2. Construction de la matrice de features (numpy) ──────────── X = np.array([build_feature_vector(r) for r in rows], dtype=np.float32) log.info(f"[clustering] Matrice X: {X.shape} — {X.nbytes/1024/1024:.1f} MB") # ── 3. Standardisation z-score ──────────────────────────────────── # Normalise par variance : features discriminantes (forte std) # contribuent plus que les features quasi-constantes. X64 = X.astype(np.float64) X_std, feat_mean, feat_std = standardize(X64) # ── 4. K-means++ sur l'espace standardisé ──────────────────────── km = kmeans_pp(X_std, k=k_actual, max_iter=80, n_init=3, seed=42) log.info(f"[clustering] K-means: {km.n_iter} iters, inertia={km.inertia:.2f}") # Centroïdes dans l'espace original [0,1] pour affichage radar # (dé-standardisation : c_orig = c_std * std + mean, puis clip [0,1]) centroids_orig = np.clip(km.centroids * feat_std + feat_mean, 0.0, 1.0) # ── 5. PCA-2D sur les features ORIGINALES (normalisées [0,1]) ──── coords = pca_2d(X64) # (n, 2), normalisé [0,1] # ── 5b. Enveloppes convexes par cluster ────────────────────────── hulls = compute_hulls(coords, km.labels, k_actual) # ── 6. Agrégation par cluster ───────────────────────────────────── cluster_rows: list[list[dict]] = [[] for _ in range(k_actual)] cluster_coords: list[list[list[float]]] = [[] for _ in range(k_actual)] cluster_ips_map: dict[int, list] = {j: [] for j in range(k_actual)} for i, label in enumerate(km.labels): j = int(label) cluster_rows[j].append(rows[i]) cluster_coords[j].append(coords[i].tolist()) cluster_ips_map[j].append(( rows[i]["ip"], rows[i]["ja4"], float(coords[i][0]), float(coords[i][1]), float(risk_score_from_centroid(centroids_orig[j])), )) # ── 7. Construction des nœuds ───────────────────────────────────── nodes = [] for j in range(k_actual): if not cluster_rows[j]: continue def avg_f(key: str, crows: list[dict] = cluster_rows[j]) -> float: return float(np.mean([float(r.get(key) or 0) for r in crows])) mean_ttl = avg_f("ttl") mean_mss = avg_f("mss") mean_scale = avg_f("scale") mean_win = avg_f("win") raw_stats = {"mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale} label_name = name_cluster(centroids_orig[j], raw_stats) risk = float(risk_score_from_centroid(centroids_orig[j])) color = _risk_to_color(risk) # Centroïde 2D = moyenne des coords du cluster cxy = np.mean(cluster_coords[j], axis=0).tolist() if cluster_coords[j] else [0.5, 0.5] ip_set = list({r["ip"] for r in cluster_rows[j]}) ip_count = len(ip_set) hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j])) threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")] countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")] orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")] def topk(lst: list[str], n: int = 5) -> list[str]: return [v for v, _ in Counter(lst).most_common(n) if v] radar = [ {"feature": name, "value": round(float(centroids_orig[j][i]), 4)} for i, name in enumerate(FEATURE_NAMES) ] radius = max(8, min(30, int(math.log1p(ip_count) * 2.2))) sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8] sample_ips = [r["ip"] for r in sample_rows] sample_ua = str(cluster_rows[j][0].get("ua") or "") nodes.append({ "id": f"c{j}_k{k_actual}", "cluster_idx": j, "label": label_name, "pca_x": round(cxy[0], 6), "pca_y": round(cxy[1], 6), "radius": radius, "color": color, "risk_score": round(risk, 4), "mean_ttl": round(mean_ttl, 1), "mean_mss": round(mean_mss, 0), "mean_scale": round(mean_scale, 1), "mean_win": round(mean_win, 0), "mean_score": round(avg_f("avg_score"), 4), "mean_velocity":round(avg_f("avg_velocity"),3), "mean_fuzzing": round(avg_f("avg_fuzzing"), 3), "mean_headless":round(avg_f("pct_headless"),3), "mean_post": round(avg_f("avg_post"), 3), "mean_asset": round(avg_f("asset_ratio"), 3), "mean_direct": round(avg_f("direct_ratio"),3), "mean_alpn_mismatch": round(avg_f("alpn_mismatch"),3), "mean_h2_eff": round(avg_f("h2_eff"), 3), "mean_hdr_conf":round(avg_f("hdr_conf"), 3), "mean_ua_ch": round(avg_f("ua_ch_mismatch"),3), "mean_entropy": round(avg_f("entropy"), 3), "mean_ja4_diversity": round(avg_f("ja4_count"),3), "mean_ip_id_zero": round(avg_f("ip_id_zero"),3), "mean_browser_score": round(avg_f("browser_score"),1), "mean_ua_rotating": round(avg_f("ua_rotating"),3), "ip_count": ip_count, "hit_count": hit_count, "top_threat": topk(threats, 1)[0] if threats else "", "top_countries":topk(countries, 5), "top_orgs": topk(orgs, 5), "sample_ips": sample_ips, "sample_ua": sample_ua, "radar": radar, # Hull pour deck.gl PolygonLayer "hull": hulls.get(j, []), }) # ── 8. Arêtes k-NN entre clusters ──────────────────────────────── edges = [] seen: set[frozenset] = set() for i, ni in enumerate(nodes): ci = ni["cluster_idx"] dists = sorted( [(j, nj["cluster_idx"], float(np.sum((centroids_orig[ci] - centroids_orig[nj["cluster_idx"]]) ** 2))) for j, nj in enumerate(nodes) if j != i], key=lambda x: x[2] ) for j_idx, cj, d2 in dists[:2]: key = frozenset([ni["id"], nodes[j_idx]["id"]]) if key in seen: continue seen.add(key) edges.append({ "id": f"e_{ni['id']}_{nodes[j_idx]['id']}", "source": ni["id"], "target": nodes[j_idx]["id"], "similarity": round(1.0 / (1.0 + math.sqrt(d2)), 3), }) # ── 9. Stockage résultat + cache IPs ───────────────────────────── total_ips = sum(n_["ip_count"] for n_ in nodes) total_hits = sum(n_["hit_count"] for n_ in nodes) bot_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.45 or "🤖" in n_["label"]) high_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.25) elapsed = round(time.time() - t0, 2) result_dict = { "nodes": nodes, "edges": edges, "stats": { "total_clusters": len(nodes), "total_ips": total_ips, "total_hits": total_hits, "bot_ips": bot_ips, "high_risk_ips": high_ips, "n_samples": n, "k": k_actual, "k_base": k, "sensitivity": sensitivity, "elapsed_s": elapsed, }, "feature_names": FEATURE_NAMES, } with _LOCK: _CACHE["result"] = result_dict _CACHE["cluster_ips"] = cluster_ips_map _CACHE["status"] = "ready" _CACHE["ts"] = time.time() _CACHE["params"] = {"k": k, "hours": hours, "sensitivity": sensitivity} _CACHE["error"] = None log.info(f"[clustering] Terminé en {elapsed}s — {total_ips} IPs, {len(nodes)} clusters") except Exception as e: log.exception("[clustering] Erreur lors du calcul") with _LOCK: _CACHE["status"] = "error" _CACHE["error"] = str(e) def _maybe_trigger(k: int, hours: int, sensitivity: float) -> None: """Lance le calcul si cache absent, expiré ou paramètres différents.""" with _LOCK: status = _CACHE["status"] params = _CACHE["params"] ts = _CACHE["ts"] cache_stale = (time.time() - ts) > _CACHE_TTL params_changed = ( params.get("k") != k or params.get("hours") != hours or params.get("sensitivity") != sensitivity ) if status in ("computing",): return # déjà en cours if status == "ready" and not cache_stale and not params_changed: return # cache frais _EXECUTOR.submit(_run_clustering_job, k, hours, sensitivity) # ─── Endpoints ──────────────────────────────────────────────────────────────── @router.get("/status") async def get_status(): """État du calcul en cours (polling frontend).""" with _LOCK: return { "status": _CACHE["status"], "error": _CACHE["error"], "ts": _CACHE["ts"], "params": _CACHE["params"], "age_s": round(time.time() - _CACHE["ts"], 0) if _CACHE["ts"] else None, } @router.get("/clusters") async def get_clusters( k: int = Query(14, ge=4, le=30, description="Nombre de clusters de base"), hours: int = Query(24, ge=1, le=168, description="Fenêtre temporelle (heures)"), sensitivity: float = Query(1.0, ge=0.5, le=3.0, description="Sensibilité : multiplicateur de k"), force: bool = Query(False, description="Forcer le recalcul"), ): """ Clustering multi-métriques sur TOUTES les IPs. k_actual = round(k × sensitivity) — la sensibilité contrôle la granularité. Retourne immédiatement depuis le cache. Déclenche le calcul si nécessaire. """ if force: with _LOCK: _CACHE["status"] = "idle" _CACHE["ts"] = 0.0 _maybe_trigger(k, hours, sensitivity) with _LOCK: status = _CACHE["status"] result = _CACHE["result"] error = _CACHE["error"] if status == "computing": return {"status": "computing", "message": "Calcul en cours, réessayez dans quelques secondes"} if status == "error": raise HTTPException(status_code=500, detail=error or "Erreur inconnue") if result is None: return {"status": "idle", "message": "Calcul démarré, réessayez dans quelques secondes"} return {**result, "status": "ready"} @router.get("/cluster/{cluster_id}/points") async def get_cluster_points( cluster_id: str, limit: int = Query(5000, ge=1, le=20000), offset: int = Query(0, ge=0), ): """ Coordonnées PCA + métadonnées de toutes les IPs d'un cluster. Utilisé par deck.gl ScatterplotLayer (drill-down ou zoom avancé). """ with _LOCK: status = _CACHE["status"] ips_map = _CACHE["cluster_ips"] if status != "ready" or not ips_map: raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord") try: idx = int(cluster_id.split("_")[0][1:]) except (ValueError, IndexError): raise HTTPException(status_code=400, detail="cluster_id invalide (format: c{n}_k{k})") members = ips_map.get(idx, []) total = len(members) page = members[offset: offset + limit] points = [ {"ip": m[0], "ja4": m[1], "pca_x": round(m[2], 6), "pca_y": round(m[3], 6), "risk": round(m[4], 3)} for m in page ] return {"points": points, "total": total, "offset": offset, "limit": limit} @router.get("/cluster/{cluster_id}/ips") async def get_cluster_ips( cluster_id: str, limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): """IPs avec détails SQL (backward-compat avec l'ancienne UI).""" with _LOCK: status = _CACHE["status"] ips_map = _CACHE["cluster_ips"] if status != "ready" or not ips_map: raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord") try: idx = int(cluster_id.split("_")[0][1:]) except (ValueError, IndexError): raise HTTPException(status_code=400, detail="cluster_id invalide") members = ips_map.get(idx, []) total = len(members) page = members[offset: offset + limit] if not page: return {"ips": [], "total": total, "cluster_id": cluster_id} safe_ips = [m[0].replace("'", "") for m in page[:200]] ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips) sql = f""" SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip, t.ja4, any(t.tcp_ttl_raw) AS ttl, any(t.tcp_win_raw) AS win, any(t.tcp_scale_raw) AS scale, any(t.tcp_mss_raw) AS mss, sum(t.hits) AS hits, any(t.first_ua) AS ua, round(avg(abs(ml.anomaly_score)), 3) AS avg_score, max(ml.threat_level) AS threat_level, any(ml.country_code) AS country_code, any(ml.asn_org) AS asn_org, round(avg(ml.fuzzing_index), 2) AS fuzzing, round(avg(ml.hit_velocity), 2) AS velocity FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 AND ml.detected_at >= now() - INTERVAL 24 HOUR WHERE t.window_start >= now() - INTERVAL 24 HOUR AND replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') IN ({ip_filter}) GROUP BY t.src_ip, t.ja4 ORDER BY hits DESC """ try: result = db.query(sql) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) ips = [] for row in result.result_rows: ips.append({ "ip": str(row[0] or ""), "ja4": str(row[1] or ""), "tcp_ttl": int(row[2] or 0), "tcp_win": int(row[3] or 0), "tcp_scale": int(row[4] or 0), "tcp_mss": int(row[5] or 0), "hits": int(row[6] or 0), "ua": str(row[7] or ""), "avg_score": float(row[8] or 0), "threat_level": str(row[9] or ""), "country_code": str(row[10] or ""), "asn_org": str(row[11] or ""), "fuzzing": float(row[12] or 0), "velocity": float(row[13] or 0), }) return {"ips": ips, "total": total, "cluster_id": cluster_id}