diff --git a/backend/routes/clustering.py b/backend/routes/clustering.py index 611f7fe..b775608 100644 --- a/backend/routes/clustering.py +++ b/backend/routes/clustering.py @@ -1,54 +1,53 @@ """ -Clustering d'IPs multi-métriques — backend ReactFlow. +Clustering d'IPs multi-métriques — WebGL / deck.gl backend. -Features utilisées (21 dimensions) : - TCP stack : TTL initial, MSS, scale, fenêtre TCP - Comportement : vélocité, POST ratio, fuzzing, assets, accès direct - Anomalie ML : score, IP-ID zéro - TLS/Protocole: ALPN mismatch, ALPN absent, efficacité H2 - Navigateur : browser score, headless, ordre headers, UA-CH mismatch - Temporel : entropie, diversité JA4, UA rotatif - -Algorithme : - 1. Échantillonnage stratifié (top détections + top hits) - 2. Construction + normalisation des vecteurs de features - 3. K-means++ (Arthur & Vassilvitskii, 2007) - 4. PCA-2D par power iteration pour les positions ReactFlow - 5. Nommage automatique par features dominantes du centroïde - 6. Calcul des arêtes : k-NN dans l'espace des features +- Calcul sur la TOTALITÉ des IPs (GROUP BY src_ip, ja4 sans LIMIT) +- K-means++ vectorisé (numpy) + PCA-2D + enveloppes convexes (scipy) +- Calcul en background thread + cache 30 min +- Endpoints : /clusters, /status, /cluster/{id}/points """ from __future__ import annotations + import math import time -import hashlib -from typing import Optional +import logging +import threading +from collections import Counter +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Any + +import numpy as np from fastapi import APIRouter, HTTPException, Query from ..database import db from ..services.clustering_engine import ( - FEATURES, FEATURE_KEYS, FEATURE_NORMS, FEATURE_NAMES, N_FEATURES, - build_feature_vector, kmeans_pp, pca_2d, - name_cluster, risk_score_from_centroid, _mean_vec, + FEATURE_KEYS, FEATURE_NAMES, FEATURE_NORMS, N_FEATURES, + build_feature_vector, kmeans_pp, pca_2d, compute_hulls, + name_cluster, risk_score_from_centroid, ) +log = logging.getLogger(__name__) router = APIRouter(prefix="/api/clustering", tags=["clustering"]) -# ─── Cache en mémoire ───────────────────────────────────────────────────────── -# Stocke (cluster_id → liste d'IPs) pour le drill-down -# + timestamp de dernière mise à jour -_cache: dict = { - "assignments": {}, # ip+ja4 → cluster_idx - "cluster_ips": {}, # cluster_idx → [(ip, ja4)] - "params": {}, # k, ts +# ─── Cache global ────────────────────────────────────────────────────────────── +_CACHE: dict[str, Any] = { + "status": "idle", # idle | computing | ready | error + "error": None, + "result": None, # dict résultat complet + "ts": 0.0, # timestamp dernière mise à jour + "params": {}, + "cluster_ips": {}, # cluster_idx → [(ip, ja4, pca_x, pca_y, risk)] } +_CACHE_TTL = 1800 # 30 minutes +_LOCK = threading.Lock() +_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="clustering") -# ─── Couleurs ───────────────────────────────────────────────────────────────── +# ─── Couleurs menace ────────────────────────────────────────────────────────── _THREAT_COLOR = { - 0.92: "#dc2626", # Bot scanner - 0.70: "#ef4444", # Critique + 0.70: "#dc2626", # Critique 0.45: "#f97316", # Élevé 0.25: "#eab308", # Modéré - 0.00: "#6b7280", # Sain / inconnu + 0.00: "#22c55e", # Sain } def _risk_to_color(risk: float) -> str: @@ -58,9 +57,8 @@ def _risk_to_color(risk: float) -> str: return "#6b7280" -# ─── SQL ────────────────────────────────────────────────────────────────────── - -_SQL_FEATURES = """ +# ─── SQL : TOUTES les IPs sans LIMIT ───────────────────────────────────────── +_SQL_ALL_IPS = """ SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip, t.ja4, @@ -71,43 +69,36 @@ SELECT any(t.first_ua) AS ua, sum(t.hits) AS hits, - avg(abs(ml.anomaly_score)) AS avg_score, - avg(ml.hit_velocity) AS avg_velocity, - avg(ml.fuzzing_index) AS avg_fuzzing, - avg(ml.is_headless) AS pct_headless, - avg(ml.post_ratio) AS avg_post, - avg(ml.ip_id_zero_ratio) AS ip_id_zero, - avg(ml.temporal_entropy) AS entropy, - avg(ml.modern_browser_score) AS browser_score, - avg(ml.alpn_http_mismatch) AS alpn_mismatch, - avg(ml.is_alpn_missing) AS alpn_missing, - avg(ml.multiplexing_efficiency) AS h2_eff, - avg(ml.header_order_confidence) AS hdr_conf, - avg(ml.ua_ch_mismatch) AS ua_ch_mismatch, - avg(ml.asset_ratio) AS asset_ratio, - avg(ml.direct_access_ratio) AS direct_ratio, - avg(ml.distinct_ja4_count) AS ja4_count, - max(ml.is_ua_rotating) AS ua_rotating, + avg(abs(ml.anomaly_score)) AS avg_score, + avg(ml.hit_velocity) AS avg_velocity, + avg(ml.fuzzing_index) AS avg_fuzzing, + avg(ml.is_headless) AS pct_headless, + avg(ml.post_ratio) AS avg_post, + avg(ml.ip_id_zero_ratio) AS ip_id_zero, + avg(ml.temporal_entropy) AS entropy, + avg(ml.modern_browser_score) AS browser_score, + avg(ml.alpn_http_mismatch) AS alpn_mismatch, + avg(ml.is_alpn_missing) AS alpn_missing, + avg(ml.multiplexing_efficiency) AS h2_eff, + avg(ml.header_order_confidence) AS hdr_conf, + avg(ml.ua_ch_mismatch) AS ua_ch_mismatch, + avg(ml.asset_ratio) AS asset_ratio, + avg(ml.direct_access_ratio) AS direct_ratio, + avg(ml.distinct_ja4_count) AS ja4_count, + max(ml.is_ua_rotating) AS ua_rotating, - max(ml.threat_level) AS threat, - any(ml.country_code) AS country, - any(ml.asn_org) AS asn_org + max(ml.threat_level) AS threat, + any(ml.country_code) AS country, + any(ml.asn_org) AS asn_org FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 - AND ml.detected_at >= now() - INTERVAL 24 HOUR -WHERE t.window_start >= now() - INTERVAL 24 HOUR + AND ml.detected_at >= now() - INTERVAL %(hours)s HOUR +WHERE t.window_start >= now() - INTERVAL %(hours)s HOUR AND t.tcp_ttl_raw > 0 GROUP BY t.src_ip, t.ja4 -ORDER BY - -- Stratégie : IPs anormales en premier, puis fort trafic - -- Cela garantit que les bots Masscan (anomalie=0.97, hits=1-2) sont inclus - avg(abs(ml.anomaly_score)) DESC, - sum(t.hits) DESC -LIMIT %(limit)s """ -# Noms des colonnes SQL dans l'ordre _SQL_COLS = [ "ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits", "avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post", @@ -117,252 +108,311 @@ _SQL_COLS = [ ] +# ─── Worker de clustering (thread pool) ────────────────────────────────────── + +def _run_clustering_job(k: int, hours: int) -> None: + """Exécuté dans le thread pool. Met à jour _CACHE.""" + t0 = time.time() + with _LOCK: + _CACHE["status"] = "computing" + _CACHE["error"] = None + + try: + log.info(f"[clustering] Démarrage du calcul k={k} hours={hours}") + + # ── 1. Chargement de toutes les IPs ────────────────────────────── + result = db.query(_SQL_ALL_IPS, {"hours": hours}) + rows: list[dict] = [] + for row in result.result_rows: + rows.append({col: row[i] for i, col in enumerate(_SQL_COLS)}) + + n = len(rows) + log.info(f"[clustering] {n} IPs chargées") + if n < k: + raise ValueError(f"Seulement {n} IPs disponibles (k={k} requis)") + + # ── 2. Construction de la matrice de features (numpy) ──────────── + X = np.array([build_feature_vector(r) for r in rows], dtype=np.float32) + log.info(f"[clustering] Matrice X: {X.shape} — {X.nbytes/1024/1024:.1f} MB") + + # ── 3. K-means++ vectorisé ──────────────────────────────────────── + km = kmeans_pp(X.astype(np.float64), k=k, max_iter=80, n_init=3, seed=42) + log.info(f"[clustering] K-means: {km.n_iter} iters, inertia={km.inertia:.2f}") + + # ── 4. PCA-2D pour toutes les IPs ──────────────────────────────── + coords = pca_2d(X.astype(np.float64)) # (n, 2), normalisé [0,1] + + # ── 5. Enveloppes convexes par cluster ─────────────────────────── + hulls = compute_hulls(coords, km.labels, k) + + # ── 6. Agrégation par cluster ───────────────────────────────────── + cluster_rows: list[list[dict]] = [[] for _ in range(k)] + cluster_coords: list[list[list[float]]] = [[] for _ in range(k)] + cluster_ips_map: dict[int, list] = {j: [] for j in range(k)} + + for i, label in enumerate(km.labels): + j = int(label) + cluster_rows[j].append(rows[i]) + cluster_coords[j].append(coords[i].tolist()) + cluster_ips_map[j].append(( + rows[i]["ip"], + rows[i]["ja4"], + float(coords[i][0]), + float(coords[i][1]), + float(risk_score_from_centroid(km.centroids[j])), + )) + + # ── 7. Construction des nœuds ───────────────────────────────────── + nodes = [] + for j in range(k): + if not cluster_rows[j]: + continue + + def avg_f(key: str, crows: list[dict] = cluster_rows[j]) -> float: + return float(np.mean([float(r.get(key) or 0) for r in crows])) + + mean_ttl = avg_f("ttl") + mean_mss = avg_f("mss") + mean_scale = avg_f("scale") + mean_win = avg_f("win") + + raw_stats = {"mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale} + label_name = name_cluster(km.centroids[j], raw_stats) + risk = float(risk_score_from_centroid(km.centroids[j])) + color = _risk_to_color(risk) + + # Centroïde 2D = moyenne des coords du cluster + cxy = np.mean(cluster_coords[j], axis=0).tolist() if cluster_coords[j] else [0.5, 0.5] + ip_set = list({r["ip"] for r in cluster_rows[j]}) + ip_count = len(ip_set) + hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j])) + + threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")] + countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")] + orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")] + + def topk(lst: list[str], n: int = 5) -> list[str]: + return [v for v, _ in Counter(lst).most_common(n) if v] + + radar = [ + {"feature": name, "value": round(float(km.centroids[j][i]), 4)} + for i, name in enumerate(FEATURE_NAMES) + ] + + radius = max(12, min(80, int(math.sqrt(ip_count) * 2))) + + sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8] + sample_ips = [r["ip"] for r in sample_rows] + sample_ua = str(cluster_rows[j][0].get("ua") or "") + + nodes.append({ + "id": f"c{j}_k{k}", + "cluster_idx": j, + "label": label_name, + "pca_x": round(cxy[0], 6), + "pca_y": round(cxy[1], 6), + "radius": radius, + "color": color, + "risk_score": round(risk, 4), + + "mean_ttl": round(mean_ttl, 1), + "mean_mss": round(mean_mss, 0), + "mean_scale": round(mean_scale, 1), + "mean_win": round(mean_win, 0), + "mean_score": round(avg_f("avg_score"), 4), + "mean_velocity":round(avg_f("avg_velocity"),3), + "mean_fuzzing": round(avg_f("avg_fuzzing"), 3), + "mean_headless":round(avg_f("pct_headless"),3), + "mean_post": round(avg_f("avg_post"), 3), + "mean_asset": round(avg_f("asset_ratio"), 3), + "mean_direct": round(avg_f("direct_ratio"),3), + "mean_alpn_mismatch": round(avg_f("alpn_mismatch"),3), + "mean_h2_eff": round(avg_f("h2_eff"), 3), + "mean_hdr_conf":round(avg_f("hdr_conf"), 3), + "mean_ua_ch": round(avg_f("ua_ch_mismatch"),3), + "mean_entropy": round(avg_f("entropy"), 3), + "mean_ja4_diversity": round(avg_f("ja4_count"),3), + "mean_ip_id_zero": round(avg_f("ip_id_zero"),3), + "mean_browser_score": round(avg_f("browser_score"),1), + "mean_ua_rotating": round(avg_f("ua_rotating"),3), + + "ip_count": ip_count, + "hit_count": hit_count, + "top_threat": topk(threats, 1)[0] if threats else "", + "top_countries":topk(countries, 5), + "top_orgs": topk(orgs, 5), + "sample_ips": sample_ips, + "sample_ua": sample_ua, + "radar": radar, + + # Hull pour deck.gl PolygonLayer + "hull": hulls.get(j, []), + }) + + # ── 8. Arêtes k-NN entre clusters ──────────────────────────────── + edges = [] + seen: set[frozenset] = set() + for i, ni in enumerate(nodes): + ci = ni["cluster_idx"] + dists = sorted( + [(j, nj["cluster_idx"], + float(np.sum((km.centroids[ci] - km.centroids[nj["cluster_idx"]]) ** 2))) + for j, nj in enumerate(nodes) if j != i], + key=lambda x: x[2] + ) + for j_idx, cj, d2 in dists[:2]: + key = frozenset([ni["id"], nodes[j_idx]["id"]]) + if key in seen: + continue + seen.add(key) + edges.append({ + "id": f"e_{ni['id']}_{nodes[j_idx]['id']}", + "source": ni["id"], + "target": nodes[j_idx]["id"], + "similarity": round(1.0 / (1.0 + math.sqrt(d2)), 3), + }) + + # ── 9. Stockage résultat + cache IPs ───────────────────────────── + total_ips = sum(n_["ip_count"] for n_ in nodes) + total_hits = sum(n_["hit_count"] for n_ in nodes) + bot_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.45 or "🤖" in n_["label"]) + high_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.25) + elapsed = round(time.time() - t0, 2) + + result_dict = { + "nodes": nodes, + "edges": edges, + "stats": { + "total_clusters": len(nodes), + "total_ips": total_ips, + "total_hits": total_hits, + "bot_ips": bot_ips, + "high_risk_ips": high_ips, + "n_samples": n, + "k": k, + "elapsed_s": elapsed, + }, + "feature_names": FEATURE_NAMES, + } + + with _LOCK: + _CACHE["result"] = result_dict + _CACHE["cluster_ips"] = cluster_ips_map + _CACHE["status"] = "ready" + _CACHE["ts"] = time.time() + _CACHE["params"] = {"k": k, "hours": hours} + _CACHE["error"] = None + + log.info(f"[clustering] Terminé en {elapsed}s — {total_ips} IPs, {len(nodes)} clusters") + + except Exception as e: + log.exception("[clustering] Erreur lors du calcul") + with _LOCK: + _CACHE["status"] = "error" + _CACHE["error"] = str(e) + + +def _maybe_trigger(k: int, hours: int) -> None: + """Lance le calcul si cache absent, expiré ou paramètres différents.""" + with _LOCK: + status = _CACHE["status"] + params = _CACHE["params"] + ts = _CACHE["ts"] + + cache_stale = (time.time() - ts) > _CACHE_TTL + params_changed = params.get("k") != k or params.get("hours") != hours + + if status in ("computing",): + return # déjà en cours + + if status == "ready" and not cache_stale and not params_changed: + return # cache frais + + _EXECUTOR.submit(_run_clustering_job, k, hours) + + # ─── Endpoints ──────────────────────────────────────────────────────────────── +@router.get("/status") +async def get_status(): + """État du calcul en cours (polling frontend).""" + with _LOCK: + return { + "status": _CACHE["status"], + "error": _CACHE["error"], + "ts": _CACHE["ts"], + "params": _CACHE["params"], + "age_s": round(time.time() - _CACHE["ts"], 0) if _CACHE["ts"] else None, + } + + @router.get("/clusters") async def get_clusters( - k: int = Query(14, ge=4, le=30, description="Nombre de clusters"), - n_samples: int = Query(3000, ge=500, le=8000, description="Taille de l'échantillon"), + k: int = Query(14, ge=4, le=30, description="Nombre de clusters"), + hours: int = Query(24, ge=1, le=168, description="Fenêtre temporelle (heures)"), + force: bool = Query(False, description="Forcer le recalcul"), ): """ - Clustering multi-métriques des IPs. + Clustering multi-métriques sur TOUTES les IPs. - Retourne les nœuds (clusters) + arêtes pour ReactFlow, avec : - - positions 2D issues de PCA sur les 21 features - - profil radar des features par cluster (normalisé [0,1]) - - statistiques détaillées (moyennes brutes des features) - - sample d'IPs représentatives + Retourne immédiatement depuis le cache (status=ready). + Si le calcul est en cours ou non démarré → status=computing/idle + trigger. """ - t0 = time.time() + if force: + with _LOCK: + _CACHE["status"] = "idle" + _CACHE["ts"] = 0.0 + + _maybe_trigger(k, hours) + + with _LOCK: + status = _CACHE["status"] + result = _CACHE["result"] + error = _CACHE["error"] + + if status == "computing": + return {"status": "computing", "message": "Calcul en cours, réessayez dans quelques secondes"} + + if status == "error": + raise HTTPException(status_code=500, detail=error or "Erreur inconnue") + + if result is None: + return {"status": "idle", "message": "Calcul démarré, réessayez dans quelques secondes"} + + return {**result, "status": "ready"} + + +@router.get("/cluster/{cluster_id}/points") +async def get_cluster_points( + cluster_id: str, + limit: int = Query(5000, ge=1, le=20000), + offset: int = Query(0, ge=0), +): + """ + Coordonnées PCA + métadonnées de toutes les IPs d'un cluster. + Utilisé par deck.gl ScatterplotLayer (drill-down ou zoom avancé). + """ + with _LOCK: + status = _CACHE["status"] + ips_map = _CACHE["cluster_ips"] + + if status != "ready" or not ips_map: + raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord") + try: - result = db.query(_SQL_FEATURES, {"limit": n_samples}) - except Exception as e: - raise HTTPException(status_code=500, detail=f"ClickHouse: {e}") + idx = int(cluster_id.split("_")[0][1:]) + except (ValueError, IndexError): + raise HTTPException(status_code=400, detail="cluster_id invalide (format: c{n}_k{k})") - # ── Construction des vecteurs de features ───────────────────────────── - rows: list[dict] = [] - for row in result.result_rows: - d = {col: row[i] for i, col in enumerate(_SQL_COLS)} - rows.append(d) + members = ips_map.get(idx, []) + total = len(members) + page = members[offset: offset + limit] - if len(rows) < k: - raise HTTPException(status_code=400, detail="Pas assez de données pour ce k") - - points = [build_feature_vector(r) for r in rows] - - # ── K-means++ ──────────────────────────────────────────────────────── - km = kmeans_pp(points, k=k, max_iter=60, seed=42) - - # ── PCA-2D sur les centroïdes ───────────────────────────────────────── - # On projette les centroïdes dans l'espace PCA des données - # → les positions relatives reflètent la variance des données - coords_all = pca_2d(points) - # Moyenne des positions PCA par cluster = position 2D du centroïde - cluster_xs: list[list[float]] = [[] for _ in range(k)] - cluster_ys: list[list[float]] = [[] for _ in range(k)] - for i, label in enumerate(km.labels): - cluster_xs[label].append(coords_all[i][0]) - cluster_ys[label].append(coords_all[i][1]) - - centroid_2d: list[tuple[float, float]] = [] - for j in range(k): - if cluster_xs[j]: - cx = sum(cluster_xs[j]) / len(cluster_xs[j]) - cy = sum(cluster_ys[j]) / len(cluster_ys[j]) - else: - cx, cy = 0.5, 0.5 - centroid_2d.append((cx, cy)) - - # ── Agrégation des statistiques par cluster ─────────────────────────── - cluster_rows: list[list[dict]] = [[] for _ in range(k)] - cluster_members: list[list[tuple[str, str]]] = [[] for _ in range(k)] - for i, label in enumerate(km.labels): - cluster_rows[label].append(rows[i]) - cluster_members[label].append((rows[i]["ip"], rows[i]["ja4"])) - - # Mise à jour du cache pour le drill-down - _cache["cluster_ips"] = {j: cluster_members[j] for j in range(k)} - _cache["params"] = {"k": k, "ts": t0} - - # ── Construction des nœuds ReactFlow ───────────────────────────────── - CANVAS_W, CANVAS_H = 1400, 780 - - nodes = [] - for j in range(k): - if not cluster_rows[j]: - continue - - # Statistiques brutes moyennées - def avg_feat(key: str) -> float: - vals = [float(r.get(key) or 0) for r in cluster_rows[j]] - return sum(vals) / len(vals) if vals else 0.0 - - mean_ttl = avg_feat("ttl") - mean_mss = avg_feat("mss") - mean_scale = avg_feat("scale") - mean_win = avg_feat("win") - mean_score = avg_feat("avg_score") - mean_vel = avg_feat("avg_velocity") - mean_fuzz = avg_feat("avg_fuzzing") - mean_hless = avg_feat("pct_headless") - mean_post = avg_feat("avg_post") - mean_asset = avg_feat("asset_ratio") - mean_direct= avg_feat("direct_ratio") - mean_alpn = avg_feat("alpn_mismatch") - mean_h2 = avg_feat("h2_eff") - mean_hconf = avg_feat("hdr_conf") - mean_ua_ch = avg_feat("ua_ch_mismatch") - mean_entr = avg_feat("entropy") - mean_ja4 = avg_feat("ja4_count") - mean_ip_id = avg_feat("ip_id_zero") - mean_brow = avg_feat("browser_score") - mean_uarot = avg_feat("ua_rotating") - - ip_count = len(set(r["ip"] for r in cluster_rows[j])) - hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j])) - - # Pays / ASN / Menace dominants - threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")] - countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")] - orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")] - - def topk(lst: list[str], n: int = 5) -> list[str]: - from collections import Counter - return [v for v, _ in Counter(lst).most_common(n) if v] - - raw_stats = { - "mean_ttl": mean_ttl, "mean_mss": mean_mss, - "mean_scale": mean_scale, - } - label = name_cluster(km.centroids[j], raw_stats) - risk = risk_score_from_centroid(km.centroids[j]) - color = _risk_to_color(risk) - - # Profil radar normalisé (valeurs centroïde [0,1]) - radar = [ - {"feature": name, "value": round(km.centroids[j][i], 4)} - for i, name in enumerate(FEATURE_NAMES) - ] - - # Position 2D (PCA normalisée → pixels ReactFlow) - px_x = centroid_2d[j][0] * CANVAS_W * 0.85 + 80 - px_y = (1 - centroid_2d[j][1]) * CANVAS_H * 0.85 + 50 # inverser y (haut=risque) - - # Rayon ∝ √ip_count - radius = max(18, min(90, int(math.sqrt(ip_count) * 0.3))) - - # Sample IPs (top 8 par hits) - sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8] - sample_ips = [r["ip"] for r in sample_rows] - sample_ua = str(cluster_rows[j][0].get("ua") or "") - - cluster_id = f"c{j}_k{k}" - - nodes.append({ - "id": cluster_id, - "label": label, - "cluster_idx": j, - "x": round(px_x, 1), - "y": round(px_y, 1), - "radius": radius, - "color": color, - "risk_score": risk, - - # Caractéristiques TCP - "mean_ttl": round(mean_ttl, 1), - "mean_mss": round(mean_mss, 0), - "mean_scale": round(mean_scale, 1), - "mean_win": round(mean_win, 0), - - # Comportement HTTP - "mean_score": round(mean_score, 4), - "mean_velocity": round(mean_vel, 3), - "mean_fuzzing": round(mean_fuzz, 3), - "mean_headless": round(mean_hless, 3), - "mean_post": round(mean_post, 3), - "mean_asset": round(mean_asset, 3), - "mean_direct": round(mean_direct, 3), - - # TLS / Protocole - "mean_alpn_mismatch": round(mean_alpn, 3), - "mean_h2_eff": round(mean_h2, 3), - "mean_hdr_conf": round(mean_hconf, 3), - "mean_ua_ch": round(mean_ua_ch, 3), - - # Temporel - "mean_entropy": round(mean_entr, 3), - "mean_ja4_diversity": round(mean_ja4, 3), - "mean_ip_id_zero": round(mean_ip_id, 3), - "mean_browser_score": round(mean_brow, 1), - "mean_ua_rotating": round(mean_uarot, 3), - - # Meta - "ip_count": ip_count, - "hit_count": hit_count, - "top_threat": topk(threats, 1)[0] if topk(threats, 1) else "", - "top_countries": topk(countries, 5), - "top_orgs": topk(orgs, 5), - "sample_ips": sample_ips, - "sample_ua": sample_ua, - - # Profil radar pour visualisation - "radar": radar, - }) - - # ── Arêtes : k-NN dans l'espace des features ────────────────────────── - # Chaque cluster est connecté à ses 2 voisins les plus proches - edges = [] - seen: set[frozenset] = set() - centroids = km.centroids - - for i, ni in enumerate(nodes): - ci = ni["cluster_idx"] - # Distance² aux autres centroïdes - dists = [ - (j, nj["cluster_idx"], - sum((centroids[ci][d] - centroids[nj["cluster_idx"]][d]) ** 2 - for d in range(N_FEATURES))) - for j, nj in enumerate(nodes) if j != i - ] - dists.sort(key=lambda x: x[2]) - # 2 voisins les plus proches - for j, cj, dist2 in dists[:2]: - key = frozenset([ni["id"], nodes[j]["id"]]) - if key in seen: - continue - seen.add(key) - similarity = round(1.0 / (1.0 + math.sqrt(dist2)), 3) - edges.append({ - "id": f"e_{ni['id']}_{nodes[j]['id']}", - "source": ni["id"], - "target": nodes[j]["id"], - "similarity": similarity, - "weight": round(similarity * 5, 1), - }) - - # ── Stats globales ──────────────────────────────────────────────────── - total_ips = sum(n["ip_count"] for n in nodes) - total_hits = sum(n["hit_count"] for n in nodes) - bot_ips = sum(n["ip_count"] for n in nodes if n["risk_score"] > 0.40 or "🤖" in n["label"]) - high_risk = sum(n["ip_count"] for n in nodes if n["risk_score"] > 0.20) - - elapsed = round(time.time() - t0, 2) - - return { - "nodes": nodes, - "edges": edges, - "stats": { - "total_clusters": len(nodes), - "total_ips": total_ips, - "total_hits": total_hits, - "bot_ips": bot_ips, - "high_risk_ips": high_risk, - "n_samples": len(rows), - "k": k, - "elapsed_s": elapsed, - }, - "feature_names": FEATURE_NAMES, - } + points = [ + {"ip": m[0], "ja4": m[1], "pca_x": round(m[2], 6), "pca_y": round(m[3], 6), "risk": round(m[4], 3)} + for m in page + ] + return {"points": points, "total": total, "offset": offset, "limit": limit} @router.get("/cluster/{cluster_id}/ips") @@ -371,57 +421,44 @@ async def get_cluster_ips( limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), ): - """ - IPs appartenant à un cluster (depuis le cache de la dernière exécution). - Si le cache est expiré, retourne une erreur guidant vers /clusters. - """ - if not _cache.get("cluster_ips"): - raise HTTPException( - status_code=404, - detail="Cache expiré — appelez /api/clustering/clusters d'abord" - ) + """IPs avec détails SQL (backward-compat avec l'ancienne UI).""" + with _LOCK: + status = _CACHE["status"] + ips_map = _CACHE["cluster_ips"] + + if status != "ready" or not ips_map: + raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord") - # Extrait l'index cluster depuis l'id (format: c{idx}_k{k}) try: idx = int(cluster_id.split("_")[0][1:]) except (ValueError, IndexError): raise HTTPException(status_code=400, detail="cluster_id invalide") - members = _cache["cluster_ips"].get(idx, []) - if not members: - return {"ips": [], "total": 0, "cluster_id": cluster_id} - - total = len(members) - page_members = members[offset: offset + limit] - - # Requête SQL pour les détails de ces IPs spécifiques - ip_list = [m[0] for m in page_members] - ja4_list = [m[1] for m in page_members] - - if not ip_list: + members = ips_map.get(idx, []) + total = len(members) + page = members[offset: offset + limit] + if not page: return {"ips": [], "total": total, "cluster_id": cluster_id} - # On ne peut pas facilement passer une liste en paramètre ClickHouse — - # on la construit directement (valeurs nettoyées) - safe_ips = [ip.replace("'", "") for ip in ip_list[:100]] + safe_ips = [m[0].replace("'", "") for m in page[:200]] ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips) sql = f""" SELECT replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip, t.ja4, - any(t.tcp_ttl_raw) AS ttl, - any(t.tcp_win_raw) AS win, - any(t.tcp_scale_raw) AS scale, - any(t.tcp_mss_raw) AS mss, - sum(t.hits) AS hits, - any(t.first_ua) AS ua, - round(avg(abs(ml.anomaly_score)), 3) AS avg_score, - max(ml.threat_level) AS threat_level, - any(ml.country_code) AS country_code, - any(ml.asn_org) AS asn_org, - round(avg(ml.fuzzing_index), 2) AS fuzzing, - round(avg(ml.hit_velocity), 2) AS velocity + any(t.tcp_ttl_raw) AS ttl, + any(t.tcp_win_raw) AS win, + any(t.tcp_scale_raw) AS scale, + any(t.tcp_mss_raw) AS mss, + sum(t.hits) AS hits, + any(t.first_ua) AS ua, + round(avg(abs(ml.anomaly_score)), 3) AS avg_score, + max(ml.threat_level) AS threat_level, + any(ml.country_code) AS country_code, + any(ml.asn_org) AS asn_org, + round(avg(ml.fuzzing_index), 2) AS fuzzing, + round(avg(ml.hit_velocity), 2) AS velocity FROM mabase_prod.agg_host_ip_ja4_1h t LEFT JOIN mabase_prod.ml_detected_anomalies ml ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4 @@ -439,7 +476,7 @@ async def get_cluster_ips( ips = [] for row in result.result_rows: ips.append({ - "ip": str(row[0]), + "ip": str(row[0] or ""), "ja4": str(row[1] or ""), "tcp_ttl": int(row[2] or 0), "tcp_win": int(row[3] or 0), diff --git a/backend/services/clustering_engine.py b/backend/services/clustering_engine.py index 213425c..0b5f1a0 100644 --- a/backend/services/clustering_engine.py +++ b/backend/services/clustering_engine.py @@ -1,12 +1,14 @@ """ -Moteur de clustering K-means++ multi-métriques (pur Python). +Moteur de clustering K-means++ multi-métriques (numpy + scipy vectorisé). -Ref: Arthur & Vassilvitskii (2007) — k-means++: The Advantages of Careful Seeding - Hotelling (1933) — PCA par puissance itérative (deflation) +Ref: + Arthur & Vassilvitskii (2007) — k-means++: The Advantages of Careful Seeding + scipy.spatial.ConvexHull — enveloppe convexe (Graham/Qhull) + sklearn-style API — centroids, labels_, inertia_ Features (21 dimensions, normalisées [0,1]) : - 0 ttl_n : TTL initial normalisé (hops-count estimé) - 1 mss_n : MSS normalisé → type réseau (Ethernet/PPPoE/VPN) + 0 ttl_n : TTL initial normalisé + 1 mss_n : MSS normalisé → type réseau 2 scale_n : facteur de mise à l'échelle TCP 3 win_n : fenêtre TCP normalisée 4 score_n : score anomalie ML (abs) @@ -16,7 +18,7 @@ Features (21 dimensions, normalisées [0,1]) : 8 post_n : ratio POST/total 9 ip_id_zero_n : ratio IP-ID=0 (Linux/spoofé) 10 entropy_n : entropie temporelle - 11 browser_n : score navigateur moderne (normalisé max 50) + 11 browser_n : score navigateur moderne 12 alpn_n : mismatch ALPN/protocole 13 alpn_absent_n : ratio ALPN absent 14 h2_n : efficacité H2 multiplexing (log1p) @@ -28,301 +30,248 @@ Features (21 dimensions, normalisées [0,1]) : 20 ua_rot_n : UA rotatif (booléen) """ from __future__ import annotations + import math -import random +import logging +import numpy as np from dataclasses import dataclass, field +from scipy.spatial import ConvexHull + +log = logging.getLogger(__name__) # ─── Définition des features ────────────────────────────────────────────────── -# (clé SQL, nom lisible, fonction de normalisation) -FEATURES = [ +FEATURES: list[tuple[str, str, object]] = [ # TCP stack - ("ttl", "TTL Initial", lambda v: min(1.0, (v or 0) / 255.0)), - ("mss", "MSS Réseau", lambda v: min(1.0, (v or 0) / 1460.0)), - ("scale", "Scale TCP", lambda v: min(1.0, (v or 0) / 14.0)), - ("win", "Fenêtre TCP", lambda v: min(1.0, (v or 0) / 65535.0)), + ("ttl", "TTL Initial", lambda v: min(1.0, (v or 0) / 255.0)), + ("mss", "MSS Réseau", lambda v: min(1.0, (v or 0) / 1460.0)), + ("scale", "Scale TCP", lambda v: min(1.0, (v or 0) / 14.0)), + ("win", "Fenêtre TCP", lambda v: min(1.0, (v or 0) / 65535.0)), # Anomalie ML - ("avg_score", "Score Anomalie", lambda v: min(1.0, float(v or 0))), - ("avg_velocity", "Vélocité (rps)", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(100))), - ("avg_fuzzing", "Fuzzing", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(300))), - ("pct_headless", "Headless", lambda v: min(1.0, float(v or 0))), - ("avg_post", "Ratio POST", lambda v: min(1.0, float(v or 0))), + ("avg_score", "Score Anomalie", lambda v: min(1.0, float(v or 0))), + ("avg_velocity", "Vélocité (rps)", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(100))), + ("avg_fuzzing", "Fuzzing", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(300))), + ("pct_headless", "Headless", lambda v: min(1.0, float(v or 0))), + ("avg_post", "Ratio POST", lambda v: min(1.0, float(v or 0))), # IP-ID - ("ip_id_zero", "IP-ID Zéro", lambda v: min(1.0, float(v or 0))), + ("ip_id_zero", "IP-ID Zéro", lambda v: min(1.0, float(v or 0))), # Temporel - ("entropy", "Entropie Temporelle", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(10))), + ("entropy", "Entropie Temporelle", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(10))), # Navigateur - ("browser_score","Score Navigateur", lambda v: min(1.0, float(v or 0) / 50.0)), + ("browser_score", "Score Navigateur", lambda v: min(1.0, float(v or 0) / 50.0)), # TLS / Protocole - ("alpn_mismatch","ALPN Mismatch", lambda v: min(1.0, float(v or 0))), - ("alpn_missing", "ALPN Absent", lambda v: min(1.0, float(v or 0))), - ("h2_eff", "H2 Multiplexing", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(20))), - ("hdr_conf", "Ordre Headers", lambda v: min(1.0, float(v or 0))), - ("ua_ch_mismatch","UA-CH Mismatch", lambda v: min(1.0, float(v or 0))), + ("alpn_mismatch", "ALPN Mismatch", lambda v: min(1.0, float(v or 0))), + ("alpn_missing", "ALPN Absent", lambda v: min(1.0, float(v or 0))), + ("h2_eff", "H2 Multiplexing", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(20))), + ("hdr_conf", "Ordre Headers", lambda v: min(1.0, float(v or 0))), + ("ua_ch_mismatch","UA-CH Mismatch", lambda v: min(1.0, float(v or 0))), # Comportement HTTP - ("asset_ratio", "Ratio Assets", lambda v: min(1.0, float(v or 0))), - ("direct_ratio", "Accès Direct", lambda v: min(1.0, float(v or 0))), + ("asset_ratio", "Ratio Assets", lambda v: min(1.0, float(v or 0))), + ("direct_ratio", "Accès Direct", lambda v: min(1.0, float(v or 0))), # Diversité JA4 - ("ja4_count", "Diversité JA4", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(30))), + ("ja4_count", "Diversité JA4", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(30))), # UA rotatif - ("ua_rotating", "UA Rotatif", lambda v: 1.0 if float(v or 0) > 0 else 0.0), + ("ua_rotating", "UA Rotatif", lambda v: 1.0 if float(v or 0) > 0 else 0.0), ] -FEATURE_KEYS = [f[0] for f in FEATURES] -FEATURE_NAMES = [f[1] for f in FEATURES] -FEATURE_NORMS = [f[2] for f in FEATURES] -N_FEATURES = len(FEATURES) - - -# ─── Utilitaires vectoriels (pur Python) ────────────────────────────────────── - -def _dist2(a: list[float], b: list[float]) -> float: - return sum((x - y) ** 2 for x, y in zip(a, b)) - -def _mean_vec(vecs: list[list[float]]) -> list[float]: - n = len(vecs) - if n == 0: - return [0.0] * N_FEATURES - return [sum(v[i] for v in vecs) / n for i in range(N_FEATURES)] +FEATURE_KEYS = [f[0] for f in FEATURES] +FEATURE_NAMES = [f[1] for f in FEATURES] +FEATURE_NORMS = [f[2] for f in FEATURES] +N_FEATURES = len(FEATURES) # ─── Construction du vecteur de features ───────────────────────────────────── def build_feature_vector(row: dict) -> list[float]: - """Normalise un dict de colonnes SQL → vecteur [0,1]^N_FEATURES.""" - return [fn(row.get(key)) for key, fn in zip(FEATURE_KEYS, FEATURE_NORMS)] + """Construit le vecteur normalisé [0,1]^21 depuis un dict SQL.""" + return [norm(row.get(key, 0)) for key, _, norm in FEATURES] -# ─── K-means++ ─────────────────────────────────────────────────────────────── +# ─── K-means++ vectorisé (numpy) ───────────────────────────────────────────── @dataclass class KMeansResult: - centroids: list[list[float]] - labels: list[int] - inertia: float - n_iter: int + centroids: np.ndarray # (k, n_features) + labels: np.ndarray # (n_points,) int32 + inertia: float + n_iter: int -def kmeans_pp( - points: list[list[float]], - k: int, - max_iter: int = 60, - seed: int = 42, - n_init: int = 3, -) -> KMeansResult: +def kmeans_pp(X: np.ndarray, k: int, max_iter: int = 60, n_init: int = 3, + seed: int = 42) -> KMeansResult: """ - K-means avec initialisation k-means++ (Arthur & Vassilvitskii, 2007). - Lance `n_init` fois et retourne le meilleur résultat (inertie minimale). + K-means++ entièrement vectorisé avec numpy. + n_init exécutions, meilleure inertie conservée. """ - rng = random.Random(seed) + rng = np.random.default_rng(seed) + n, d = X.shape best: KMeansResult | None = None - for attempt in range(n_init): - # ── Initialisation k-means++ ──────────────────────────────────── - first_idx = rng.randrange(len(points)) - centroids = [points[first_idx][:]] - + for _ in range(n_init): + # ── Initialisation K-means++ ────────────────────────────────────── + centers = [X[rng.integers(n)].copy()] for _ in range(k - 1): - d2 = [min(_dist2(p, c) for c in centroids) for p in points] - total = sum(d2) - if total == 0: - break - r = rng.random() * total - cumul = 0.0 - for i, d in enumerate(d2): - cumul += d - if cumul >= r: - centroids.append(points[i][:]) - break + D = _min_sq_dist(X, np.array(centers)) + # Garantit des probabilités non-négatives (erreurs float, points dupliqués) + D = np.clip(D, 0.0, None) + total = D.sum() + if total < 1e-12: + # Tous les points sont confondus — tirage aléatoire + centers.append(X[rng.integers(n)].copy()) else: - centroids.append(points[rng.randrange(len(points))][:]) + probs = D / total + centers.append(X[rng.choice(n, p=probs)].copy()) + centers_arr = np.array(centers) # (k, d) - # ── Itérations EM ─────────────────────────────────────────────── - labels: list[int] = [0] * len(points) - for iteration in range(max_iter): - # E-step : affectation - new_labels = [ - min(range(len(centroids)), key=lambda c: _dist2(p, centroids[c])) - for p in points - ] - if new_labels == labels and iteration > 0: - break + # ── Iterations ─────────────────────────────────────────────────── + labels = np.zeros(n, dtype=np.int32) + for it in range(max_iter): + # Assignation vectorisée : (n, k) distance² + dists = _sq_dists(X, centers_arr) # (n, k) + new_labels = np.argmin(dists, axis=1).astype(np.int32) + + if it > 0 and np.all(new_labels == labels): + break # convergence labels = new_labels - # M-step : mise à jour - clusters: list[list[list[float]]] = [[] for _ in range(k)] - for i, l in enumerate(labels): - clusters[l].append(points[i]) + # Mise à jour des centroïdes for j in range(k): - if clusters[j]: - centroids[j] = _mean_vec(clusters[j]) + mask = labels == j + if mask.any(): + centers_arr[j] = X[mask].mean(axis=0) - inertia = sum(_dist2(points[i], centroids[labels[i]]) for i in range(len(points))) - result = KMeansResult( - centroids=centroids, - labels=labels, - inertia=inertia, - n_iter=iteration + 1, - ) + inertia = float(np.sum(np.min(_sq_dists(X, centers_arr), axis=1))) + result = KMeansResult(centers_arr, labels, inertia, it + 1) if best is None or inertia < best.inertia: best = result - return best # type: ignore + return best # type: ignore[return-value] -# ─── PCA 2D par puissance itérative ────────────────────────────────────────── +def _sq_dists(X: np.ndarray, C: np.ndarray) -> np.ndarray: + """Distance² entre chaque point de X et chaque centroïde de C. O(n·k·d).""" + # ||x - c||² = ||x||² + ||c||² - 2·x·cᵀ + X2 = np.sum(X ** 2, axis=1, keepdims=True) # (n, 1) + C2 = np.sum(C ** 2, axis=1, keepdims=True).T # (1, k) + return X2 + C2 - 2.0 * X @ C.T # (n, k) -def pca_2d(points: list[list[float]]) -> list[tuple[float, float]]: + +def _min_sq_dist(X: np.ndarray, C: np.ndarray) -> np.ndarray: + """Distance² minimale de chaque point aux centroïdes existants.""" + return np.min(_sq_dists(X, C), axis=1) + + +# ─── PCA 2D (numpy) ────────────────────────────────────────────────────────── + +def pca_2d(X: np.ndarray) -> np.ndarray: """ - Projection PCA 2D par puissance itérative avec déflation (Hotelling). - Retourne les coordonnées (pc1, pc2) normalisées dans [0,1]. + PCA-2D vectorisée. Retourne les coordonnées normalisées [0,1] × [0,1]. """ - n = len(points) - if n == 0: - return [] - - # Centrage - mean = _mean_vec(points) - X = [[p[i] - mean[i] for i in range(N_FEATURES)] for p in points] - - def power_iter(X_centered: list[list[float]], n_iter: int = 30) -> list[float]: - """Trouve le premier vecteur propre de X^T X par puissance itérative.""" - v = [1.0 / math.sqrt(N_FEATURES)] * N_FEATURES - for _ in range(n_iter): - # Xv = X @ v - Xv = [sum(row[j] * v[j] for j in range(N_FEATURES)) for row in X_centered] - # Xtxv = X^T @ Xv - xtxv = [sum(X_centered[i][j] * Xv[i] for i in range(len(X_centered))) for j in range(N_FEATURES)] - norm = math.sqrt(sum(x ** 2 for x in xtxv)) or 1e-10 - v = [x / norm for x in xtxv] - return v - - # PC1 - v1 = power_iter(X) - proj1 = [sum(row[j] * v1[j] for j in range(N_FEATURES)) for row in X] - - # Déflation : retire la composante PC1 de X - X2 = [ - [X[i][j] - proj1[i] * v1[j] for j in range(N_FEATURES)] - for i in range(n) - ] - - # PC2 - v2 = power_iter(X2) - proj2 = [sum(row[j] * v2[j] for j in range(N_FEATURES)) for row in X2] + mean = X.mean(axis=0) + Xc = X - mean + # Power iteration pour les 2 premières composantes + rng = np.random.default_rng(0) + v1 = _power_iter(Xc, rng.standard_normal(Xc.shape[1])) + proj1 = Xc @ v1 + # Déflation (Hotelling) + Xc2 = Xc - np.outer(proj1, v1) + v2 = _power_iter(Xc2, rng.standard_normal(Xc.shape[1])) + proj2 = Xc2 @ v2 + coords = np.column_stack([proj1, proj2]) # Normalisation [0,1] - def _norm01(vals: list[float]) -> list[float]: - lo, hi = min(vals), max(vals) - rng = hi - lo or 1e-10 - return [(v - lo) / rng for v in vals] - - p1 = _norm01(proj1) - p2 = _norm01(proj2) - - return list(zip(p1, p2)) + mn, mx = coords.min(axis=0), coords.max(axis=0) + rng_ = mx - mn + rng_[rng_ == 0] = 1.0 + return (coords - mn) / rng_ -# ─── Nommage automatique des clusters ──────────────────────────────────────── +def _power_iter(X: np.ndarray, v: np.ndarray, n_iter: int = 30) -> np.ndarray: + """Power iteration : trouve le premier vecteur propre de XᵀX.""" + for _ in range(n_iter): + v = X.T @ (X @ v) + norm = np.linalg.norm(v) + if norm < 1e-12: + break + v /= norm + return v -def name_cluster(centroid: list[float], raw_stats: dict | None = None) -> str: + +# ─── Enveloppe convexe (hull) par cluster ──────────────────────────────────── + +def compute_hulls(coords_2d: np.ndarray, labels: np.ndarray, + k: int, min_pts: int = 4) -> dict[int, list[list[float]]]: """ - Génère un nom lisible à partir du centroïde normalisé et de statistiques brutes. - Priorité : signaux les plus discriminants en premier. + Calcule l'enveloppe convexe (convex hull) des points PCA pour chaque cluster. + Retourne {cluster_idx: [[x,y], ...]} (polygone fermé). """ - score = centroid[4] # anomalie ML - vel = centroid[5] # vélocité - fuzz = centroid[6] # fuzzing (log1p normalisé, >0.35 ≈ fuzzing_index > 100) - hless = centroid[7] # headless - post = centroid[8] # POST ratio - alpn = centroid[12] # ALPN mismatch - h2 = centroid[14] # H2 eff - ua_ch = centroid[16] # UA-CH mismatch - ja4d = centroid[19] # JA4 diversité - ua_rot = centroid[20] # UA rotatif - - raw_mss = (raw_stats or {}).get("mean_mss", 0) - raw_ttl = (raw_stats or {}).get("mean_ttl", 0) or (centroid[0] * 255) - raw_scale = (raw_stats or {}).get("mean_scale", 0) - - # ── Signaux forts (déterministes) ──────────────────────────────────── - - # Pattern Masscan : mss≈1452, scale≈4, TTL 48-57 - if raw_mss and 1440 <= raw_mss <= 1460 and raw_scale and 3 <= raw_scale <= 5 and raw_ttl < 60: - return "🤖 Masscan / Scanner IP" - - # Fuzzer agressif (fuzzing_index normalisé > 0.35 ≈ valeur brute > 100) - if fuzz > 0.35: - return "🤖 Bot Fuzzer / Scanner" - - # UA rotatif + UA-CH mismatch : bot sophistiqué simulant un navigateur - if ua_rot > 0.5 and ua_ch > 0.7: - return "🤖 Bot UA Rotatif + CH Mismatch" - - # UA-CH mismatch fort seul (navigateur simulé sans headers CH) - if ua_ch > 0.8: - return "⚠️ Bot UA-CH Incohérent" - - # ── Score ML modéré + signal comportemental ────────────────────────── - - if score > 0.20: - if hless > 0.3: - return "⚠️ Navigateur Headless Suspect" - if vel > 0.25: - return "⚠️ Bot Haute Vélocité" - if post > 0.4: - return "⚠️ Bot POST Automatisé" - if alpn > 0.5 or h2 > 0.5: - return "⚠️ TLS/H2 Anormal" - if ua_ch > 0.4: - return "⚠️ Anomalie UA-CH" - return "⚠️ Anomalie ML Modérée" - - # ── Signaux faibles ─────────────────────────────────────────────────── - - if ua_ch > 0.4: - return "🔎 UA-CH Incohérent" - - if ja4d > 0.5: - return "🔄 Client Multi-Fingerprint" - - # ── Classification réseau / OS ──────────────────────────────────────── - - # MSS bas → VPN ou tunnel - if raw_mss and raw_mss < 1360: - return "🌐 VPN / Tunnel" - - if raw_ttl < 70: - return "🐧 Linux / Mobile" - if raw_ttl > 110: - return "🪟 Windows" - - return "✅ Trafic Légitime" + hulls: dict[int, list[list[float]]] = {} + for j in range(k): + pts = coords_2d[labels == j] + if len(pts) < min_pts: + # Pas assez de points : bounding box + if len(pts) > 0: + mx_, my_ = pts.mean(axis=0) + r = max(0.01, pts.std(axis=0).max()) + hulls[j] = [ + [mx_ - r, my_ - r], [mx_ + r, my_ - r], + [mx_ + r, my_ + r], [mx_ - r, my_ + r], + ] + continue + try: + hull = ConvexHull(pts) + hull_pts = pts[hull.vertices].tolist() + # Fermer le polygone + hull_pts.append(hull_pts[0]) + hulls[j] = hull_pts + except Exception: + hulls[j] = [] + return hulls -def risk_score_from_centroid(centroid: list[float]) -> float: - """Score de risque [0,1] pondéré. Calibré pour les valeurs observées (score ML ~0.3).""" - # Normalisation de score ML : x / 0.5 pour étendre la plage utile (0-0.5 → 0-1) - score_n = min(1.0, centroid[4] / 0.5) - fuzz_n = centroid[6] - ua_ch_n = centroid[16] - ua_rot_n = centroid[20] - vel_n = centroid[5] - hless_n = centroid[7] - ip_id_n = centroid[9] - alpn_n = centroid[12] - ja4d_n = centroid[19] - post_n = centroid[8] +# ─── Nommage et scoring ─────────────────────────────────────────────────────── - return min(1.0, - 0.25 * score_n + - 0.20 * ua_ch_n + - 0.15 * fuzz_n + - 0.12 * ua_rot_n + - 0.10 * hless_n + - 0.07 * vel_n + - 0.04 * ip_id_n + - 0.04 * alpn_n + - 0.03 * ja4d_n + - 0.03 * post_n - ) +def name_cluster(centroid: np.ndarray, raw_stats: dict) -> str: + """Nom lisible basé sur les features dominantes du centroïde.""" + s = centroid # alias + ttl_raw = float(raw_stats.get("mean_ttl", 0)) + mss_raw = float(raw_stats.get("mean_mss", 0)) + + # Scanners / bots masscan + if s[0] > 0.16 and s[0] < 0.25 and mss_raw in range(1440, 1460) and s[2] > 0.25: + return "🤖 Masscan Scanner" + if s[4] > 0.70 and s[6] > 0.5: + return "🤖 Bot agressif" + if s[16] > 0.80: + return "🤖 UA-CH Mismatch" + if s[7] > 0.70: + return "🤖 Headless Browser" + if s[4] > 0.50: + return "⚠️ Anomalie ML haute" + if s[3] > 0.85 and ttl_raw > 120: + return "🖥️ Windows" + if s[0] > 0.22 and s[0] < 0.28 and mss_raw > 1400: + return "🐧 Linux" + if s[1] < 0.90 and s[1] > 0.95: + return "📡 VPN/Proxy" + if mss_raw < 1380 and mss_raw > 0: + return "🌐 Tunnel réseau" + if s[5] > 0.60: + return "⚡ Trafic rapide" + if s[4] < 0.10 and s[5] < 0.10: + return "✅ Trafic sain" + return "📊 Cluster mixte" + + +def risk_score_from_centroid(centroid: np.ndarray) -> float: + """Score de risque [0,1] agrégé depuis le centroïde.""" + s = centroid + return float(np.clip( + 0.40 * s[4] + # score ML + 0.15 * s[6] + # fuzzing + 0.15 * s[16] + # UA-CH mismatch + 0.10 * s[7] + # headless + 0.10 * s[5] + # vélocité + 0.10 * s[9], # IP-ID zéro + 0.0, 1.0 + )) diff --git a/frontend/package.json b/frontend/package.json index db1ee93..7ec975f 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -15,7 +15,10 @@ "recharts": "^2.10.0", "@tanstack/react-table": "^8.11.0", "date-fns": "^3.0.0", - "reactflow": "^11.10.0" + "reactflow": "^11.10.0", + "@deck.gl/react": "^9.0.0", + "@deck.gl/core": "^9.0.0", + "@deck.gl/layers": "^9.0.0" }, "devDependencies": { "@types/react": "^18.2.0", diff --git a/frontend/src/components/ClusteringView.tsx b/frontend/src/components/ClusteringView.tsx index d44d3ae..dd5f8bd 100644 --- a/frontend/src/components/ClusteringView.tsx +++ b/frontend/src/components/ClusteringView.tsx @@ -1,847 +1,617 @@ /** - * Clustering IPs — visualisation multi-métriques + * ClusteringView — Visualisation WebGL des clusters d'IPs via deck.gl * - * Deux vues : - * 1. "Cartes" (défaut) — grille de cartes triées par risque, toujours lisibles - * 2. "Graphe" — ReactFlow avec nœuds-cartes et disposition par colonne de menace + * Architecture LOD : + * - Vue globale : PolygonLayer (hulls) + ScatterplotLayer (centroïdes) + * - Sur sélection : ScatterplotLayer dense (toutes les IPs du cluster) + * - Sidebar : profil radar, stats, liste IPs paginée * - * Chaque cluster affiche : - * • Label + emoji de menace - * • Compteur IPs / hits - * • Score de risque (barre colorée) - * • 4 métriques clés (barres horizontales) - * • Top pays + ASN - * • Radar dans la sidebar + * Rendu WebGL via @deck.gl/react + OrthographicView */ -import { useCallback, useEffect, useState, useMemo } from 'react'; -import ReactFlow, { - Background, Controls, MiniMap, ReactFlowProvider, - useNodesState, useEdgesState, useReactFlow, - Node, Edge, Handle, Position, NodeProps, - Panel, -} from 'reactflow'; -import 'reactflow/dist/style.css'; -import { - RadarChart, Radar, PolarGrid, PolarAngleAxis, PolarRadiusAxis, - ResponsiveContainer, Tooltip as RechartsTooltip, -} from 'recharts'; + +import React, { useState, useEffect, useCallback, useRef } from 'react'; +import DeckGL from '@deck.gl/react'; +import { OrthographicView } from '@deck.gl/core'; +import { ScatterplotLayer, PolygonLayer, TextLayer, LineLayer } from '@deck.gl/layers'; +import { RadarChart, PolarGrid, PolarAngleAxis, Radar, ResponsiveContainer, Tooltip } from 'recharts'; +import axios from 'axios'; // ─── Types ──────────────────────────────────────────────────────────────────── +interface RadarEntry { feature: string; value: number; } + interface ClusterNode { id: string; - label: string; cluster_idx: number; - x: number; y: number; + label: string; + pca_x: number; + pca_y: number; radius: number; color: string; risk_score: number; ip_count: number; hit_count: number; - mean_score: number; - mean_ua_ch: number; - mean_ua_rotating: number; - mean_fuzzing: number; - mean_headless: number; - mean_velocity: number; mean_ttl: number; mean_mss: number; - mean_scale: number; - mean_alpn_mismatch: number; - mean_ip_id_zero: number; - mean_browser_score: number; - mean_entropy: number; - mean_ja4_diversity: number; + mean_score: number; + mean_velocity: number; + mean_fuzzing: number; + mean_headless: number; + mean_ua_ch: number; top_threat: string; top_countries: string[]; top_orgs: string[]; sample_ips: string[]; sample_ua: string; - radar: { feature: string; value: number }[]; + radar: RadarEntry[]; + hull: [number, number][]; } -interface ClusteringData { +interface ClusterEdge { + id: string; + source: string; + target: string; + similarity: number; +} + +interface ClusterStats { + total_clusters: number; + total_ips: number; + total_hits: number; + bot_ips: number; + high_risk_ips: number; + n_samples: number; + k: number; + elapsed_s: number; +} + +interface ClusterResult { + status: string; nodes: ClusterNode[]; - edges: { id: string; source: string; target: string; similarity: number; weight: number }[]; - stats: { - total_clusters: number; - total_ips: number; - total_hits: number; - bot_ips: number; - high_risk_ips: number; - n_samples: number; - k: number; - elapsed_s: number; - }; + edges: ClusterEdge[]; + stats: ClusterStats; + feature_names: string[]; + message?: string; } -interface ClusterIP { - ip: string; ja4: string; tcp_ttl: number; tcp_mss: number; - hits: number; ua: string; avg_score: number; - threat_level: string; country_code: string; asn_org: string; - fuzzing: number; velocity: number; +interface IPPoint { ip: string; ja4: string; pca_x: number; pca_y: number; risk: number; } +interface IPDetail { ip: string; ja4: string; tcp_ttl: number; tcp_mss: number; hits: number; ua: string; avg_score: number; threat_level: string; country_code: string; asn_org: string; } + +// ─── Coordonnées deck.gl ───────────────────────────────────────────────────── +// PCA normalisé [0,1] → world [0, WORLD] +const WORLD = 1000; + +function toWorld(v: number): number { return v * WORLD; } + +// Couleur hex → [r,g,b,a] +function hexToRgba(hex: string, alpha = 255): [number, number, number, number] { + const r = parseInt(hex.slice(1, 3), 16); + const g = parseInt(hex.slice(3, 5), 16); + const b = parseInt(hex.slice(5, 7), 16); + return [r, g, b, alpha]; } -// ─── Helpers ────────────────────────────────────────────────────────────────── - -const THREAT_BADGE_CLASS: Record = { - CRITICAL: 'bg-red-600', HIGH: 'bg-orange-500', - MEDIUM: 'bg-yellow-500', LOW: 'bg-green-600', -}; - -const RADAR_FEATURES = [ - 'Score Anomalie', 'Vélocité (rps)', 'Fuzzing', 'Headless', - 'ALPN Mismatch', 'H2 Multiplexing', 'UA-CH Mismatch', 'UA Rotatif', - 'IP-ID Zéro', 'Entropie Temporelle', -]; - -function ThreatBadge({ level }: { level: string }) { - if (!level) return null; - return ( - - {level} - - ); -} - -function MiniBar({ value, color = '#6366f1', label }: { value: number; color?: string; label?: string }) { - const pct = Math.round(Math.min(1, Math.max(0, value)) * 100); - return ( -
- {label && {label}} -
-
-
- {pct}% -
- ); -} - -function riskColor(risk: number): string { - if (risk >= 0.45) return '#dc2626'; - if (risk >= 0.30) return '#f97316'; - if (risk >= 0.15) return '#eab308'; - return '#22c55e'; -} - -function riskLabel(risk: number): string { - if (risk >= 0.45) return 'CRITIQUE'; - if (risk >= 0.30) return 'ÉLEVÉ'; - if (risk >= 0.15) return 'MODÉRÉ'; - return 'SAIN'; -} - -// ─── Carte cluster (réutilisée dans les 2 vues) ──────────────────────────── - -function ClusterCard({ - node, selected, onClick, -}: { - node: ClusterNode; - selected: boolean; - onClick: () => void; -}) { - const rc = riskColor(node.risk_score); - const rl = riskLabel(node.risk_score); - - // Normalisation anomaly_score pour la barre (valeurs ~0.3 max → étirer sur /0.5) - const scoreN = Math.min(1, node.mean_score / 0.5); - - return ( - - ); -} - -// ─── Vue Cartes (défaut) ────────────────────────────────────────────────────── - -function CardGridView({ - nodes, selectedId, onSelect, -}: { - nodes: ClusterNode[]; - selectedId: string | null; - onSelect: (n: ClusterNode) => void; -}) { - const sorted = useMemo( - () => [...nodes].sort((a, b) => b.risk_score - a.risk_score), - [nodes], - ); - - // Groupes par niveau de risque - const groups = useMemo(() => { - const bots = sorted.filter(n => n.risk_score >= 0.45 || n.label.includes('🤖')); - const warn = sorted.filter(n => n.risk_score >= 0.15 && n.risk_score < 0.45 && !n.label.includes('🤖')); - const safe = sorted.filter(n => n.risk_score < 0.15 && !n.label.includes('🤖')); - return { bots, warn, safe }; - }, [sorted]); - - function Group({ title, color, nodes: gn }: { title: string; color: string; nodes: ClusterNode[] }) { - if (gn.length === 0) return null; - return ( -
-
-
-

- {title} ({gn.length}) -

-
-
-
- {gn.map(n => ( - onSelect(n)} - /> - ))} -
-
- ); - } - - return ( -
- - - -
- ); -} - -// ─── Nœud ReactFlow (pour la vue Graphe) ───────────────────────────────────── - -function GraphCardNode({ data }: NodeProps) { - const rc = riskColor(data.risk_score); - const rl = riskLabel(data.risk_score); - const scoreN = Math.min(1, data.mean_score / 0.5); - - return ( - <> - -
0.40 ? `0 0 16px ${rc}55` : 'none', - }} - > -
-
-
-

{data.label}

- - {rl} - -
-

- {data.ip_count.toLocaleString()} IPs ·{' '} - {data.hit_count.toLocaleString()} req -

- {/* Barre risque */} -
-
-
- {/* Mini métriques */} -
- {[ - ['Anomalie', scoreN, scoreN > 0.5 ? '#dc2626' : '#f97316'], - ['UA-CH', data.mean_ua_ch, '#f97316'], - ['Fuzzing', Math.min(1, data.mean_fuzzing * 3), '#8b5cf6'], - ].map(([l, v, c]: any) => ( -
- {l} -
-
-
- {Math.round(v * 100)}% -
- ))} -
- {data.top_countries?.length > 0 && ( -

🌍 {data.top_countries.slice(0, 4).join(' · ')}

- )} -
-
- - - ); -} - -const nodeTypes = { graphCard: GraphCardNode }; - -// ─── Vue Graphe ─────────────────────────────────────────────────────────────── - -function GraphView({ - data, selectedId, onSelect, -}: { - data: ClusteringData; - selectedId: string | null; - onSelect: (n: ClusterNode) => void; -}) { - const [nodes, setNodes, onNodesChange] = useNodesState([]); - const [edges, setEdges, onEdgesChange] = useEdgesState([]); - const { fitView } = useReactFlow(); - - useEffect(() => { - if (!data) return; - - // Layout en colonnes par niveau de menace - // Col 0 → bots (rouge), Col 1 → suspects (orange), Col 2 → légitimes (vert) - const sorted = [...data.nodes].sort((a, b) => b.risk_score - a.risk_score); - - const col: ClusterNode[][] = [[], [], []]; - for (const n of sorted) { - if (n.risk_score >= 0.45 || n.label.includes('🤖')) col[0].push(n); - else if (n.risk_score >= 0.15) col[1].push(n); - else col[2].push(n); - } - - const NODE_W = 240; - const NODE_H = 170; - const PAD_X = 80; - const PAD_Y = 40; - const COL_GAP = 80; - - const rfNodes: Node[] = []; - col.forEach((group, ci) => { - group.forEach((n, ri) => { - rfNodes.push({ - id: n.id, - type: 'graphCard', - position: { - x: ci * (NODE_W + COL_GAP) + PAD_X, - y: ri * (NODE_H + PAD_Y) + PAD_Y, - }, - data: n, - draggable: true, - selected: n.id === selectedId, - }); - }); - }); - - // Arêtes avec couleur par similarité - const rfEdges: Edge[] = data.edges.map(e => { - const sim = e.similarity; - return { - id: e.id, - source: e.source, - target: e.target, - style: { - stroke: sim > 0.6 ? '#f97316' : sim > 0.4 ? '#6b7280' : '#374151', - strokeWidth: Math.max(1, e.weight * 0.5), - strokeDasharray: sim < 0.4 ? '4 4' : undefined, - }, - label: sim > 0.55 ? `${Math.round(sim * 100)}%` : undefined, - labelStyle: { fontSize: 9, fill: '#9ca3af' }, - labelBgStyle: { fill: '#0f1117aa', borderRadius: 3 }, - animated: sim > 0.6, - }; - }); - - setNodes(rfNodes); - setEdges(rfEdges); - setTimeout(() => fitView({ padding: 0.08 }), 120); - }, [data, selectedId]); - - return ( -
- onSelect(node.data as ClusterNode)} - nodeTypes={nodeTypes} - fitView - minZoom={0.12} - maxZoom={2.5} - attributionPosition="bottom-right" - > - - - riskColor((n.data as any)?.risk_score ?? 0)} - style={{ background: '#0f1117', border: '1px solid #374151' }} - /> - {/* Légende colonnes */} - -
- {[ - { color: '#dc2626', label: '🤖 Bots / Menaces', col: 0 }, - { color: '#f97316', label: '⚠️ Suspects', col: 1 }, - { color: '#22c55e', label: '✅ Légitimes', col: 2 }, - ].map(({ color, label }) => ( -
-
- {label} -
- ))} - ── similaire · - - différent · animé=fort -
- - -
-

K-means++ · 21 features

-

Colonnes : niveau de risque

-

Arêtes : similarité des centroides

-
-
- -
- ); -} - -// ─── Sidebar détail cluster ──────────────────────────────────────────────────── - -const RADAR_FEATURES_SET = new Set(RADAR_FEATURES); - -function ClusterSidebar({ cluster, onClose }: { cluster: ClusterNode; onClose: () => void }) { - const [ips, setIPs] = useState([]); - const [total, setTotal] = useState(0); - const [loading, setLoading] = useState(false); - const [copied, setCopied] = useState(false); - - useEffect(() => { - setLoading(true); - fetch(`/api/clustering/cluster/${cluster.id}/ips?limit=80`) - .then(r => r.json()) - .then(d => { setIPs(d.ips || []); setTotal(d.total || 0); }) - .catch(() => {}) - .finally(() => setLoading(false)); - }, [cluster.id]); - - const copyIPs = () => { - navigator.clipboard.writeText(ips.map(i => i.ip).join('\n')); - setCopied(true); - setTimeout(() => setCopied(false), 2000); - }; - - const downloadCSV = () => { - const header = 'IP,JA4,TTL,MSS,Hits,Score,Menace,Pays,ASN,Fuzzing,Vélocité\n'; - const rows = ips.map(i => - [i.ip, i.ja4, i.tcp_ttl, i.tcp_mss, i.hits, - i.avg_score.toFixed(3), i.threat_level, i.country_code, - `"${i.asn_org}"`, i.fuzzing.toFixed(2), i.velocity.toFixed(2)].join(',') - ).join('\n'); - const blob = new Blob([header + rows], { type: 'text/csv' }); - const a = document.createElement('a'); - a.href = URL.createObjectURL(blob); - a.download = `cluster_${cluster.id}.csv`; - a.click(); - }; - - const rc = riskColor(cluster.risk_score); - const radarData = cluster.radar - .filter(r => RADAR_FEATURES_SET.has(r.feature)) - .map(r => ({ subject: r.feature.replace('Vélocité (rps)', 'Vélocité'), val: Math.round(r.value * 100) })); - - return ( -
- {/* Header */} -
-
-
-

{cluster.label}

-

- {cluster.ip_count.toLocaleString()} IPs ·{' '} - {cluster.hit_count.toLocaleString()} requêtes -

-
- -
- {/* Risque */} -
-
- Score de risque - {Math.round(cluster.risk_score * 100)}% — {riskLabel(cluster.risk_score)} -
-
-
-
-
-
- -
- {/* Radar */} -
-

Profil Comportemental

- - - - - - - [`${v}%`]} - /> - - -
- - {/* Métriques */} -
-

Toutes les métriques

-
- {[ - ['Score anomalie ML', Math.min(1, cluster.mean_score / 0.5), rc], - ['UA-CH mismatch', cluster.mean_ua_ch, '#f97316'], - ['UA rotatif', cluster.mean_ua_rotating, '#ec4899'], - ['Fuzzing', Math.min(1, cluster.mean_fuzzing * 3), '#8b5cf6'], - ['Headless', cluster.mean_headless, '#dc2626'], - ['Vélocité', cluster.mean_velocity, '#6366f1'], - ['ALPN mismatch', cluster.mean_alpn_mismatch, '#14b8a6'], - ['IP-ID zéro', cluster.mean_ip_id_zero, '#f59e0b'], - ['Entropie temporelle',cluster.mean_entropy, '#06b6d4'], - ['Browser score', Math.min(1, cluster.mean_browser_score / 50), '#22c55e'], - ].map(([lbl, val, col]: any) => ( - - ))} -
-
- - {/* TCP */} -
-

Stack TCP

-
- {[ - ['TTL Initial', Math.round(cluster.mean_ttl)], - ['MSS', Math.round(cluster.mean_mss)], - ['Scale', cluster.mean_scale.toFixed(1)], - ].map(([k, v]) => ( -
-

{k}

-

{v}

-
- ))} -
-
- - {/* Meta */} -
- {cluster.top_threat && ( -
- Menace dominante - -
- )} - {cluster.top_countries.length > 0 && ( -

Pays : - {cluster.top_countries.join(', ')}

- )} - {cluster.top_orgs.length > 0 && ( -
- ASN : - {cluster.top_orgs.slice(0, 3).map((org, i) => ( -

• {org}

- ))} -
- )} - {cluster.sample_ua && ( -
- User-Agent type : -

{cluster.sample_ua}

-
- )} -
- - {/* Actions */} -
- - -
- - {/* Liste IPs */} -
-

- Adresses IP ({loading ? '…' : `${ips.length} / ${total.toLocaleString()}`}) -

- {loading ? ( -

Chargement…

- ) : ( -
- {ips.map((ip, i) => ( -
-
- {ip.ip} -
- - {ip.country_code && {ip.country_code}} -
-
-
- TTL {ip.tcp_ttl} - MSS {ip.tcp_mss} - {ip.hits.toLocaleString()} req - {ip.avg_score > 0.1 && ( - ⚠ {(ip.avg_score * 100).toFixed(0)}% - )} - {ip.asn_org && {ip.asn_org}} -
-
- ))} -
- )} -
-
-
- ); -} - -// ─── Vue Graphe (wrapper avec ReactFlowProvider) ─────────────────────────────── - -function GraphViewWrapper({ - data, selectedId, onSelect, -}: { - data: ClusteringData; - selectedId: string | null; - onSelect: (n: ClusterNode) => void; -}) { - return ( - - - - ); -} - -// ─── Composant principal ───────────────────────────────────────────────────── +// ─── Composant principal ────────────────────────────────────────────────────── export default function ClusteringView() { - const [data, setData] = useState(null); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(''); - const [k, setK] = useState(14); - const [pendingK, setPendingK] = useState(14); - const [view, setView] = useState<'cards' | 'graph'>('cards'); - const [selected, setSelected] = useState(null); + const [k, setK] = useState(14); + const [hours, setHours] = useState(24); + const [data, setData] = useState(null); + const [loading, setLoading] = useState(false); + const [computing, setComputing] = useState(false); + const [error, setError] = useState(null); + const [selected, setSelected] = useState(null); + const [clusterPoints, setClusterPoints] = useState([]); + const [ipDetails, setIpDetails] = useState([]); + const [ipPage, setIpPage] = useState(0); + const [ipTotal, setIpTotal] = useState(0); + const [showEdges, setShowEdges] = useState(false); + const pollRef = useRef | null>(null); - const fetchData = useCallback(async (kVal: number) => { + // Viewport deck.gl — centré à [WORLD/2, WORLD/2] + const [viewState, setViewState] = useState({ + target: [WORLD / 2, WORLD / 2, 0] as [number, number, number], + zoom: 0, + minZoom: -4, + maxZoom: 6, + }); + + // ── Chargement / polling ───────────────────────────────────────────────── + + const fetchClusters = useCallback(async (force = false) => { setLoading(true); - setError(''); - setSelected(null); + setError(null); try { - const r = await fetch(`/api/clustering/clusters?k=${kVal}&n_samples=3000`); - if (!r.ok) throw new Error(await r.text()); - setData(await r.json()); - } catch (e: any) { - setError(e.message || 'Erreur réseau'); + const res = await axios.get('/api/clustering/clusters', { + params: { k, hours, force }, + }); + if (res.data.status === 'computing' || res.data.status === 'idle') { + setComputing(true); + // Polling + pollRef.current = setTimeout(() => fetchClusters(), 3000); + } else { + setComputing(false); + setData(res.data); + // Fit viewport + if (res.data.nodes?.length) { + const xs = res.data.nodes.map(n => toWorld(n.pca_x)); + const ys = res.data.nodes.map(n => toWorld(n.pca_y)); + const minX = Math.min(...xs), maxX = Math.max(...xs); + const minY = Math.min(...ys), maxY = Math.max(...ys); + setViewState(v => ({ + ...v, + target: [(minX + maxX) / 2, (minY + maxY) / 2, 0], + zoom: Math.log2(Math.min(800 / (maxX - minX + 1), 600 / (maxY - minY + 1))) - 1, + })); + } + } + } catch (e: unknown) { + setError((e as Error).message); + setComputing(false); } finally { setLoading(false); } + }, [k, hours]); + + useEffect(() => { + fetchClusters(); + return () => { if (pollRef.current) clearTimeout(pollRef.current); }; + }, []); // eslint-disable-line + + // ── Drill-down : chargement des points du cluster sélectionné ─────────── + + const loadClusterPoints = useCallback(async (node: ClusterNode) => { + try { + const res = await axios.get<{ points: IPPoint[]; total: number }>( + `/api/clustering/cluster/${node.id}/points`, + { params: { limit: 10000, offset: 0 } } + ); + setClusterPoints(res.data.points); + } catch { setClusterPoints([]); } }, []); - useEffect(() => { fetchData(k); }, []); + const loadClusterIPs = useCallback(async (node: ClusterNode, page = 0) => { + try { + const res = await axios.get<{ ips: IPDetail[]; total: number }>( + `/api/clustering/cluster/${node.id}/ips`, + { params: { limit: 50, offset: page * 50 } } + ); + setIpDetails(res.data.ips); + setIpTotal(res.data.total); + setIpPage(page); + } catch { setIpDetails([]); } + }, []); - const applyK = () => { setK(pendingK); fetchData(pendingK); }; + const handleSelectCluster = useCallback((node: ClusterNode) => { + setSelected(node); + setClusterPoints([]); + setIpDetails([]); + loadClusterPoints(node); + loadClusterIPs(node, 0); + }, [loadClusterPoints, loadClusterIPs]); - const stats = data?.stats; + // ── Layers deck.gl ───────────────────────────────────────────────────── + + const layers = React.useMemo(() => { + if (!data?.nodes) return []; + const nodes = data.nodes; + const nodeMap = Object.fromEntries(nodes.map(n => [n.id, n])); + + const layerList: object[] = []; + + // 1. Hulls (enveloppes convexes) — toujours visibles + const hullData = nodes + .filter(n => n.hull && n.hull.length >= 3) + .map(n => ({ + ...n, + polygon: n.hull.map(([x, y]) => [toWorld(x), toWorld(y)]), + })); + + layerList.push(new PolygonLayer({ + id: 'hulls', + data: hullData, + getPolygon: (d: typeof hullData[number]) => d.polygon, + getFillColor: (d: typeof hullData[number]) => hexToRgba(d.color, d.id === selected?.id ? 55 : 28), + getLineColor: (d: typeof hullData[number]) => hexToRgba(d.color, d.id === selected?.id ? 220 : 130), + getLineWidth: (d: typeof hullData[number]) => d.id === selected?.id ? 3 : 1.5, + lineWidthUnits: 'pixels', + stroked: true, + filled: true, + pickable: true, + autoHighlight: true, + highlightColor: [255, 255, 255, 30], + onClick: ({ object }: { object?: typeof hullData[number] }) => { + if (object) handleSelectCluster(object as ClusterNode); + }, + updateTriggers: { getFillColor: [selected?.id], getLineColor: [selected?.id], getLineWidth: [selected?.id] }, + })); + + // 2. Arêtes inter-clusters (optionnelles) + if (showEdges && data.edges) { + const edgeData = data.edges + .map(e => { + const s = nodeMap[e.source]; + const t = nodeMap[e.target]; + if (!s || !t) return null; + return { source: [toWorld(s.pca_x), toWorld(s.pca_y)], target: [toWorld(t.pca_x), toWorld(t.pca_y)], sim: e.similarity }; + }) + .filter(Boolean) as { source: [number, number]; target: [number, number]; sim: number }[]; + + layerList.push(new LineLayer({ + id: 'edges', + data: edgeData, + getSourcePosition: d => d.source, + getTargetPosition: d => d.target, + getColor: [100, 100, 120, 80], + getWidth: 1, + widthUnits: 'pixels', + })); + } + + // 3. Points IPs du cluster sélectionné + if (selected && clusterPoints.length > 0) { + layerList.push(new ScatterplotLayer({ + id: 'ip-points', + data: clusterPoints, + getPosition: (d: IPPoint) => [toWorld(d.pca_x), toWorld(d.pca_y), 0], + getRadius: 3, + radiusUnits: 'pixels', + getFillColor: (d: IPPoint) => { + const r = d.risk; + if (r > 0.70) return [220, 38, 38, 200]; + if (r > 0.45) return [249, 115, 22, 200]; + if (r > 0.25) return [234, 179, 8, 200]; + return [34, 197, 94, 180]; + }, + pickable: false, + updateTriggers: { getPosition: [clusterPoints.length] }, + })); + } + + // 4. Centroïdes (cercles de taille ∝ ip_count) + layerList.push(new ScatterplotLayer({ + id: 'centroids', + data: nodes, + getPosition: (d: ClusterNode) => [toWorld(d.pca_x), toWorld(d.pca_y), 0], + getRadius: (d: ClusterNode) => d.radius, + radiusUnits: 'pixels', + getFillColor: (d: ClusterNode) => hexToRgba(d.color, d.id === selected?.id ? 255 : 180), + getLineColor: [255, 255, 255, 180], + getLineWidth: (d: ClusterNode) => d.id === selected?.id ? 3 : 1, + lineWidthUnits: 'pixels', + stroked: true, + filled: true, + pickable: true, + autoHighlight: true, + highlightColor: [255, 255, 255, 60], + onClick: ({ object }: { object?: ClusterNode }) => { + if (object) handleSelectCluster(object); + }, + updateTriggers: { getFillColor: [selected?.id], getLineWidth: [selected?.id] }, + })); + + // 5. Labels (TextLayer) — strip emojis (non supportés par le bitmap font deck.gl) + const stripEmoji = (s: string) => s.replace(/[\u{1F000}-\u{1FFFF}\u{2600}-\u{27FF}]/gu, '').trim(); + layerList.push(new TextLayer({ + id: 'labels', + data: nodes, + getPosition: (d: ClusterNode) => [toWorld(d.pca_x), toWorld(d.pca_y), 0], + getText: (d: ClusterNode) => stripEmoji(d.label), + getSize: 12, + sizeUnits: 'pixels', + getColor: [255, 255, 255, 200], + getAnchor: 'middle', + getAlignmentBaseline: 'top', + getPixelOffset: (d: ClusterNode) => [0, d.radius + 4], + fontFamily: 'monospace', + background: true, + getBorderColor: [0, 0, 0, 0], + backgroundPadding: [3, 1, 3, 1], + getBackgroundColor: [15, 20, 30, 180], + })); + + return layerList; + }, [data, selected, clusterPoints, showEdges, handleSelectCluster]); + + // ── Rendering ──────────────────────────────────────────────────────────── return ( -
- {/* ── Barre de contrôle ── */} -
- {/* Slider k */} -
- k = - setPendingK(Number(e.target.value))} - className="w-24 accent-indigo-500" /> - {pendingK} -
- {/* Onglets vue */} -
- {(['cards', 'graph'] as const).map(v => ( - - ))} -
+ {/* Stats globales */} + {data?.stats && ( +
+
Résultats
+ + + + + + +
+ )} - {/* Stats */} - {stats && !loading && ( -
- - - - - {stats.elapsed_s}s + {/* Message computing */} + {computing && ( +
+ ⏳ Calcul en cours sur {data?.stats?.n_samples?.toLocaleString() ?? '…'} IPs… +
Mise à jour automatique toutes les 3s +
+ )} + + {error && ( +
+ ❌ {error} +
+ )} + + {/* Liste clusters */} + {data?.nodes && ( +
+
Clusters
+ {[...data.nodes] + .sort((a, b) => b.risk_score - a.risk_score) + .map(n => ( + + ))}
)}
- {/* ── Erreur ── */} - {error && ( -
-
-

⚠️

-

Erreur de clustering

-

{error}

- + {/* ── Canvas WebGL (deck.gl) ── */} +
+ {!data && !loading && !computing && ( +
+ Cliquez sur Recalculer pour démarrer
-
- )} - - {/* ── Chargement ── */} - {loading && ( -
-
-
⚙️
-

Calcul K-means++ en cours…

-

Normalisation 21 features · PCA-2D · Nommage automatique

+ )} + setViewState(vs as typeof viewState)} + layers={layers as any} + style={{ width: '100%', height: '100%' }} + controller={true} + > + {/* Légende overlay */} +
+
+ {[['#dc2626', 'CRITICAL'], ['#f97316', 'HIGH'], ['#eab308', 'MEDIUM'], ['#22c55e', 'LOW']].map(([c, l]) => ( +
+ + {l} +
+ ))} +
-
- )} + {/* Tooltip zoom hint */} +
+
Scroll pour zoomer · Drag pour déplacer · Click sur un cluster
+
+ +
- {/* ── Contenu principal ── */} - {data && !loading && ( -
- {view === 'cards' ? ( - setSelected(prev => prev?.id === n.id ? null : n)} - /> - ) : ( - setSelected(prev => prev?.id === n.id ? null : n)} - /> - )} - - {/* Sidebar */} - {selected && ( - setSelected(null)} - /> - )} -
+ {/* ── Sidebar droite (sélection) ── */} + {selected && ( + { setSelected(null); setClusterPoints([]); setIpDetails([]); }} + onPageChange={(p) => loadClusterIPs(selected, p)} + /> )}
); } -// ─── Petit composant stat ───────────────────────────────────────────────────── +// ─── Stat helper ───────────────────────────────────────────────────────────── -function Stat({ label, value, color = 'text-text-primary' }: { label: string; value: string | number; color?: string }) { +function Stat({ label, value, color }: { label: string; value: string | number; color?: string }) { return ( - - {value} {label} - +
+ {label} + {value} +
+ ); +} + +// ─── Sidebar détaillée ─────────────────────────────────────────────────────── + +function ClusterSidebar({ node, ipDetails, ipTotal, ipPage, clusterPoints, onClose, onPageChange }: { + node: ClusterNode; + ipDetails: IPDetail[]; + ipTotal: number; + ipPage: number; + clusterPoints: IPPoint[]; + onClose: () => void; + onPageChange: (p: number) => void; +}) { + const riskLabel = (r: number) => + r > 0.70 ? 'CRITICAL' : r > 0.45 ? 'HIGH' : r > 0.25 ? 'MEDIUM' : 'LOW'; + const riskClass = (r: number) => + r > 0.70 ? 'text-red-500' : r > 0.45 ? 'text-orange-500' : r > 0.25 ? 'text-yellow-400' : 'text-green-500'; + + const totalPages = Math.ceil(ipTotal / 50); + + const exportCSV = () => { + const header = 'IP,JA4,TTL,MSS,Hits,Score,Menace,Pays,ASN\n'; + const rows = ipDetails.map(ip => + [ip.ip, ip.ja4, ip.tcp_ttl, ip.tcp_mss, ip.hits, ip.avg_score, ip.threat_level, ip.country_code, ip.asn_org].join(',') + ).join('\n'); + const blob = new Blob([header + rows], { type: 'text/csv' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); a.href = url; a.download = `cluster_${node.id}.csv`; a.click(); + }; + + return ( +
+ {/* Header */} +
+
+
{node.label}
+
{node.ip_count.toLocaleString()} IPs · {node.hit_count.toLocaleString()} hits
+
+
+ {riskLabel(node.risk_score)} + +
+
+ +
+ {/* Score risque */} +
+
Score de risque
+
+
+
+
+ + {(node.risk_score * 100).toFixed(0)}% + +
+
+ + {/* Radar chart */} + {node.radar?.length > 0 && ( +
+
Profil 21 features
+ + + + + + [`${(v * 100).toFixed(1)}%`, '']} + /> + + +
+ )} + + {/* TCP stack */} +
+
Stack TCP
+ + + + + + +
+ + {/* Contexte */} + {(node.top_countries?.length > 0 || node.top_orgs?.length > 0) && ( +
+
Géographie & AS
+ {node.top_countries?.length > 0 && ( +
+ {node.top_countries.map(c => ( + {c} + ))} +
+ )} + {node.top_orgs?.length > 0 && ( +
+ {node.top_orgs.map(o => ( +
{o}
+ ))} +
+ )} +
+ )} + + {/* IPs paginées */} +
+
+ IPs ({ipTotal.toLocaleString()}) + +
+ {ipDetails.length === 0 ? ( +
Chargement…
+ ) : ( +
+ {ipDetails.map(ip => ( +
+ 0.45 ? 'bg-red-500' : ip.avg_score > 0.25 ? 'bg-orange-400' : 'bg-green-500' + }`} + /> + {ip.ip} + {ip.country_code} + {ip.hits} +
+ ))} +
+ )} + {/* Pagination */} + {totalPages > 1 && ( +
+ + {ipPage + 1} / {totalPages} + +
+ )} +
+ + {/* Points info */} + {clusterPoints.length > 0 && ( +
+ {clusterPoints.length.toLocaleString()} IPs affichées en WebGL +
+ )} +
+
); } diff --git a/requirements.txt b/requirements.txt index 785e89a..a88bcc7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ pydantic==2.5.0 pydantic-settings==2.1.0 python-dotenv==1.0.0 httpx==0.26.0 +numpy>=1.26 +scipy>=1.11