""" Clustering d'IPs multi-métriques — WebGL / deck.gl backend. - Calcul sur la TOTALITÉ des IPs (GROUP BY src_ip, ja4 sans LIMIT) - K-means++ vectorisé (numpy) + PCA-2D + enveloppes convexes (scipy) - Calcul en background thread + cache 30 min - Endpoints : /clusters, /status, /cluster/{id}/points """ import math import time import logging import threading from collections import Counter from concurrent.futures import ThreadPoolExecutor from typing import Any import numpy as np from fastapi import APIRouter, HTTPException, Query from ..database import db from ..services.clustering_engine import ( FEATURE_NAMES, build_feature_vector, kmeans_pp, pca_2d, compute_hulls, name_cluster, risk_score_from_centroid, standardize, risk_to_gradient_color, ) log = logging.getLogger(__name__) router = APIRouter(prefix="/api/clustering", tags=["clustering"]) # ─── Cache global ────────────────────────────────────────────────────────────── _CACHE: dict[str, Any] = { "status": "idle", # idle | computing | ready | error "error": None, "result": None, # dict résultat complet "ts": 0.0, # timestamp dernière mise à jour "params": {}, "cluster_ips": {}, # cluster_idx → [(ip, ja4, pca_x, pca_y, risk)] } _CACHE_TTL = 1800 # 30 minutes _LOCK = threading.Lock() _EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="clustering") # ─── Palette de couleurs (remplace l'ancienne logique menace) ───────────────── # Les couleurs sont désormais attribuées par index de cluster pour maximiser # la distinction visuelle, indépendamment du niveau de risque. # ─── SQL : TOUTES les IPs sans LIMIT ───────────────────────────────────────── _SQL_ALL_IPS = """ SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip, t.ja4, any(t.tcp_ttl_raw) AS ttl, any(t.tcp_win_raw) AS win, any(t.tcp_scale_raw) AS scale, any(t.tcp_mss_raw) AS mss, any(t.first_ua) AS ua, sum(t.hits) AS hits, avg(abs(ml.anomaly_score)) AS avg_score, avg(ml.hit_velocity) AS avg_velocity, avg(ml.fuzzing_index) AS avg_fuzzing, avg(ml.is_headless) AS pct_headless, avg(ml.post_ratio) AS avg_post, avg(ml.ip_id_zero_ratio) AS ip_id_zero, avg(ml.temporal_entropy) AS entropy, avg(ml.modern_browser_score) AS browser_score, avg(ml.alpn_http_mismatch) AS alpn_mismatch, avg(ml.is_alpn_missing) AS alpn_missing, avg(ml.multiplexing_efficiency) AS h2_eff, avg(ml.header_order_confidence) AS hdr_conf, avg(ml.ua_ch_mismatch) AS ua_ch_mismatch, avg(ml.asset_ratio) AS asset_ratio, avg(ml.direct_access_ratio) AS direct_ratio, avg(ml.distinct_ja4_count) AS ja4_count, max(ml.is_ua_rotating) AS ua_rotating, max(ml.threat_level) AS threat, any(ml.country_code) AS country, any(ml.asn_org) AS asn_org, -- Features headers HTTP (depuis view_dashboard_entities) avg(ml.has_accept_language) AS hdr_accept_lang, any(vh.hdr_enc) AS hdr_has_encoding, any(vh.hdr_sec_fetch) AS hdr_has_sec_fetch, any(vh.hdr_count) AS hdr_count_raw, -- Fingerprint HTTP Headers (depuis agg_header_fingerprint_1h + ml_detected_anomalies) -- header_order_shared_count : nb d'IPs partageant le même fingerprint -- → faible = fingerprint rare = comportement suspect avg(ml.header_order_shared_count) AS hfp_shared_count, -- distinct_header_orders : nb de fingerprints distincts émis par cette IP -- → élevé = rotation de fingerprint = comportement bot avg(ml.distinct_header_orders) AS hfp_distinct_orders, -- Cookie et Referer issus de la table dédiée aux empreintes any(hfp.hfp_cookie) AS hfp_cookie, any(hfp.hfp_referer) AS hfp_referer FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 AND ml.detected_at >= now() - INTERVAL %(hours)s HOUR LEFT JOIN ( SELECT toIPv6(concat('::ffff:', toString(src_ip))) AS src_ip_v6, ja4, any(arrayExists(x -> x LIKE '%%Accept-Encoding%%', client_headers)) AS hdr_enc, any(arrayExists(x -> x LIKE '%%Sec-Fetch%%', client_headers)) AS hdr_sec_fetch, any(length(splitByChar(',', client_headers[1]))) AS hdr_count FROM mabase_prod.view_dashboard_entities WHERE length(client_headers) > 0 AND log_date >= today() - 2 GROUP BY src_ip_v6, ja4 ) vh ON t.src_ip = vh.src_ip_v6 AND t.ja4 = vh.ja4 LEFT JOIN ( SELECT src_ip, avg(has_cookie) AS hfp_cookie, avg(has_referer) AS hfp_referer FROM mabase_prod.agg_header_fingerprint_1h WHERE window_start >= now() - INTERVAL %(hours)s HOUR GROUP BY src_ip ) hfp ON t.src_ip = hfp.src_ip WHERE t.window_start >= now() - INTERVAL %(hours)s HOUR AND t.tcp_ttl_raw > 0 GROUP BY t.src_ip, t.ja4 """ _SQL_COLS = [ "ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits", "avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post", "ip_id_zero", "entropy", "browser_score", "alpn_mismatch", "alpn_missing", "h2_eff", "hdr_conf", "ua_ch_mismatch", "asset_ratio", "direct_ratio", "ja4_count", "ua_rotating", "threat", "country", "asn_org", "hdr_accept_lang", "hdr_has_encoding", "hdr_has_sec_fetch", "hdr_count_raw", "hfp_shared_count", "hfp_distinct_orders", "hfp_cookie", "hfp_referer", ] # ─── Worker de clustering (thread pool) ────────────────────────────────────── def _run_clustering_job(k: int, hours: int, sensitivity: float = 1.0) -> None: """Exécuté dans le thread pool. Met à jour _CACHE. sensitivity : multiplicateur de k [0.5 – 5.0]. 0.5 = vue très agrégée (k/2 clusters) 1.0 = comportement par défaut 2.0 = deux fois plus de clusters → groupes plus homogènes 5.0 = granularité maximale (classification la plus fine) k_actual est plafonné à 300 pour éviter des temps de calcul excessifs. n_init est réduit à 1 quand k_actual > 60 pour rester rapide. """ k_actual = max(4, min(300, round(k * sensitivity))) t0 = time.time() with _LOCK: _CACHE["status"] = "computing" _CACHE["error"] = None try: log.info(f"[clustering] Démarrage k={k_actual} (base={k}×sens={sensitivity}) hours={hours}") # ── 1. Chargement de toutes les IPs ────────────────────────────── result = db.query(_SQL_ALL_IPS, {"hours": hours}) rows: list[dict] = [] for row in result.result_rows: rows.append({col: row[i] for i, col in enumerate(_SQL_COLS)}) n = len(rows) log.info(f"[clustering] {n} IPs chargées") if n < k_actual: raise ValueError(f"Seulement {n} IPs disponibles (k={k_actual} requis)") # ── 2. Construction de la matrice de features (numpy) ──────────── X = np.array([build_feature_vector(r) for r in rows], dtype=np.float32) log.info(f"[clustering] Matrice X: {X.shape} — {X.nbytes/1024/1024:.1f} MB") # ── 3. Standardisation z-score ──────────────────────────────────── # Normalise par variance : features discriminantes (forte std) # contribuent plus que les features quasi-constantes. X64 = X.astype(np.float64) X_std, feat_mean, feat_std = standardize(X64) # ── 4. K-means++ sur l'espace standardisé ──────────────────────── # n_init réduit à 1 pour k élevé (> 60) afin de limiter le temps de calcul n_init = 1 if k_actual > 60 else 3 km = kmeans_pp(X_std, k=k_actual, max_iter=80, n_init=n_init, seed=42) log.info(f"[clustering] K-means: {km.n_iter} iters, inertia={km.inertia:.2f}") # Centroïdes dans l'espace original [0,1] pour affichage radar # (dé-standardisation : c_orig = c_std * std + mean, puis clip [0,1]) centroids_orig = np.clip(km.centroids * feat_std + feat_mean, 0.0, 1.0) # ── 5. PCA-2D sur les features ORIGINALES (normalisées [0,1]) ──── coords = pca_2d(X64) # (n, 2), normalisé [0,1] # ── 5b. Enveloppes convexes par cluster ────────────────────────── hulls = compute_hulls(coords, km.labels, k_actual) # ── 6. Agrégation par cluster ───────────────────────────────────── cluster_rows: list[list[dict]] = [[] for _ in range(k_actual)] cluster_coords: list[list[list[float]]] = [[] for _ in range(k_actual)] cluster_ips_map: dict[int, list] = {j: [] for j in range(k_actual)} for i, label in enumerate(km.labels): j = int(label) cluster_rows[j].append(rows[i]) cluster_coords[j].append(coords[i].tolist()) cluster_ips_map[j].append(( rows[i]["ip"], rows[i]["ja4"], float(coords[i][0]), float(coords[i][1]), float(risk_score_from_centroid(centroids_orig[j])), )) # ── 7. Construction des nœuds ───────────────────────────────────── nodes = [] for j in range(k_actual): if not cluster_rows[j]: continue def avg_f(key: str, crows: list[dict] = cluster_rows[j]) -> float: return float(np.mean([float(r.get(key) or 0) for r in crows])) mean_ttl = avg_f("ttl") mean_mss = avg_f("mss") mean_scale = avg_f("scale") mean_win = avg_f("win") raw_stats = {"mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale} label_name = name_cluster(centroids_orig[j], raw_stats) risk = float(risk_score_from_centroid(centroids_orig[j])) color = risk_to_gradient_color(risk) # Centroïde 2D = moyenne des coords du cluster cxy = np.mean(cluster_coords[j], axis=0).tolist() if cluster_coords[j] else [0.5, 0.5] ip_set = list({r["ip"] for r in cluster_rows[j]}) ip_count = len(ip_set) hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j])) threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")] countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")] orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")] def topk(lst: list[str], n: int = 5) -> list[str]: return [v for v, _ in Counter(lst).most_common(n) if v] radar = [ {"feature": name, "value": round(float(centroids_orig[j][i]), 4)} for i, name in enumerate(FEATURE_NAMES) ] radius = max(8, min(30, int(math.log1p(ip_count) * 2.2))) sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8] sample_ips = [r["ip"] for r in sample_rows] sample_ua = str(cluster_rows[j][0].get("ua") or "") nodes.append({ "id": f"c{j}_k{k_actual}", "cluster_idx": j, "label": label_name, "pca_x": round(cxy[0], 6), "pca_y": round(cxy[1], 6), "radius": radius, "color": color, "risk_score": round(risk, 4), "mean_ttl": round(mean_ttl, 1), "mean_mss": round(mean_mss, 0), "mean_scale": round(mean_scale, 1), "mean_win": round(mean_win, 0), "mean_velocity":round(avg_f("avg_velocity"),3), "mean_fuzzing": round(avg_f("avg_fuzzing"), 3), "mean_headless":round(avg_f("pct_headless"),3), "mean_post": round(avg_f("avg_post"), 3), "mean_asset": round(avg_f("asset_ratio"), 3), "mean_direct": round(avg_f("direct_ratio"),3), "mean_alpn_mismatch": round(avg_f("alpn_mismatch"),3), "mean_h2_eff": round(avg_f("h2_eff"), 3), "mean_hdr_conf":round(avg_f("hdr_conf"), 3), "mean_ua_ch": round(avg_f("ua_ch_mismatch"),3), "mean_entropy": round(avg_f("entropy"), 3), "mean_ja4_diversity": round(avg_f("ja4_count"),3), "mean_ip_id_zero": round(avg_f("ip_id_zero"),3), "mean_browser_score": round(avg_f("browser_score"),1), "mean_ua_rotating": round(avg_f("ua_rotating"),3), "ip_count": ip_count, "hit_count": hit_count, "top_threat": topk(threats, 1)[0] if threats else "", "top_countries":topk(countries, 5), "top_orgs": topk(orgs, 5), "sample_ips": sample_ips, "sample_ua": sample_ua, "radar": radar, # Hull pour deck.gl PolygonLayer "hull": hulls.get(j, []), }) # ── 8. Arêtes k-NN entre clusters ──────────────────────────────── edges = [] seen: set[frozenset] = set() for i, ni in enumerate(nodes): ci = ni["cluster_idx"] dists = sorted( [(j, nj["cluster_idx"], float(np.sum((centroids_orig[ci] - centroids_orig[nj["cluster_idx"]]) ** 2))) for j, nj in enumerate(nodes) if j != i], key=lambda x: x[2] ) for j_idx, cj, d2 in dists[:2]: key = frozenset([ni["id"], nodes[j_idx]["id"]]) if key in seen: continue seen.add(key) edges.append({ "id": f"e_{ni['id']}_{nodes[j_idx]['id']}", "source": ni["id"], "target": nodes[j_idx]["id"], "similarity": round(1.0 / (1.0 + math.sqrt(d2)), 3), }) # ── 9. Stockage résultat + cache IPs ───────────────────────────── total_ips = sum(n_["ip_count"] for n_ in nodes) total_hits = sum(n_["hit_count"] for n_ in nodes) elapsed = round(time.time() - t0, 2) result_dict = { "nodes": nodes, "edges": edges, "stats": { "total_clusters": len(nodes), "total_ips": total_ips, "total_hits": total_hits, "n_samples": n, "k": k_actual, "k_base": k, "sensitivity": sensitivity, "elapsed_s": elapsed, }, "feature_names": FEATURE_NAMES, } with _LOCK: _CACHE["result"] = result_dict _CACHE["cluster_ips"] = cluster_ips_map _CACHE["status"] = "ready" _CACHE["ts"] = time.time() _CACHE["params"] = {"k": k, "hours": hours, "sensitivity": sensitivity} _CACHE["error"] = None log.info(f"[clustering] Terminé en {elapsed}s — {total_ips} IPs, {len(nodes)} clusters") except Exception as e: log.exception("[clustering] Erreur lors du calcul") with _LOCK: _CACHE["status"] = "error" _CACHE["error"] = str(e) def _maybe_trigger(k: int, hours: int, sensitivity: float) -> None: """Lance le calcul si cache absent, expiré ou paramètres différents.""" with _LOCK: status = _CACHE["status"] params = _CACHE["params"] ts = _CACHE["ts"] cache_stale = (time.time() - ts) > _CACHE_TTL params_changed = ( params.get("k") != k or params.get("hours") != hours or params.get("sensitivity") != sensitivity ) if status in ("computing",): return # déjà en cours if status == "ready" and not cache_stale and not params_changed: return # cache frais _EXECUTOR.submit(_run_clustering_job, k, hours, sensitivity) # ─── Endpoints ──────────────────────────────────────────────────────────────── @router.get("/status") async def get_status(): """État du calcul en cours (polling frontend).""" with _LOCK: return { "status": _CACHE["status"], "error": _CACHE["error"], "ts": _CACHE["ts"], "params": _CACHE["params"], "age_s": round(time.time() - _CACHE["ts"], 0) if _CACHE["ts"] else None, } @router.get("/clusters") async def get_clusters( k: int = Query(20, ge=4, le=100, description="Nombre de clusters de base"), hours: int = Query(24, ge=1, le=168, description="Fenêtre temporelle (heures)"), sensitivity: float = Query(1.0, ge=0.5, le=5.0, description="Sensibilité : multiplicateur de k (5.0 = granularité maximale)"), force: bool = Query(False, description="Forcer le recalcul"), ): """ Clustering multi-métriques sur TOUTES les IPs. k_actual = round(k × sensitivity) — la sensibilité contrôle la granularité. Retourne immédiatement depuis le cache. Déclenche le calcul si nécessaire. """ if force: with _LOCK: _CACHE["status"] = "idle" _CACHE["ts"] = 0.0 _CACHE["result"] = None _CACHE["cluster_ips"] = {} _maybe_trigger(k, hours, sensitivity) with _LOCK: status = _CACHE["status"] result = _CACHE["result"] error = _CACHE["error"] if status == "computing": return {"status": "computing", "message": "Calcul en cours, réessayez dans quelques secondes"} if status == "error": raise HTTPException(status_code=500, detail=error or "Erreur inconnue") if result is None: return {"status": "idle", "message": "Calcul démarré, réessayez dans quelques secondes"} return {**result, "status": "ready"} @router.get("/cluster/{cluster_id}/points") async def get_cluster_points( cluster_id: str, limit: int = Query(5000, ge=1, le=20000), offset: int = Query(0, ge=0), ): """ Coordonnées PCA + métadonnées de toutes les IPs d'un cluster. Utilisé par deck.gl ScatterplotLayer (drill-down ou zoom avancé). """ with _LOCK: status = _CACHE["status"] ips_map = _CACHE["cluster_ips"] if status != "ready" or not ips_map: raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord") try: idx = int(cluster_id.split("_")[0][1:]) except (ValueError, IndexError): raise HTTPException(status_code=400, detail="cluster_id invalide (format: c{n}_k{k})") members = ips_map.get(idx, []) total = len(members) page = members[offset: offset + limit] points = [ {"ip": m[0], "ja4": m[1], "pca_x": round(m[2], 6), "pca_y": round(m[3], 6), "risk": round(m[4], 3)} for m in page ] return {"points": points, "total": total, "offset": offset, "limit": limit} @router.get("/cluster/{cluster_id}/ips") async def get_cluster_ips( cluster_id: str, limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): """IPs avec détails SQL (backward-compat avec l'ancienne UI).""" with _LOCK: status = _CACHE["status"] ips_map = _CACHE["cluster_ips"] if status != "ready" or not ips_map: raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord") try: idx = int(cluster_id.split("_")[0][1:]) except (ValueError, IndexError): raise HTTPException(status_code=400, detail="cluster_id invalide") members = ips_map.get(idx, []) total = len(members) page = members[offset: offset + limit] if not page: return {"ips": [], "total": total, "cluster_id": cluster_id} safe_ips = [m[0].replace("'", "") for m in page[:200]] ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips) sql = f""" SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip, t.ja4, any(t.tcp_ttl_raw) AS ttl, any(t.tcp_win_raw) AS win, any(t.tcp_scale_raw) AS scale, any(t.tcp_mss_raw) AS mss, sum(t.hits) AS hits, any(t.first_ua) AS ua, round(avg(abs(ml.anomaly_score)), 3) AS avg_score, max(ml.threat_level) AS threat_level, any(ml.country_code) AS country_code, any(ml.asn_org) AS asn_org, round(avg(ml.fuzzing_index), 2) AS fuzzing, round(avg(ml.hit_velocity), 2) AS velocity FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 AND ml.detected_at >= now() - INTERVAL 24 HOUR WHERE t.window_start >= now() - INTERVAL 24 HOUR AND replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') IN ({ip_filter}) GROUP BY t.src_ip, t.ja4 ORDER BY hits DESC """ try: result = db.query(sql) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) ips = [] for row in result.result_rows: ips.append({ "ip": str(row[0] or ""), "ja4": str(row[1] or ""), "tcp_ttl": int(row[2] or 0), "tcp_win": int(row[3] or 0), "tcp_scale": int(row[4] or 0), "tcp_mss": int(row[5] or 0), "hits": int(row[6] or 0), "ua": str(row[7] or ""), "avg_score": float(row[8] or 0), "threat_level": str(row[9] or ""), "country_code": str(row[10] or ""), "asn_org": str(row[11] or ""), "fuzzing": float(row[12] or 0), "velocity": float(row[13] or 0), }) return {"ips": ips, "total": total, "cluster_id": cluster_id}