Files
dashboard/backend/routes/clustering.py
SOC Analyst 08054cb571 fix: bulles plus petites + viewport auto-fit avec padding 18%
- Backend: radius = log1p(ip_count)*2.2 au lieu de sqrt*2 (max 30px vs 80px)
  ex: 60K IPs → 24px, 1K IPs → 15px, 100 IPs → 10px
- Frontend: zoom initial -0.5 (vue dézoomée par défaut)
- Fit viewport basé sur dimensions réelles canvas - panneaux latéraux
- Padding 18% autour de l'étendue des données pour éviter le débord

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-19 09:50:41 +01:00

496 lines
20 KiB
Python

"""
Clustering d'IPs multi-métriques — WebGL / deck.gl backend.
- Calcul sur la TOTALITÉ des IPs (GROUP BY src_ip, ja4 sans LIMIT)
- K-means++ vectorisé (numpy) + PCA-2D + enveloppes convexes (scipy)
- Calcul en background thread + cache 30 min
- Endpoints : /clusters, /status, /cluster/{id}/points
"""
from __future__ import annotations
import math
import time
import logging
import threading
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Any
import numpy as np
from fastapi import APIRouter, HTTPException, Query
from ..database import db
from ..services.clustering_engine import (
FEATURE_KEYS, FEATURE_NAMES, FEATURE_NORMS, N_FEATURES,
build_feature_vector, kmeans_pp, pca_2d, compute_hulls,
name_cluster, risk_score_from_centroid,
)
log = logging.getLogger(__name__)
router = APIRouter(prefix="/api/clustering", tags=["clustering"])
# ─── Cache global ──────────────────────────────────────────────────────────────
_CACHE: dict[str, Any] = {
"status": "idle", # idle | computing | ready | error
"error": None,
"result": None, # dict résultat complet
"ts": 0.0, # timestamp dernière mise à jour
"params": {},
"cluster_ips": {}, # cluster_idx → [(ip, ja4, pca_x, pca_y, risk)]
}
_CACHE_TTL = 1800 # 30 minutes
_LOCK = threading.Lock()
_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="clustering")
# ─── Couleurs menace ──────────────────────────────────────────────────────────
_THREAT_COLOR = {
0.70: "#dc2626", # Critique
0.45: "#f97316", # Élevé
0.25: "#eab308", # Modéré
0.00: "#22c55e", # Sain
}
def _risk_to_color(risk: float) -> str:
for threshold, color in sorted(_THREAT_COLOR.items(), reverse=True):
if risk >= threshold:
return color
return "#6b7280"
# ─── SQL : TOUTES les IPs sans LIMIT ─────────────────────────────────────────
_SQL_ALL_IPS = """
SELECT
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip,
t.ja4,
any(t.tcp_ttl_raw) AS ttl,
any(t.tcp_win_raw) AS win,
any(t.tcp_scale_raw) AS scale,
any(t.tcp_mss_raw) AS mss,
any(t.first_ua) AS ua,
sum(t.hits) AS hits,
avg(abs(ml.anomaly_score)) AS avg_score,
avg(ml.hit_velocity) AS avg_velocity,
avg(ml.fuzzing_index) AS avg_fuzzing,
avg(ml.is_headless) AS pct_headless,
avg(ml.post_ratio) AS avg_post,
avg(ml.ip_id_zero_ratio) AS ip_id_zero,
avg(ml.temporal_entropy) AS entropy,
avg(ml.modern_browser_score) AS browser_score,
avg(ml.alpn_http_mismatch) AS alpn_mismatch,
avg(ml.is_alpn_missing) AS alpn_missing,
avg(ml.multiplexing_efficiency) AS h2_eff,
avg(ml.header_order_confidence) AS hdr_conf,
avg(ml.ua_ch_mismatch) AS ua_ch_mismatch,
avg(ml.asset_ratio) AS asset_ratio,
avg(ml.direct_access_ratio) AS direct_ratio,
avg(ml.distinct_ja4_count) AS ja4_count,
max(ml.is_ua_rotating) AS ua_rotating,
max(ml.threat_level) AS threat,
any(ml.country_code) AS country,
any(ml.asn_org) AS asn_org
FROM mabase_prod.agg_host_ip_ja4_1h t
LEFT JOIN mabase_prod.ml_detected_anomalies ml
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
AND ml.detected_at >= now() - INTERVAL %(hours)s HOUR
WHERE t.window_start >= now() - INTERVAL %(hours)s HOUR
AND t.tcp_ttl_raw > 0
GROUP BY t.src_ip, t.ja4
"""
_SQL_COLS = [
"ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits",
"avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post",
"ip_id_zero", "entropy", "browser_score", "alpn_mismatch", "alpn_missing",
"h2_eff", "hdr_conf", "ua_ch_mismatch", "asset_ratio", "direct_ratio",
"ja4_count", "ua_rotating", "threat", "country", "asn_org",
]
# ─── Worker de clustering (thread pool) ──────────────────────────────────────
def _run_clustering_job(k: int, hours: int) -> None:
"""Exécuté dans le thread pool. Met à jour _CACHE."""
t0 = time.time()
with _LOCK:
_CACHE["status"] = "computing"
_CACHE["error"] = None
try:
log.info(f"[clustering] Démarrage du calcul k={k} hours={hours}")
# ── 1. Chargement de toutes les IPs ──────────────────────────────
result = db.query(_SQL_ALL_IPS, {"hours": hours})
rows: list[dict] = []
for row in result.result_rows:
rows.append({col: row[i] for i, col in enumerate(_SQL_COLS)})
n = len(rows)
log.info(f"[clustering] {n} IPs chargées")
if n < k:
raise ValueError(f"Seulement {n} IPs disponibles (k={k} requis)")
# ── 2. Construction de la matrice de features (numpy) ────────────
X = np.array([build_feature_vector(r) for r in rows], dtype=np.float32)
log.info(f"[clustering] Matrice X: {X.shape}{X.nbytes/1024/1024:.1f} MB")
# ── 3. K-means++ vectorisé ────────────────────────────────────────
km = kmeans_pp(X.astype(np.float64), k=k, max_iter=80, n_init=3, seed=42)
log.info(f"[clustering] K-means: {km.n_iter} iters, inertia={km.inertia:.2f}")
# ── 4. PCA-2D pour toutes les IPs ────────────────────────────────
coords = pca_2d(X.astype(np.float64)) # (n, 2), normalisé [0,1]
# ── 5. Enveloppes convexes par cluster ───────────────────────────
hulls = compute_hulls(coords, km.labels, k)
# ── 6. Agrégation par cluster ─────────────────────────────────────
cluster_rows: list[list[dict]] = [[] for _ in range(k)]
cluster_coords: list[list[list[float]]] = [[] for _ in range(k)]
cluster_ips_map: dict[int, list] = {j: [] for j in range(k)}
for i, label in enumerate(km.labels):
j = int(label)
cluster_rows[j].append(rows[i])
cluster_coords[j].append(coords[i].tolist())
cluster_ips_map[j].append((
rows[i]["ip"],
rows[i]["ja4"],
float(coords[i][0]),
float(coords[i][1]),
float(risk_score_from_centroid(km.centroids[j])),
))
# ── 7. Construction des nœuds ─────────────────────────────────────
nodes = []
for j in range(k):
if not cluster_rows[j]:
continue
def avg_f(key: str, crows: list[dict] = cluster_rows[j]) -> float:
return float(np.mean([float(r.get(key) or 0) for r in crows]))
mean_ttl = avg_f("ttl")
mean_mss = avg_f("mss")
mean_scale = avg_f("scale")
mean_win = avg_f("win")
raw_stats = {"mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale}
label_name = name_cluster(km.centroids[j], raw_stats)
risk = float(risk_score_from_centroid(km.centroids[j]))
color = _risk_to_color(risk)
# Centroïde 2D = moyenne des coords du cluster
cxy = np.mean(cluster_coords[j], axis=0).tolist() if cluster_coords[j] else [0.5, 0.5]
ip_set = list({r["ip"] for r in cluster_rows[j]})
ip_count = len(ip_set)
hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j]))
threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")]
countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")]
orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")]
def topk(lst: list[str], n: int = 5) -> list[str]:
return [v for v, _ in Counter(lst).most_common(n) if v]
radar = [
{"feature": name, "value": round(float(km.centroids[j][i]), 4)}
for i, name in enumerate(FEATURE_NAMES)
]
radius = max(8, min(30, int(math.log1p(ip_count) * 2.2)))
sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8]
sample_ips = [r["ip"] for r in sample_rows]
sample_ua = str(cluster_rows[j][0].get("ua") or "")
nodes.append({
"id": f"c{j}_k{k}",
"cluster_idx": j,
"label": label_name,
"pca_x": round(cxy[0], 6),
"pca_y": round(cxy[1], 6),
"radius": radius,
"color": color,
"risk_score": round(risk, 4),
"mean_ttl": round(mean_ttl, 1),
"mean_mss": round(mean_mss, 0),
"mean_scale": round(mean_scale, 1),
"mean_win": round(mean_win, 0),
"mean_score": round(avg_f("avg_score"), 4),
"mean_velocity":round(avg_f("avg_velocity"),3),
"mean_fuzzing": round(avg_f("avg_fuzzing"), 3),
"mean_headless":round(avg_f("pct_headless"),3),
"mean_post": round(avg_f("avg_post"), 3),
"mean_asset": round(avg_f("asset_ratio"), 3),
"mean_direct": round(avg_f("direct_ratio"),3),
"mean_alpn_mismatch": round(avg_f("alpn_mismatch"),3),
"mean_h2_eff": round(avg_f("h2_eff"), 3),
"mean_hdr_conf":round(avg_f("hdr_conf"), 3),
"mean_ua_ch": round(avg_f("ua_ch_mismatch"),3),
"mean_entropy": round(avg_f("entropy"), 3),
"mean_ja4_diversity": round(avg_f("ja4_count"),3),
"mean_ip_id_zero": round(avg_f("ip_id_zero"),3),
"mean_browser_score": round(avg_f("browser_score"),1),
"mean_ua_rotating": round(avg_f("ua_rotating"),3),
"ip_count": ip_count,
"hit_count": hit_count,
"top_threat": topk(threats, 1)[0] if threats else "",
"top_countries":topk(countries, 5),
"top_orgs": topk(orgs, 5),
"sample_ips": sample_ips,
"sample_ua": sample_ua,
"radar": radar,
# Hull pour deck.gl PolygonLayer
"hull": hulls.get(j, []),
})
# ── 8. Arêtes k-NN entre clusters ────────────────────────────────
edges = []
seen: set[frozenset] = set()
for i, ni in enumerate(nodes):
ci = ni["cluster_idx"]
dists = sorted(
[(j, nj["cluster_idx"],
float(np.sum((km.centroids[ci] - km.centroids[nj["cluster_idx"]]) ** 2)))
for j, nj in enumerate(nodes) if j != i],
key=lambda x: x[2]
)
for j_idx, cj, d2 in dists[:2]:
key = frozenset([ni["id"], nodes[j_idx]["id"]])
if key in seen:
continue
seen.add(key)
edges.append({
"id": f"e_{ni['id']}_{nodes[j_idx]['id']}",
"source": ni["id"],
"target": nodes[j_idx]["id"],
"similarity": round(1.0 / (1.0 + math.sqrt(d2)), 3),
})
# ── 9. Stockage résultat + cache IPs ─────────────────────────────
total_ips = sum(n_["ip_count"] for n_ in nodes)
total_hits = sum(n_["hit_count"] for n_ in nodes)
bot_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.45 or "🤖" in n_["label"])
high_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.25)
elapsed = round(time.time() - t0, 2)
result_dict = {
"nodes": nodes,
"edges": edges,
"stats": {
"total_clusters": len(nodes),
"total_ips": total_ips,
"total_hits": total_hits,
"bot_ips": bot_ips,
"high_risk_ips": high_ips,
"n_samples": n,
"k": k,
"elapsed_s": elapsed,
},
"feature_names": FEATURE_NAMES,
}
with _LOCK:
_CACHE["result"] = result_dict
_CACHE["cluster_ips"] = cluster_ips_map
_CACHE["status"] = "ready"
_CACHE["ts"] = time.time()
_CACHE["params"] = {"k": k, "hours": hours}
_CACHE["error"] = None
log.info(f"[clustering] Terminé en {elapsed}s — {total_ips} IPs, {len(nodes)} clusters")
except Exception as e:
log.exception("[clustering] Erreur lors du calcul")
with _LOCK:
_CACHE["status"] = "error"
_CACHE["error"] = str(e)
def _maybe_trigger(k: int, hours: int) -> None:
"""Lance le calcul si cache absent, expiré ou paramètres différents."""
with _LOCK:
status = _CACHE["status"]
params = _CACHE["params"]
ts = _CACHE["ts"]
cache_stale = (time.time() - ts) > _CACHE_TTL
params_changed = params.get("k") != k or params.get("hours") != hours
if status in ("computing",):
return # déjà en cours
if status == "ready" and not cache_stale and not params_changed:
return # cache frais
_EXECUTOR.submit(_run_clustering_job, k, hours)
# ─── Endpoints ────────────────────────────────────────────────────────────────
@router.get("/status")
async def get_status():
"""État du calcul en cours (polling frontend)."""
with _LOCK:
return {
"status": _CACHE["status"],
"error": _CACHE["error"],
"ts": _CACHE["ts"],
"params": _CACHE["params"],
"age_s": round(time.time() - _CACHE["ts"], 0) if _CACHE["ts"] else None,
}
@router.get("/clusters")
async def get_clusters(
k: int = Query(14, ge=4, le=30, description="Nombre de clusters"),
hours: int = Query(24, ge=1, le=168, description="Fenêtre temporelle (heures)"),
force: bool = Query(False, description="Forcer le recalcul"),
):
"""
Clustering multi-métriques sur TOUTES les IPs.
Retourne immédiatement depuis le cache (status=ready).
Si le calcul est en cours ou non démarré → status=computing/idle + trigger.
"""
if force:
with _LOCK:
_CACHE["status"] = "idle"
_CACHE["ts"] = 0.0
_maybe_trigger(k, hours)
with _LOCK:
status = _CACHE["status"]
result = _CACHE["result"]
error = _CACHE["error"]
if status == "computing":
return {"status": "computing", "message": "Calcul en cours, réessayez dans quelques secondes"}
if status == "error":
raise HTTPException(status_code=500, detail=error or "Erreur inconnue")
if result is None:
return {"status": "idle", "message": "Calcul démarré, réessayez dans quelques secondes"}
return {**result, "status": "ready"}
@router.get("/cluster/{cluster_id}/points")
async def get_cluster_points(
cluster_id: str,
limit: int = Query(5000, ge=1, le=20000),
offset: int = Query(0, ge=0),
):
"""
Coordonnées PCA + métadonnées de toutes les IPs d'un cluster.
Utilisé par deck.gl ScatterplotLayer (drill-down ou zoom avancé).
"""
with _LOCK:
status = _CACHE["status"]
ips_map = _CACHE["cluster_ips"]
if status != "ready" or not ips_map:
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
try:
idx = int(cluster_id.split("_")[0][1:])
except (ValueError, IndexError):
raise HTTPException(status_code=400, detail="cluster_id invalide (format: c{n}_k{k})")
members = ips_map.get(idx, [])
total = len(members)
page = members[offset: offset + limit]
points = [
{"ip": m[0], "ja4": m[1], "pca_x": round(m[2], 6), "pca_y": round(m[3], 6), "risk": round(m[4], 3)}
for m in page
]
return {"points": points, "total": total, "offset": offset, "limit": limit}
@router.get("/cluster/{cluster_id}/ips")
async def get_cluster_ips(
cluster_id: str,
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""IPs avec détails SQL (backward-compat avec l'ancienne UI)."""
with _LOCK:
status = _CACHE["status"]
ips_map = _CACHE["cluster_ips"]
if status != "ready" or not ips_map:
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
try:
idx = int(cluster_id.split("_")[0][1:])
except (ValueError, IndexError):
raise HTTPException(status_code=400, detail="cluster_id invalide")
members = ips_map.get(idx, [])
total = len(members)
page = members[offset: offset + limit]
if not page:
return {"ips": [], "total": total, "cluster_id": cluster_id}
safe_ips = [m[0].replace("'", "") for m in page[:200]]
ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips)
sql = f"""
SELECT
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip,
t.ja4,
any(t.tcp_ttl_raw) AS ttl,
any(t.tcp_win_raw) AS win,
any(t.tcp_scale_raw) AS scale,
any(t.tcp_mss_raw) AS mss,
sum(t.hits) AS hits,
any(t.first_ua) AS ua,
round(avg(abs(ml.anomaly_score)), 3) AS avg_score,
max(ml.threat_level) AS threat_level,
any(ml.country_code) AS country_code,
any(ml.asn_org) AS asn_org,
round(avg(ml.fuzzing_index), 2) AS fuzzing,
round(avg(ml.hit_velocity), 2) AS velocity
FROM mabase_prod.agg_host_ip_ja4_1h t
LEFT JOIN mabase_prod.ml_detected_anomalies ml
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
AND ml.detected_at >= now() - INTERVAL 24 HOUR
WHERE t.window_start >= now() - INTERVAL 24 HOUR
AND replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') IN ({ip_filter})
GROUP BY t.src_ip, t.ja4
ORDER BY hits DESC
"""
try:
result = db.query(sql)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
ips = []
for row in result.result_rows:
ips.append({
"ip": str(row[0] or ""),
"ja4": str(row[1] or ""),
"tcp_ttl": int(row[2] or 0),
"tcp_win": int(row[3] or 0),
"tcp_scale": int(row[4] or 0),
"tcp_mss": int(row[5] or 0),
"hits": int(row[6] or 0),
"ua": str(row[7] or ""),
"avg_score": float(row[8] or 0),
"threat_level": str(row[9] or ""),
"country_code": str(row[10] or ""),
"asn_org": str(row[11] or ""),
"fuzzing": float(row[12] or 0),
"velocity": float(row[13] or 0),
})
return {"ips": ips, "total": total, "cluster_id": cluster_id}