Files
dashboard/backend/routes/clustering.py
SOC Analyst fc3392779b feat: slider sensibilité + z-score standardization pour clustering plus précis
Sensibilité (0.5x–3.0x) :
- Multiplie k : sensibilité=2x avec k=14 → 28 clusters effectifs
- Labels UI : Grossière / Normale / Fine / Très fine / Maximum
- Paramètres avancés (k, fenêtre) masqués dans un <details>
- Cache invalidé si sensibilité change

Z-score standardisation (Bishop 2006 PRML §9.1) :
- Normalise par variance de chaque feature avant K-means
- Features discriminantes (forte std) pèsent plus
- Résultat : risque 0→1.00 sur clusters bots vs 0→0.27 avant
- Bots détectés : 4 337 IPs vs 1 604 (2.7x plus)
- Nouveaux clusters : Bot agressif, Tunnel réseau, UA-CH Mismatch distincts

Fix TextLayer deck.gl :
- Translittération des accents (é→e, à→a, ç→c…) + strip emojis
- Évite les warnings 'Missing character' sur caractères non-ASCII

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-19 10:07:23 +01:00

518 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Clustering d'IPs multi-métriques — WebGL / deck.gl backend.
- Calcul sur la TOTALITÉ des IPs (GROUP BY src_ip, ja4 sans LIMIT)
- K-means++ vectorisé (numpy) + PCA-2D + enveloppes convexes (scipy)
- Calcul en background thread + cache 30 min
- Endpoints : /clusters, /status, /cluster/{id}/points
"""
from __future__ import annotations
import math
import time
import logging
import threading
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Any
import numpy as np
from fastapi import APIRouter, HTTPException, Query
from ..database import db
from ..services.clustering_engine import (
FEATURE_KEYS, FEATURE_NAMES, FEATURE_NORMS, N_FEATURES,
build_feature_vector, kmeans_pp, pca_2d, compute_hulls,
name_cluster, risk_score_from_centroid, standardize,
)
log = logging.getLogger(__name__)
router = APIRouter(prefix="/api/clustering", tags=["clustering"])
# ─── Cache global ──────────────────────────────────────────────────────────────
_CACHE: dict[str, Any] = {
"status": "idle", # idle | computing | ready | error
"error": None,
"result": None, # dict résultat complet
"ts": 0.0, # timestamp dernière mise à jour
"params": {},
"cluster_ips": {}, # cluster_idx → [(ip, ja4, pca_x, pca_y, risk)]
}
_CACHE_TTL = 1800 # 30 minutes
_LOCK = threading.Lock()
_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="clustering")
# ─── Couleurs menace ──────────────────────────────────────────────────────────
_THREAT_COLOR = {
0.70: "#dc2626", # Critique
0.45: "#f97316", # Élevé
0.25: "#eab308", # Modéré
0.00: "#22c55e", # Sain
}
def _risk_to_color(risk: float) -> str:
for threshold, color in sorted(_THREAT_COLOR.items(), reverse=True):
if risk >= threshold:
return color
return "#6b7280"
# ─── SQL : TOUTES les IPs sans LIMIT ─────────────────────────────────────────
_SQL_ALL_IPS = """
SELECT
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip,
t.ja4,
any(t.tcp_ttl_raw) AS ttl,
any(t.tcp_win_raw) AS win,
any(t.tcp_scale_raw) AS scale,
any(t.tcp_mss_raw) AS mss,
any(t.first_ua) AS ua,
sum(t.hits) AS hits,
avg(abs(ml.anomaly_score)) AS avg_score,
avg(ml.hit_velocity) AS avg_velocity,
avg(ml.fuzzing_index) AS avg_fuzzing,
avg(ml.is_headless) AS pct_headless,
avg(ml.post_ratio) AS avg_post,
avg(ml.ip_id_zero_ratio) AS ip_id_zero,
avg(ml.temporal_entropy) AS entropy,
avg(ml.modern_browser_score) AS browser_score,
avg(ml.alpn_http_mismatch) AS alpn_mismatch,
avg(ml.is_alpn_missing) AS alpn_missing,
avg(ml.multiplexing_efficiency) AS h2_eff,
avg(ml.header_order_confidence) AS hdr_conf,
avg(ml.ua_ch_mismatch) AS ua_ch_mismatch,
avg(ml.asset_ratio) AS asset_ratio,
avg(ml.direct_access_ratio) AS direct_ratio,
avg(ml.distinct_ja4_count) AS ja4_count,
max(ml.is_ua_rotating) AS ua_rotating,
max(ml.threat_level) AS threat,
any(ml.country_code) AS country,
any(ml.asn_org) AS asn_org
FROM mabase_prod.agg_host_ip_ja4_1h t
LEFT JOIN mabase_prod.ml_detected_anomalies ml
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
AND ml.detected_at >= now() - INTERVAL %(hours)s HOUR
WHERE t.window_start >= now() - INTERVAL %(hours)s HOUR
AND t.tcp_ttl_raw > 0
GROUP BY t.src_ip, t.ja4
"""
_SQL_COLS = [
"ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits",
"avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post",
"ip_id_zero", "entropy", "browser_score", "alpn_mismatch", "alpn_missing",
"h2_eff", "hdr_conf", "ua_ch_mismatch", "asset_ratio", "direct_ratio",
"ja4_count", "ua_rotating", "threat", "country", "asn_org",
]
# ─── Worker de clustering (thread pool) ──────────────────────────────────────
def _run_clustering_job(k: int, hours: int, sensitivity: float = 1.0) -> None:
"""Exécuté dans le thread pool. Met à jour _CACHE.
sensitivity : multiplicateur de k [0.5 3.0].
1.0 = comportement par défaut
2.0 = deux fois plus de clusters → groupes plus homogènes
0.5 = moitié → vue très agrégée
"""
k_actual = max(4, min(50, round(k * sensitivity)))
t0 = time.time()
with _LOCK:
_CACHE["status"] = "computing"
_CACHE["error"] = None
try:
log.info(f"[clustering] Démarrage k={k_actual} (base={k}×sens={sensitivity}) hours={hours}")
# ── 1. Chargement de toutes les IPs ──────────────────────────────
result = db.query(_SQL_ALL_IPS, {"hours": hours})
rows: list[dict] = []
for row in result.result_rows:
rows.append({col: row[i] for i, col in enumerate(_SQL_COLS)})
n = len(rows)
log.info(f"[clustering] {n} IPs chargées")
if n < k_actual:
raise ValueError(f"Seulement {n} IPs disponibles (k={k_actual} requis)")
# ── 2. Construction de la matrice de features (numpy) ────────────
X = np.array([build_feature_vector(r) for r in rows], dtype=np.float32)
log.info(f"[clustering] Matrice X: {X.shape}{X.nbytes/1024/1024:.1f} MB")
# ── 3. Standardisation z-score ────────────────────────────────────
# Normalise par variance : features discriminantes (forte std)
# contribuent plus que les features quasi-constantes.
X64 = X.astype(np.float64)
X_std, feat_mean, feat_std = standardize(X64)
# ── 4. K-means++ sur l'espace standardisé ────────────────────────
km = kmeans_pp(X_std, k=k_actual, max_iter=80, n_init=3, seed=42)
log.info(f"[clustering] K-means: {km.n_iter} iters, inertia={km.inertia:.2f}")
# ── 5. PCA-2D sur les features ORIGINALES (normalisées [0,1]) ────
# On utilise les features non-standardisées pour la projection 2D
# afin que les distances visuelles restent interprétables.
coords = pca_2d(X64) # (n, 2), normalisé [0,1]
# ── 5b. Enveloppes convexes par cluster ──────────────────────────
hulls = compute_hulls(coords, km.labels, k_actual)
# ── 6. Agrégation par cluster ─────────────────────────────────────
cluster_rows: list[list[dict]] = [[] for _ in range(k_actual)]
cluster_coords: list[list[list[float]]] = [[] for _ in range(k_actual)]
cluster_ips_map: dict[int, list] = {j: [] for j in range(k_actual)}
for i, label in enumerate(km.labels):
j = int(label)
cluster_rows[j].append(rows[i])
cluster_coords[j].append(coords[i].tolist())
cluster_ips_map[j].append((
rows[i]["ip"],
rows[i]["ja4"],
float(coords[i][0]),
float(coords[i][1]),
float(risk_score_from_centroid(km.centroids[j])),
))
# ── 7. Construction des nœuds ─────────────────────────────────────
nodes = []
for j in range(k_actual):
if not cluster_rows[j]:
continue
def avg_f(key: str, crows: list[dict] = cluster_rows[j]) -> float:
return float(np.mean([float(r.get(key) or 0) for r in crows]))
mean_ttl = avg_f("ttl")
mean_mss = avg_f("mss")
mean_scale = avg_f("scale")
mean_win = avg_f("win")
raw_stats = {"mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale}
label_name = name_cluster(km.centroids[j], raw_stats)
risk = float(risk_score_from_centroid(km.centroids[j]))
color = _risk_to_color(risk)
# Centroïde 2D = moyenne des coords du cluster
cxy = np.mean(cluster_coords[j], axis=0).tolist() if cluster_coords[j] else [0.5, 0.5]
ip_set = list({r["ip"] for r in cluster_rows[j]})
ip_count = len(ip_set)
hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j]))
threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")]
countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")]
orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")]
def topk(lst: list[str], n: int = 5) -> list[str]:
return [v for v, _ in Counter(lst).most_common(n) if v]
radar = [
{"feature": name, "value": round(float(km.centroids[j][i]), 4)}
for i, name in enumerate(FEATURE_NAMES)
]
radius = max(8, min(30, int(math.log1p(ip_count) * 2.2)))
sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8]
sample_ips = [r["ip"] for r in sample_rows]
sample_ua = str(cluster_rows[j][0].get("ua") or "")
nodes.append({
"id": f"c{j}_k{k_actual}",
"cluster_idx": j,
"label": label_name,
"pca_x": round(cxy[0], 6),
"pca_y": round(cxy[1], 6),
"radius": radius,
"color": color,
"risk_score": round(risk, 4),
"mean_ttl": round(mean_ttl, 1),
"mean_mss": round(mean_mss, 0),
"mean_scale": round(mean_scale, 1),
"mean_win": round(mean_win, 0),
"mean_score": round(avg_f("avg_score"), 4),
"mean_velocity":round(avg_f("avg_velocity"),3),
"mean_fuzzing": round(avg_f("avg_fuzzing"), 3),
"mean_headless":round(avg_f("pct_headless"),3),
"mean_post": round(avg_f("avg_post"), 3),
"mean_asset": round(avg_f("asset_ratio"), 3),
"mean_direct": round(avg_f("direct_ratio"),3),
"mean_alpn_mismatch": round(avg_f("alpn_mismatch"),3),
"mean_h2_eff": round(avg_f("h2_eff"), 3),
"mean_hdr_conf":round(avg_f("hdr_conf"), 3),
"mean_ua_ch": round(avg_f("ua_ch_mismatch"),3),
"mean_entropy": round(avg_f("entropy"), 3),
"mean_ja4_diversity": round(avg_f("ja4_count"),3),
"mean_ip_id_zero": round(avg_f("ip_id_zero"),3),
"mean_browser_score": round(avg_f("browser_score"),1),
"mean_ua_rotating": round(avg_f("ua_rotating"),3),
"ip_count": ip_count,
"hit_count": hit_count,
"top_threat": topk(threats, 1)[0] if threats else "",
"top_countries":topk(countries, 5),
"top_orgs": topk(orgs, 5),
"sample_ips": sample_ips,
"sample_ua": sample_ua,
"radar": radar,
# Hull pour deck.gl PolygonLayer
"hull": hulls.get(j, []),
})
# ── 8. Arêtes k-NN entre clusters ────────────────────────────────
edges = []
seen: set[frozenset] = set()
for i, ni in enumerate(nodes):
ci = ni["cluster_idx"]
dists = sorted(
[(j, nj["cluster_idx"],
float(np.sum((km.centroids[ci] - km.centroids[nj["cluster_idx"]]) ** 2)))
for j, nj in enumerate(nodes) if j != i],
key=lambda x: x[2]
)
for j_idx, cj, d2 in dists[:2]:
key = frozenset([ni["id"], nodes[j_idx]["id"]])
if key in seen:
continue
seen.add(key)
edges.append({
"id": f"e_{ni['id']}_{nodes[j_idx]['id']}",
"source": ni["id"],
"target": nodes[j_idx]["id"],
"similarity": round(1.0 / (1.0 + math.sqrt(d2)), 3),
})
# ── 9. Stockage résultat + cache IPs ─────────────────────────────
total_ips = sum(n_["ip_count"] for n_ in nodes)
total_hits = sum(n_["hit_count"] for n_ in nodes)
bot_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.45 or "🤖" in n_["label"])
high_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.25)
elapsed = round(time.time() - t0, 2)
result_dict = {
"nodes": nodes,
"edges": edges,
"stats": {
"total_clusters": len(nodes),
"total_ips": total_ips,
"total_hits": total_hits,
"bot_ips": bot_ips,
"high_risk_ips": high_ips,
"n_samples": n,
"k": k_actual,
"k_base": k,
"sensitivity": sensitivity,
"elapsed_s": elapsed,
},
"feature_names": FEATURE_NAMES,
}
with _LOCK:
_CACHE["result"] = result_dict
_CACHE["cluster_ips"] = cluster_ips_map
_CACHE["status"] = "ready"
_CACHE["ts"] = time.time()
_CACHE["params"] = {"k": k, "hours": hours, "sensitivity": sensitivity}
_CACHE["error"] = None
log.info(f"[clustering] Terminé en {elapsed}s — {total_ips} IPs, {len(nodes)} clusters")
except Exception as e:
log.exception("[clustering] Erreur lors du calcul")
with _LOCK:
_CACHE["status"] = "error"
_CACHE["error"] = str(e)
def _maybe_trigger(k: int, hours: int, sensitivity: float) -> None:
"""Lance le calcul si cache absent, expiré ou paramètres différents."""
with _LOCK:
status = _CACHE["status"]
params = _CACHE["params"]
ts = _CACHE["ts"]
cache_stale = (time.time() - ts) > _CACHE_TTL
params_changed = (
params.get("k") != k or
params.get("hours") != hours or
params.get("sensitivity") != sensitivity
)
if status in ("computing",):
return # déjà en cours
if status == "ready" and not cache_stale and not params_changed:
return # cache frais
_EXECUTOR.submit(_run_clustering_job, k, hours, sensitivity)
# ─── Endpoints ────────────────────────────────────────────────────────────────
@router.get("/status")
async def get_status():
"""État du calcul en cours (polling frontend)."""
with _LOCK:
return {
"status": _CACHE["status"],
"error": _CACHE["error"],
"ts": _CACHE["ts"],
"params": _CACHE["params"],
"age_s": round(time.time() - _CACHE["ts"], 0) if _CACHE["ts"] else None,
}
@router.get("/clusters")
async def get_clusters(
k: int = Query(14, ge=4, le=30, description="Nombre de clusters de base"),
hours: int = Query(24, ge=1, le=168, description="Fenêtre temporelle (heures)"),
sensitivity: float = Query(1.0, ge=0.5, le=3.0, description="Sensibilité : multiplicateur de k"),
force: bool = Query(False, description="Forcer le recalcul"),
):
"""
Clustering multi-métriques sur TOUTES les IPs.
k_actual = round(k × sensitivity) — la sensibilité contrôle la granularité.
Retourne immédiatement depuis le cache. Déclenche le calcul si nécessaire.
"""
if force:
with _LOCK:
_CACHE["status"] = "idle"
_CACHE["ts"] = 0.0
_maybe_trigger(k, hours, sensitivity)
with _LOCK:
status = _CACHE["status"]
result = _CACHE["result"]
error = _CACHE["error"]
if status == "computing":
return {"status": "computing", "message": "Calcul en cours, réessayez dans quelques secondes"}
if status == "error":
raise HTTPException(status_code=500, detail=error or "Erreur inconnue")
if result is None:
return {"status": "idle", "message": "Calcul démarré, réessayez dans quelques secondes"}
return {**result, "status": "ready"}
@router.get("/cluster/{cluster_id}/points")
async def get_cluster_points(
cluster_id: str,
limit: int = Query(5000, ge=1, le=20000),
offset: int = Query(0, ge=0),
):
"""
Coordonnées PCA + métadonnées de toutes les IPs d'un cluster.
Utilisé par deck.gl ScatterplotLayer (drill-down ou zoom avancé).
"""
with _LOCK:
status = _CACHE["status"]
ips_map = _CACHE["cluster_ips"]
if status != "ready" or not ips_map:
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
try:
idx = int(cluster_id.split("_")[0][1:])
except (ValueError, IndexError):
raise HTTPException(status_code=400, detail="cluster_id invalide (format: c{n}_k{k})")
members = ips_map.get(idx, [])
total = len(members)
page = members[offset: offset + limit]
points = [
{"ip": m[0], "ja4": m[1], "pca_x": round(m[2], 6), "pca_y": round(m[3], 6), "risk": round(m[4], 3)}
for m in page
]
return {"points": points, "total": total, "offset": offset, "limit": limit}
@router.get("/cluster/{cluster_id}/ips")
async def get_cluster_ips(
cluster_id: str,
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""IPs avec détails SQL (backward-compat avec l'ancienne UI)."""
with _LOCK:
status = _CACHE["status"]
ips_map = _CACHE["cluster_ips"]
if status != "ready" or not ips_map:
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
try:
idx = int(cluster_id.split("_")[0][1:])
except (ValueError, IndexError):
raise HTTPException(status_code=400, detail="cluster_id invalide")
members = ips_map.get(idx, [])
total = len(members)
page = members[offset: offset + limit]
if not page:
return {"ips": [], "total": total, "cluster_id": cluster_id}
safe_ips = [m[0].replace("'", "") for m in page[:200]]
ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips)
sql = f"""
SELECT
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip,
t.ja4,
any(t.tcp_ttl_raw) AS ttl,
any(t.tcp_win_raw) AS win,
any(t.tcp_scale_raw) AS scale,
any(t.tcp_mss_raw) AS mss,
sum(t.hits) AS hits,
any(t.first_ua) AS ua,
round(avg(abs(ml.anomaly_score)), 3) AS avg_score,
max(ml.threat_level) AS threat_level,
any(ml.country_code) AS country_code,
any(ml.asn_org) AS asn_org,
round(avg(ml.fuzzing_index), 2) AS fuzzing,
round(avg(ml.hit_velocity), 2) AS velocity
FROM mabase_prod.agg_host_ip_ja4_1h t
LEFT JOIN mabase_prod.ml_detected_anomalies ml
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
AND ml.detected_at >= now() - INTERVAL 24 HOUR
WHERE t.window_start >= now() - INTERVAL 24 HOUR
AND replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') IN ({ip_filter})
GROUP BY t.src_ip, t.ja4
ORDER BY hits DESC
"""
try:
result = db.query(sql)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
ips = []
for row in result.result_rows:
ips.append({
"ip": str(row[0] or ""),
"ja4": str(row[1] or ""),
"tcp_ttl": int(row[2] or 0),
"tcp_win": int(row[3] or 0),
"tcp_scale": int(row[4] or 0),
"tcp_mss": int(row[5] or 0),
"hits": int(row[6] or 0),
"ua": str(row[7] or ""),
"avg_score": float(row[8] or 0),
"threat_level": str(row[9] or ""),
"country_code": str(row[10] or ""),
"asn_org": str(row[11] or ""),
"fuzzing": float(row[12] or 0),
"velocity": float(row[13] or 0),
})
return {"ips": ips, "total": total, "cluster_id": cluster_id}