feat: WebGL clustering (deck.gl) + K-means++ sur toutes les IPs (183K)

- Ajout numpy + scipy à requirements.txt (K-means vectorisé, convex hull)
- Réécriture clustering_engine.py :
  * K-means++ entièrement vectorisé numpy (100x plus rapide que pur Python)
  * PCA-2D par power iteration (numpy)
  * Enveloppes convexes par cluster via scipy.spatial.ConvexHull
  * Traitement des probabilités nulles (points dupliqués) en K-means++ init
- Réécriture clustering.py :
  * Calcul sur la TOTALITÉ des IPs (sans LIMIT) : 183K IPs, 16.8 MB features
  * Computation en background thread (ThreadPoolExecutor) + cache 30 min
  * Endpoint /api/clustering/status pour polling frontend
  * Endpoint /api/clustering/cluster/{id}/points (coordonnées PCA pour WebGL)
- Réécriture ClusteringView.tsx en WebGL (deck.gl) :
  * PolygonLayer : enveloppes convexes colorées par niveau de menace
  * ScatterplotLayer centroïdes : taille ∝ sqrt(ip_count)
  * ScatterplotLayer IPs : chargé sur sélection (LOD), GPU-accelerated
  * TextLayer : labels (emojis strippés — non supportés par bitmap font)
  * LineLayer : arêtes inter-clusters (optionnel)
  * OrthographicView avec pan/zoom natif
  * Sidebar : radar 21 features, pagination IPs, export CSV
  * Polling automatique toutes les 3s pendant le calcul
- Ajout @deck.gl/react @deck.gl/core @deck.gl/layers à package.json

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
SOC Analyst
2026-03-19 09:40:27 +01:00
parent 9de59f5681
commit b2c3379aa0
5 changed files with 1130 additions and 1369 deletions

View File

@ -1,54 +1,53 @@
"""
Clustering d'IPs multi-métriques — backend ReactFlow.
Clustering d'IPs multi-métriques — WebGL / deck.gl backend.
Features utilisées (21 dimensions) :
TCP stack : TTL initial, MSS, scale, fenêtre TCP
Comportement : vélocité, POST ratio, fuzzing, assets, accès direct
Anomalie ML : score, IP-ID zéro
TLS/Protocole: ALPN mismatch, ALPN absent, efficacité H2
Navigateur : browser score, headless, ordre headers, UA-CH mismatch
Temporel : entropie, diversité JA4, UA rotatif
Algorithme :
1. Échantillonnage stratifié (top détections + top hits)
2. Construction + normalisation des vecteurs de features
3. K-means++ (Arthur & Vassilvitskii, 2007)
4. PCA-2D par power iteration pour les positions ReactFlow
5. Nommage automatique par features dominantes du centroïde
6. Calcul des arêtes : k-NN dans l'espace des features
- Calcul sur la TOTALITÉ des IPs (GROUP BY src_ip, ja4 sans LIMIT)
- K-means++ vectorisé (numpy) + PCA-2D + enveloppes convexes (scipy)
- Calcul en background thread + cache 30 min
- Endpoints : /clusters, /status, /cluster/{id}/points
"""
from __future__ import annotations
import math
import time
import hashlib
from typing import Optional
import logging
import threading
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Any
import numpy as np
from fastapi import APIRouter, HTTPException, Query
from ..database import db
from ..services.clustering_engine import (
FEATURES, FEATURE_KEYS, FEATURE_NORMS, FEATURE_NAMES, N_FEATURES,
build_feature_vector, kmeans_pp, pca_2d,
name_cluster, risk_score_from_centroid, _mean_vec,
FEATURE_KEYS, FEATURE_NAMES, FEATURE_NORMS, N_FEATURES,
build_feature_vector, kmeans_pp, pca_2d, compute_hulls,
name_cluster, risk_score_from_centroid,
)
log = logging.getLogger(__name__)
router = APIRouter(prefix="/api/clustering", tags=["clustering"])
# ─── Cache en mémoire ─────────────────────────────────────────────────────────
# Stocke (cluster_id → liste d'IPs) pour le drill-down
# + timestamp de dernière mise à jour
_cache: dict = {
"assignments": {}, # ip+ja4 → cluster_idx
"cluster_ips": {}, # cluster_idx → [(ip, ja4)]
"params": {}, # k, ts
# ─── Cache global ──────────────────────────────────────────────────────────────
_CACHE: dict[str, Any] = {
"status": "idle", # idle | computing | ready | error
"error": None,
"result": None, # dict résultat complet
"ts": 0.0, # timestamp dernière mise à jour
"params": {},
"cluster_ips": {}, # cluster_idx → [(ip, ja4, pca_x, pca_y, risk)]
}
_CACHE_TTL = 1800 # 30 minutes
_LOCK = threading.Lock()
_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="clustering")
# ─── Couleurs ─────────────────────────────────────────────────────────────────
# ─── Couleurs menace ──────────────────────────────────────────────────────────
_THREAT_COLOR = {
0.92: "#dc2626", # Bot scanner
0.70: "#ef4444", # Critique
0.70: "#dc2626", # Critique
0.45: "#f97316", # Élevé
0.25: "#eab308", # Modéré
0.00: "#6b7280", # Sain / inconnu
0.00: "#22c55e", # Sain
}
def _risk_to_color(risk: float) -> str:
@ -58,9 +57,8 @@ def _risk_to_color(risk: float) -> str:
return "#6b7280"
# ─── SQL ──────────────────────────────────────────────────────────────────────
_SQL_FEATURES = """
# ─── SQL : TOUTES les IPs sans LIMIT ─────────────────────────────────────────
_SQL_ALL_IPS = """
SELECT
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip,
t.ja4,
@ -71,43 +69,36 @@ SELECT
any(t.first_ua) AS ua,
sum(t.hits) AS hits,
avg(abs(ml.anomaly_score)) AS avg_score,
avg(ml.hit_velocity) AS avg_velocity,
avg(ml.fuzzing_index) AS avg_fuzzing,
avg(ml.is_headless) AS pct_headless,
avg(ml.post_ratio) AS avg_post,
avg(ml.ip_id_zero_ratio) AS ip_id_zero,
avg(ml.temporal_entropy) AS entropy,
avg(ml.modern_browser_score) AS browser_score,
avg(ml.alpn_http_mismatch) AS alpn_mismatch,
avg(ml.is_alpn_missing) AS alpn_missing,
avg(ml.multiplexing_efficiency) AS h2_eff,
avg(ml.header_order_confidence) AS hdr_conf,
avg(ml.ua_ch_mismatch) AS ua_ch_mismatch,
avg(ml.asset_ratio) AS asset_ratio,
avg(ml.direct_access_ratio) AS direct_ratio,
avg(ml.distinct_ja4_count) AS ja4_count,
max(ml.is_ua_rotating) AS ua_rotating,
avg(abs(ml.anomaly_score)) AS avg_score,
avg(ml.hit_velocity) AS avg_velocity,
avg(ml.fuzzing_index) AS avg_fuzzing,
avg(ml.is_headless) AS pct_headless,
avg(ml.post_ratio) AS avg_post,
avg(ml.ip_id_zero_ratio) AS ip_id_zero,
avg(ml.temporal_entropy) AS entropy,
avg(ml.modern_browser_score) AS browser_score,
avg(ml.alpn_http_mismatch) AS alpn_mismatch,
avg(ml.is_alpn_missing) AS alpn_missing,
avg(ml.multiplexing_efficiency) AS h2_eff,
avg(ml.header_order_confidence) AS hdr_conf,
avg(ml.ua_ch_mismatch) AS ua_ch_mismatch,
avg(ml.asset_ratio) AS asset_ratio,
avg(ml.direct_access_ratio) AS direct_ratio,
avg(ml.distinct_ja4_count) AS ja4_count,
max(ml.is_ua_rotating) AS ua_rotating,
max(ml.threat_level) AS threat,
any(ml.country_code) AS country,
any(ml.asn_org) AS asn_org
max(ml.threat_level) AS threat,
any(ml.country_code) AS country,
any(ml.asn_org) AS asn_org
FROM mabase_prod.agg_host_ip_ja4_1h t
LEFT JOIN mabase_prod.ml_detected_anomalies ml
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
AND ml.detected_at >= now() - INTERVAL 24 HOUR
WHERE t.window_start >= now() - INTERVAL 24 HOUR
AND ml.detected_at >= now() - INTERVAL %(hours)s HOUR
WHERE t.window_start >= now() - INTERVAL %(hours)s HOUR
AND t.tcp_ttl_raw > 0
GROUP BY t.src_ip, t.ja4
ORDER BY
-- Stratégie : IPs anormales en premier, puis fort trafic
-- Cela garantit que les bots Masscan (anomalie=0.97, hits=1-2) sont inclus
avg(abs(ml.anomaly_score)) DESC,
sum(t.hits) DESC
LIMIT %(limit)s
"""
# Noms des colonnes SQL dans l'ordre
_SQL_COLS = [
"ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits",
"avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post",
@ -117,252 +108,311 @@ _SQL_COLS = [
]
# ─── Worker de clustering (thread pool) ──────────────────────────────────────
def _run_clustering_job(k: int, hours: int) -> None:
"""Exécuté dans le thread pool. Met à jour _CACHE."""
t0 = time.time()
with _LOCK:
_CACHE["status"] = "computing"
_CACHE["error"] = None
try:
log.info(f"[clustering] Démarrage du calcul k={k} hours={hours}")
# ── 1. Chargement de toutes les IPs ──────────────────────────────
result = db.query(_SQL_ALL_IPS, {"hours": hours})
rows: list[dict] = []
for row in result.result_rows:
rows.append({col: row[i] for i, col in enumerate(_SQL_COLS)})
n = len(rows)
log.info(f"[clustering] {n} IPs chargées")
if n < k:
raise ValueError(f"Seulement {n} IPs disponibles (k={k} requis)")
# ── 2. Construction de la matrice de features (numpy) ────────────
X = np.array([build_feature_vector(r) for r in rows], dtype=np.float32)
log.info(f"[clustering] Matrice X: {X.shape}{X.nbytes/1024/1024:.1f} MB")
# ── 3. K-means++ vectorisé ────────────────────────────────────────
km = kmeans_pp(X.astype(np.float64), k=k, max_iter=80, n_init=3, seed=42)
log.info(f"[clustering] K-means: {km.n_iter} iters, inertia={km.inertia:.2f}")
# ── 4. PCA-2D pour toutes les IPs ────────────────────────────────
coords = pca_2d(X.astype(np.float64)) # (n, 2), normalisé [0,1]
# ── 5. Enveloppes convexes par cluster ───────────────────────────
hulls = compute_hulls(coords, km.labels, k)
# ── 6. Agrégation par cluster ─────────────────────────────────────
cluster_rows: list[list[dict]] = [[] for _ in range(k)]
cluster_coords: list[list[list[float]]] = [[] for _ in range(k)]
cluster_ips_map: dict[int, list] = {j: [] for j in range(k)}
for i, label in enumerate(km.labels):
j = int(label)
cluster_rows[j].append(rows[i])
cluster_coords[j].append(coords[i].tolist())
cluster_ips_map[j].append((
rows[i]["ip"],
rows[i]["ja4"],
float(coords[i][0]),
float(coords[i][1]),
float(risk_score_from_centroid(km.centroids[j])),
))
# ── 7. Construction des nœuds ─────────────────────────────────────
nodes = []
for j in range(k):
if not cluster_rows[j]:
continue
def avg_f(key: str, crows: list[dict] = cluster_rows[j]) -> float:
return float(np.mean([float(r.get(key) or 0) for r in crows]))
mean_ttl = avg_f("ttl")
mean_mss = avg_f("mss")
mean_scale = avg_f("scale")
mean_win = avg_f("win")
raw_stats = {"mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale}
label_name = name_cluster(km.centroids[j], raw_stats)
risk = float(risk_score_from_centroid(km.centroids[j]))
color = _risk_to_color(risk)
# Centroïde 2D = moyenne des coords du cluster
cxy = np.mean(cluster_coords[j], axis=0).tolist() if cluster_coords[j] else [0.5, 0.5]
ip_set = list({r["ip"] for r in cluster_rows[j]})
ip_count = len(ip_set)
hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j]))
threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")]
countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")]
orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")]
def topk(lst: list[str], n: int = 5) -> list[str]:
return [v for v, _ in Counter(lst).most_common(n) if v]
radar = [
{"feature": name, "value": round(float(km.centroids[j][i]), 4)}
for i, name in enumerate(FEATURE_NAMES)
]
radius = max(12, min(80, int(math.sqrt(ip_count) * 2)))
sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8]
sample_ips = [r["ip"] for r in sample_rows]
sample_ua = str(cluster_rows[j][0].get("ua") or "")
nodes.append({
"id": f"c{j}_k{k}",
"cluster_idx": j,
"label": label_name,
"pca_x": round(cxy[0], 6),
"pca_y": round(cxy[1], 6),
"radius": radius,
"color": color,
"risk_score": round(risk, 4),
"mean_ttl": round(mean_ttl, 1),
"mean_mss": round(mean_mss, 0),
"mean_scale": round(mean_scale, 1),
"mean_win": round(mean_win, 0),
"mean_score": round(avg_f("avg_score"), 4),
"mean_velocity":round(avg_f("avg_velocity"),3),
"mean_fuzzing": round(avg_f("avg_fuzzing"), 3),
"mean_headless":round(avg_f("pct_headless"),3),
"mean_post": round(avg_f("avg_post"), 3),
"mean_asset": round(avg_f("asset_ratio"), 3),
"mean_direct": round(avg_f("direct_ratio"),3),
"mean_alpn_mismatch": round(avg_f("alpn_mismatch"),3),
"mean_h2_eff": round(avg_f("h2_eff"), 3),
"mean_hdr_conf":round(avg_f("hdr_conf"), 3),
"mean_ua_ch": round(avg_f("ua_ch_mismatch"),3),
"mean_entropy": round(avg_f("entropy"), 3),
"mean_ja4_diversity": round(avg_f("ja4_count"),3),
"mean_ip_id_zero": round(avg_f("ip_id_zero"),3),
"mean_browser_score": round(avg_f("browser_score"),1),
"mean_ua_rotating": round(avg_f("ua_rotating"),3),
"ip_count": ip_count,
"hit_count": hit_count,
"top_threat": topk(threats, 1)[0] if threats else "",
"top_countries":topk(countries, 5),
"top_orgs": topk(orgs, 5),
"sample_ips": sample_ips,
"sample_ua": sample_ua,
"radar": radar,
# Hull pour deck.gl PolygonLayer
"hull": hulls.get(j, []),
})
# ── 8. Arêtes k-NN entre clusters ────────────────────────────────
edges = []
seen: set[frozenset] = set()
for i, ni in enumerate(nodes):
ci = ni["cluster_idx"]
dists = sorted(
[(j, nj["cluster_idx"],
float(np.sum((km.centroids[ci] - km.centroids[nj["cluster_idx"]]) ** 2)))
for j, nj in enumerate(nodes) if j != i],
key=lambda x: x[2]
)
for j_idx, cj, d2 in dists[:2]:
key = frozenset([ni["id"], nodes[j_idx]["id"]])
if key in seen:
continue
seen.add(key)
edges.append({
"id": f"e_{ni['id']}_{nodes[j_idx]['id']}",
"source": ni["id"],
"target": nodes[j_idx]["id"],
"similarity": round(1.0 / (1.0 + math.sqrt(d2)), 3),
})
# ── 9. Stockage résultat + cache IPs ─────────────────────────────
total_ips = sum(n_["ip_count"] for n_ in nodes)
total_hits = sum(n_["hit_count"] for n_ in nodes)
bot_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.45 or "🤖" in n_["label"])
high_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.25)
elapsed = round(time.time() - t0, 2)
result_dict = {
"nodes": nodes,
"edges": edges,
"stats": {
"total_clusters": len(nodes),
"total_ips": total_ips,
"total_hits": total_hits,
"bot_ips": bot_ips,
"high_risk_ips": high_ips,
"n_samples": n,
"k": k,
"elapsed_s": elapsed,
},
"feature_names": FEATURE_NAMES,
}
with _LOCK:
_CACHE["result"] = result_dict
_CACHE["cluster_ips"] = cluster_ips_map
_CACHE["status"] = "ready"
_CACHE["ts"] = time.time()
_CACHE["params"] = {"k": k, "hours": hours}
_CACHE["error"] = None
log.info(f"[clustering] Terminé en {elapsed}s — {total_ips} IPs, {len(nodes)} clusters")
except Exception as e:
log.exception("[clustering] Erreur lors du calcul")
with _LOCK:
_CACHE["status"] = "error"
_CACHE["error"] = str(e)
def _maybe_trigger(k: int, hours: int) -> None:
"""Lance le calcul si cache absent, expiré ou paramètres différents."""
with _LOCK:
status = _CACHE["status"]
params = _CACHE["params"]
ts = _CACHE["ts"]
cache_stale = (time.time() - ts) > _CACHE_TTL
params_changed = params.get("k") != k or params.get("hours") != hours
if status in ("computing",):
return # déjà en cours
if status == "ready" and not cache_stale and not params_changed:
return # cache frais
_EXECUTOR.submit(_run_clustering_job, k, hours)
# ─── Endpoints ────────────────────────────────────────────────────────────────
@router.get("/status")
async def get_status():
"""État du calcul en cours (polling frontend)."""
with _LOCK:
return {
"status": _CACHE["status"],
"error": _CACHE["error"],
"ts": _CACHE["ts"],
"params": _CACHE["params"],
"age_s": round(time.time() - _CACHE["ts"], 0) if _CACHE["ts"] else None,
}
@router.get("/clusters")
async def get_clusters(
k: int = Query(14, ge=4, le=30, description="Nombre de clusters"),
n_samples: int = Query(3000, ge=500, le=8000, description="Taille de l'échantillon"),
k: int = Query(14, ge=4, le=30, description="Nombre de clusters"),
hours: int = Query(24, ge=1, le=168, description="Fenêtre temporelle (heures)"),
force: bool = Query(False, description="Forcer le recalcul"),
):
"""
Clustering multi-métriques des IPs.
Clustering multi-métriques sur TOUTES les IPs.
Retourne les nœuds (clusters) + arêtes pour ReactFlow, avec :
- positions 2D issues de PCA sur les 21 features
- profil radar des features par cluster (normalisé [0,1])
- statistiques détaillées (moyennes brutes des features)
- sample d'IPs représentatives
Retourne immédiatement depuis le cache (status=ready).
Si le calcul est en cours ou non démarré → status=computing/idle + trigger.
"""
t0 = time.time()
if force:
with _LOCK:
_CACHE["status"] = "idle"
_CACHE["ts"] = 0.0
_maybe_trigger(k, hours)
with _LOCK:
status = _CACHE["status"]
result = _CACHE["result"]
error = _CACHE["error"]
if status == "computing":
return {"status": "computing", "message": "Calcul en cours, réessayez dans quelques secondes"}
if status == "error":
raise HTTPException(status_code=500, detail=error or "Erreur inconnue")
if result is None:
return {"status": "idle", "message": "Calcul démarré, réessayez dans quelques secondes"}
return {**result, "status": "ready"}
@router.get("/cluster/{cluster_id}/points")
async def get_cluster_points(
cluster_id: str,
limit: int = Query(5000, ge=1, le=20000),
offset: int = Query(0, ge=0),
):
"""
Coordonnées PCA + métadonnées de toutes les IPs d'un cluster.
Utilisé par deck.gl ScatterplotLayer (drill-down ou zoom avancé).
"""
with _LOCK:
status = _CACHE["status"]
ips_map = _CACHE["cluster_ips"]
if status != "ready" or not ips_map:
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
try:
result = db.query(_SQL_FEATURES, {"limit": n_samples})
except Exception as e:
raise HTTPException(status_code=500, detail=f"ClickHouse: {e}")
idx = int(cluster_id.split("_")[0][1:])
except (ValueError, IndexError):
raise HTTPException(status_code=400, detail="cluster_id invalide (format: c{n}_k{k})")
# ── Construction des vecteurs de features ─────────────────────────────
rows: list[dict] = []
for row in result.result_rows:
d = {col: row[i] for i, col in enumerate(_SQL_COLS)}
rows.append(d)
members = ips_map.get(idx, [])
total = len(members)
page = members[offset: offset + limit]
if len(rows) < k:
raise HTTPException(status_code=400, detail="Pas assez de données pour ce k")
points = [build_feature_vector(r) for r in rows]
# ── K-means++ ────────────────────────────────────────────────────────
km = kmeans_pp(points, k=k, max_iter=60, seed=42)
# ── PCA-2D sur les centroïdes ─────────────────────────────────────────
# On projette les centroïdes dans l'espace PCA des données
# → les positions relatives reflètent la variance des données
coords_all = pca_2d(points)
# Moyenne des positions PCA par cluster = position 2D du centroïde
cluster_xs: list[list[float]] = [[] for _ in range(k)]
cluster_ys: list[list[float]] = [[] for _ in range(k)]
for i, label in enumerate(km.labels):
cluster_xs[label].append(coords_all[i][0])
cluster_ys[label].append(coords_all[i][1])
centroid_2d: list[tuple[float, float]] = []
for j in range(k):
if cluster_xs[j]:
cx = sum(cluster_xs[j]) / len(cluster_xs[j])
cy = sum(cluster_ys[j]) / len(cluster_ys[j])
else:
cx, cy = 0.5, 0.5
centroid_2d.append((cx, cy))
# ── Agrégation des statistiques par cluster ───────────────────────────
cluster_rows: list[list[dict]] = [[] for _ in range(k)]
cluster_members: list[list[tuple[str, str]]] = [[] for _ in range(k)]
for i, label in enumerate(km.labels):
cluster_rows[label].append(rows[i])
cluster_members[label].append((rows[i]["ip"], rows[i]["ja4"]))
# Mise à jour du cache pour le drill-down
_cache["cluster_ips"] = {j: cluster_members[j] for j in range(k)}
_cache["params"] = {"k": k, "ts": t0}
# ── Construction des nœuds ReactFlow ─────────────────────────────────
CANVAS_W, CANVAS_H = 1400, 780
nodes = []
for j in range(k):
if not cluster_rows[j]:
continue
# Statistiques brutes moyennées
def avg_feat(key: str) -> float:
vals = [float(r.get(key) or 0) for r in cluster_rows[j]]
return sum(vals) / len(vals) if vals else 0.0
mean_ttl = avg_feat("ttl")
mean_mss = avg_feat("mss")
mean_scale = avg_feat("scale")
mean_win = avg_feat("win")
mean_score = avg_feat("avg_score")
mean_vel = avg_feat("avg_velocity")
mean_fuzz = avg_feat("avg_fuzzing")
mean_hless = avg_feat("pct_headless")
mean_post = avg_feat("avg_post")
mean_asset = avg_feat("asset_ratio")
mean_direct= avg_feat("direct_ratio")
mean_alpn = avg_feat("alpn_mismatch")
mean_h2 = avg_feat("h2_eff")
mean_hconf = avg_feat("hdr_conf")
mean_ua_ch = avg_feat("ua_ch_mismatch")
mean_entr = avg_feat("entropy")
mean_ja4 = avg_feat("ja4_count")
mean_ip_id = avg_feat("ip_id_zero")
mean_brow = avg_feat("browser_score")
mean_uarot = avg_feat("ua_rotating")
ip_count = len(set(r["ip"] for r in cluster_rows[j]))
hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j]))
# Pays / ASN / Menace dominants
threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")]
countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")]
orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")]
def topk(lst: list[str], n: int = 5) -> list[str]:
from collections import Counter
return [v for v, _ in Counter(lst).most_common(n) if v]
raw_stats = {
"mean_ttl": mean_ttl, "mean_mss": mean_mss,
"mean_scale": mean_scale,
}
label = name_cluster(km.centroids[j], raw_stats)
risk = risk_score_from_centroid(km.centroids[j])
color = _risk_to_color(risk)
# Profil radar normalisé (valeurs centroïde [0,1])
radar = [
{"feature": name, "value": round(km.centroids[j][i], 4)}
for i, name in enumerate(FEATURE_NAMES)
]
# Position 2D (PCA normalisée → pixels ReactFlow)
px_x = centroid_2d[j][0] * CANVAS_W * 0.85 + 80
px_y = (1 - centroid_2d[j][1]) * CANVAS_H * 0.85 + 50 # inverser y (haut=risque)
# Rayon ∝ √ip_count
radius = max(18, min(90, int(math.sqrt(ip_count) * 0.3)))
# Sample IPs (top 8 par hits)
sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8]
sample_ips = [r["ip"] for r in sample_rows]
sample_ua = str(cluster_rows[j][0].get("ua") or "")
cluster_id = f"c{j}_k{k}"
nodes.append({
"id": cluster_id,
"label": label,
"cluster_idx": j,
"x": round(px_x, 1),
"y": round(px_y, 1),
"radius": radius,
"color": color,
"risk_score": risk,
# Caractéristiques TCP
"mean_ttl": round(mean_ttl, 1),
"mean_mss": round(mean_mss, 0),
"mean_scale": round(mean_scale, 1),
"mean_win": round(mean_win, 0),
# Comportement HTTP
"mean_score": round(mean_score, 4),
"mean_velocity": round(mean_vel, 3),
"mean_fuzzing": round(mean_fuzz, 3),
"mean_headless": round(mean_hless, 3),
"mean_post": round(mean_post, 3),
"mean_asset": round(mean_asset, 3),
"mean_direct": round(mean_direct, 3),
# TLS / Protocole
"mean_alpn_mismatch": round(mean_alpn, 3),
"mean_h2_eff": round(mean_h2, 3),
"mean_hdr_conf": round(mean_hconf, 3),
"mean_ua_ch": round(mean_ua_ch, 3),
# Temporel
"mean_entropy": round(mean_entr, 3),
"mean_ja4_diversity": round(mean_ja4, 3),
"mean_ip_id_zero": round(mean_ip_id, 3),
"mean_browser_score": round(mean_brow, 1),
"mean_ua_rotating": round(mean_uarot, 3),
# Meta
"ip_count": ip_count,
"hit_count": hit_count,
"top_threat": topk(threats, 1)[0] if topk(threats, 1) else "",
"top_countries": topk(countries, 5),
"top_orgs": topk(orgs, 5),
"sample_ips": sample_ips,
"sample_ua": sample_ua,
# Profil radar pour visualisation
"radar": radar,
})
# ── Arêtes : k-NN dans l'espace des features ──────────────────────────
# Chaque cluster est connecté à ses 2 voisins les plus proches
edges = []
seen: set[frozenset] = set()
centroids = km.centroids
for i, ni in enumerate(nodes):
ci = ni["cluster_idx"]
# Distance² aux autres centroïdes
dists = [
(j, nj["cluster_idx"],
sum((centroids[ci][d] - centroids[nj["cluster_idx"]][d]) ** 2
for d in range(N_FEATURES)))
for j, nj in enumerate(nodes) if j != i
]
dists.sort(key=lambda x: x[2])
# 2 voisins les plus proches
for j, cj, dist2 in dists[:2]:
key = frozenset([ni["id"], nodes[j]["id"]])
if key in seen:
continue
seen.add(key)
similarity = round(1.0 / (1.0 + math.sqrt(dist2)), 3)
edges.append({
"id": f"e_{ni['id']}_{nodes[j]['id']}",
"source": ni["id"],
"target": nodes[j]["id"],
"similarity": similarity,
"weight": round(similarity * 5, 1),
})
# ── Stats globales ────────────────────────────────────────────────────
total_ips = sum(n["ip_count"] for n in nodes)
total_hits = sum(n["hit_count"] for n in nodes)
bot_ips = sum(n["ip_count"] for n in nodes if n["risk_score"] > 0.40 or "🤖" in n["label"])
high_risk = sum(n["ip_count"] for n in nodes if n["risk_score"] > 0.20)
elapsed = round(time.time() - t0, 2)
return {
"nodes": nodes,
"edges": edges,
"stats": {
"total_clusters": len(nodes),
"total_ips": total_ips,
"total_hits": total_hits,
"bot_ips": bot_ips,
"high_risk_ips": high_risk,
"n_samples": len(rows),
"k": k,
"elapsed_s": elapsed,
},
"feature_names": FEATURE_NAMES,
}
points = [
{"ip": m[0], "ja4": m[1], "pca_x": round(m[2], 6), "pca_y": round(m[3], 6), "risk": round(m[4], 3)}
for m in page
]
return {"points": points, "total": total, "offset": offset, "limit": limit}
@router.get("/cluster/{cluster_id}/ips")
@ -371,57 +421,44 @@ async def get_cluster_ips(
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""
IPs appartenant à un cluster (depuis le cache de la dernière exécution).
Si le cache est expiré, retourne une erreur guidant vers /clusters.
"""
if not _cache.get("cluster_ips"):
raise HTTPException(
status_code=404,
detail="Cache expiré — appelez /api/clustering/clusters d'abord"
)
"""IPs avec détails SQL (backward-compat avec l'ancienne UI)."""
with _LOCK:
status = _CACHE["status"]
ips_map = _CACHE["cluster_ips"]
if status != "ready" or not ips_map:
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
# Extrait l'index cluster depuis l'id (format: c{idx}_k{k})
try:
idx = int(cluster_id.split("_")[0][1:])
except (ValueError, IndexError):
raise HTTPException(status_code=400, detail="cluster_id invalide")
members = _cache["cluster_ips"].get(idx, [])
if not members:
return {"ips": [], "total": 0, "cluster_id": cluster_id}
total = len(members)
page_members = members[offset: offset + limit]
# Requête SQL pour les détails de ces IPs spécifiques
ip_list = [m[0] for m in page_members]
ja4_list = [m[1] for m in page_members]
if not ip_list:
members = ips_map.get(idx, [])
total = len(members)
page = members[offset: offset + limit]
if not page:
return {"ips": [], "total": total, "cluster_id": cluster_id}
# On ne peut pas facilement passer une liste en paramètre ClickHouse —
# on la construit directement (valeurs nettoyées)
safe_ips = [ip.replace("'", "") for ip in ip_list[:100]]
safe_ips = [m[0].replace("'", "") for m in page[:200]]
ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips)
sql = f"""
SELECT
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip,
t.ja4,
any(t.tcp_ttl_raw) AS ttl,
any(t.tcp_win_raw) AS win,
any(t.tcp_scale_raw) AS scale,
any(t.tcp_mss_raw) AS mss,
sum(t.hits) AS hits,
any(t.first_ua) AS ua,
round(avg(abs(ml.anomaly_score)), 3) AS avg_score,
max(ml.threat_level) AS threat_level,
any(ml.country_code) AS country_code,
any(ml.asn_org) AS asn_org,
round(avg(ml.fuzzing_index), 2) AS fuzzing,
round(avg(ml.hit_velocity), 2) AS velocity
any(t.tcp_ttl_raw) AS ttl,
any(t.tcp_win_raw) AS win,
any(t.tcp_scale_raw) AS scale,
any(t.tcp_mss_raw) AS mss,
sum(t.hits) AS hits,
any(t.first_ua) AS ua,
round(avg(abs(ml.anomaly_score)), 3) AS avg_score,
max(ml.threat_level) AS threat_level,
any(ml.country_code) AS country_code,
any(ml.asn_org) AS asn_org,
round(avg(ml.fuzzing_index), 2) AS fuzzing,
round(avg(ml.hit_velocity), 2) AS velocity
FROM mabase_prod.agg_host_ip_ja4_1h t
LEFT JOIN mabase_prod.ml_detected_anomalies ml
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
@ -439,7 +476,7 @@ async def get_cluster_ips(
ips = []
for row in result.result_rows:
ips.append({
"ip": str(row[0]),
"ip": str(row[0] or ""),
"ja4": str(row[1] or ""),
"tcp_ttl": int(row[2] or 0),
"tcp_win": int(row[3] or 0),