- Backend: radius = log1p(ip_count)*2.2 au lieu de sqrt*2 (max 30px vs 80px) ex: 60K IPs → 24px, 1K IPs → 15px, 100 IPs → 10px - Frontend: zoom initial -0.5 (vue dézoomée par défaut) - Fit viewport basé sur dimensions réelles canvas - panneaux latéraux - Padding 18% autour de l'étendue des données pour éviter le débord Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
496 lines
20 KiB
Python
496 lines
20 KiB
Python
"""
|
|
Clustering d'IPs multi-métriques — WebGL / deck.gl backend.
|
|
|
|
- Calcul sur la TOTALITÉ des IPs (GROUP BY src_ip, ja4 sans LIMIT)
|
|
- K-means++ vectorisé (numpy) + PCA-2D + enveloppes convexes (scipy)
|
|
- Calcul en background thread + cache 30 min
|
|
- Endpoints : /clusters, /status, /cluster/{id}/points
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
import time
|
|
import logging
|
|
import threading
|
|
from collections import Counter
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Optional, Any
|
|
|
|
import numpy as np
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
|
|
from ..database import db
|
|
from ..services.clustering_engine import (
|
|
FEATURE_KEYS, FEATURE_NAMES, FEATURE_NORMS, N_FEATURES,
|
|
build_feature_vector, kmeans_pp, pca_2d, compute_hulls,
|
|
name_cluster, risk_score_from_centroid,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/clustering", tags=["clustering"])
|
|
|
|
# ─── Cache global ──────────────────────────────────────────────────────────────
|
|
_CACHE: dict[str, Any] = {
|
|
"status": "idle", # idle | computing | ready | error
|
|
"error": None,
|
|
"result": None, # dict résultat complet
|
|
"ts": 0.0, # timestamp dernière mise à jour
|
|
"params": {},
|
|
"cluster_ips": {}, # cluster_idx → [(ip, ja4, pca_x, pca_y, risk)]
|
|
}
|
|
_CACHE_TTL = 1800 # 30 minutes
|
|
_LOCK = threading.Lock()
|
|
_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="clustering")
|
|
|
|
# ─── Couleurs menace ──────────────────────────────────────────────────────────
|
|
_THREAT_COLOR = {
|
|
0.70: "#dc2626", # Critique
|
|
0.45: "#f97316", # Élevé
|
|
0.25: "#eab308", # Modéré
|
|
0.00: "#22c55e", # Sain
|
|
}
|
|
|
|
def _risk_to_color(risk: float) -> str:
|
|
for threshold, color in sorted(_THREAT_COLOR.items(), reverse=True):
|
|
if risk >= threshold:
|
|
return color
|
|
return "#6b7280"
|
|
|
|
|
|
# ─── SQL : TOUTES les IPs sans LIMIT ─────────────────────────────────────────
|
|
_SQL_ALL_IPS = """
|
|
SELECT
|
|
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS ip,
|
|
t.ja4,
|
|
any(t.tcp_ttl_raw) AS ttl,
|
|
any(t.tcp_win_raw) AS win,
|
|
any(t.tcp_scale_raw) AS scale,
|
|
any(t.tcp_mss_raw) AS mss,
|
|
any(t.first_ua) AS ua,
|
|
sum(t.hits) AS hits,
|
|
|
|
avg(abs(ml.anomaly_score)) AS avg_score,
|
|
avg(ml.hit_velocity) AS avg_velocity,
|
|
avg(ml.fuzzing_index) AS avg_fuzzing,
|
|
avg(ml.is_headless) AS pct_headless,
|
|
avg(ml.post_ratio) AS avg_post,
|
|
avg(ml.ip_id_zero_ratio) AS ip_id_zero,
|
|
avg(ml.temporal_entropy) AS entropy,
|
|
avg(ml.modern_browser_score) AS browser_score,
|
|
avg(ml.alpn_http_mismatch) AS alpn_mismatch,
|
|
avg(ml.is_alpn_missing) AS alpn_missing,
|
|
avg(ml.multiplexing_efficiency) AS h2_eff,
|
|
avg(ml.header_order_confidence) AS hdr_conf,
|
|
avg(ml.ua_ch_mismatch) AS ua_ch_mismatch,
|
|
avg(ml.asset_ratio) AS asset_ratio,
|
|
avg(ml.direct_access_ratio) AS direct_ratio,
|
|
avg(ml.distinct_ja4_count) AS ja4_count,
|
|
max(ml.is_ua_rotating) AS ua_rotating,
|
|
|
|
max(ml.threat_level) AS threat,
|
|
any(ml.country_code) AS country,
|
|
any(ml.asn_org) AS asn_org
|
|
FROM mabase_prod.agg_host_ip_ja4_1h t
|
|
LEFT JOIN mabase_prod.ml_detected_anomalies ml
|
|
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
|
|
AND ml.detected_at >= now() - INTERVAL %(hours)s HOUR
|
|
WHERE t.window_start >= now() - INTERVAL %(hours)s HOUR
|
|
AND t.tcp_ttl_raw > 0
|
|
GROUP BY t.src_ip, t.ja4
|
|
"""
|
|
|
|
_SQL_COLS = [
|
|
"ip", "ja4", "ttl", "win", "scale", "mss", "ua", "hits",
|
|
"avg_score", "avg_velocity", "avg_fuzzing", "pct_headless", "avg_post",
|
|
"ip_id_zero", "entropy", "browser_score", "alpn_mismatch", "alpn_missing",
|
|
"h2_eff", "hdr_conf", "ua_ch_mismatch", "asset_ratio", "direct_ratio",
|
|
"ja4_count", "ua_rotating", "threat", "country", "asn_org",
|
|
]
|
|
|
|
|
|
# ─── Worker de clustering (thread pool) ──────────────────────────────────────
|
|
|
|
def _run_clustering_job(k: int, hours: int) -> None:
|
|
"""Exécuté dans le thread pool. Met à jour _CACHE."""
|
|
t0 = time.time()
|
|
with _LOCK:
|
|
_CACHE["status"] = "computing"
|
|
_CACHE["error"] = None
|
|
|
|
try:
|
|
log.info(f"[clustering] Démarrage du calcul k={k} hours={hours}")
|
|
|
|
# ── 1. Chargement de toutes les IPs ──────────────────────────────
|
|
result = db.query(_SQL_ALL_IPS, {"hours": hours})
|
|
rows: list[dict] = []
|
|
for row in result.result_rows:
|
|
rows.append({col: row[i] for i, col in enumerate(_SQL_COLS)})
|
|
|
|
n = len(rows)
|
|
log.info(f"[clustering] {n} IPs chargées")
|
|
if n < k:
|
|
raise ValueError(f"Seulement {n} IPs disponibles (k={k} requis)")
|
|
|
|
# ── 2. Construction de la matrice de features (numpy) ────────────
|
|
X = np.array([build_feature_vector(r) for r in rows], dtype=np.float32)
|
|
log.info(f"[clustering] Matrice X: {X.shape} — {X.nbytes/1024/1024:.1f} MB")
|
|
|
|
# ── 3. K-means++ vectorisé ────────────────────────────────────────
|
|
km = kmeans_pp(X.astype(np.float64), k=k, max_iter=80, n_init=3, seed=42)
|
|
log.info(f"[clustering] K-means: {km.n_iter} iters, inertia={km.inertia:.2f}")
|
|
|
|
# ── 4. PCA-2D pour toutes les IPs ────────────────────────────────
|
|
coords = pca_2d(X.astype(np.float64)) # (n, 2), normalisé [0,1]
|
|
|
|
# ── 5. Enveloppes convexes par cluster ───────────────────────────
|
|
hulls = compute_hulls(coords, km.labels, k)
|
|
|
|
# ── 6. Agrégation par cluster ─────────────────────────────────────
|
|
cluster_rows: list[list[dict]] = [[] for _ in range(k)]
|
|
cluster_coords: list[list[list[float]]] = [[] for _ in range(k)]
|
|
cluster_ips_map: dict[int, list] = {j: [] for j in range(k)}
|
|
|
|
for i, label in enumerate(km.labels):
|
|
j = int(label)
|
|
cluster_rows[j].append(rows[i])
|
|
cluster_coords[j].append(coords[i].tolist())
|
|
cluster_ips_map[j].append((
|
|
rows[i]["ip"],
|
|
rows[i]["ja4"],
|
|
float(coords[i][0]),
|
|
float(coords[i][1]),
|
|
float(risk_score_from_centroid(km.centroids[j])),
|
|
))
|
|
|
|
# ── 7. Construction des nœuds ─────────────────────────────────────
|
|
nodes = []
|
|
for j in range(k):
|
|
if not cluster_rows[j]:
|
|
continue
|
|
|
|
def avg_f(key: str, crows: list[dict] = cluster_rows[j]) -> float:
|
|
return float(np.mean([float(r.get(key) or 0) for r in crows]))
|
|
|
|
mean_ttl = avg_f("ttl")
|
|
mean_mss = avg_f("mss")
|
|
mean_scale = avg_f("scale")
|
|
mean_win = avg_f("win")
|
|
|
|
raw_stats = {"mean_ttl": mean_ttl, "mean_mss": mean_mss, "mean_scale": mean_scale}
|
|
label_name = name_cluster(km.centroids[j], raw_stats)
|
|
risk = float(risk_score_from_centroid(km.centroids[j]))
|
|
color = _risk_to_color(risk)
|
|
|
|
# Centroïde 2D = moyenne des coords du cluster
|
|
cxy = np.mean(cluster_coords[j], axis=0).tolist() if cluster_coords[j] else [0.5, 0.5]
|
|
ip_set = list({r["ip"] for r in cluster_rows[j]})
|
|
ip_count = len(ip_set)
|
|
hit_count = int(sum(float(r.get("hits") or 0) for r in cluster_rows[j]))
|
|
|
|
threats = [str(r.get("threat") or "") for r in cluster_rows[j] if r.get("threat")]
|
|
countries = [str(r.get("country") or "") for r in cluster_rows[j] if r.get("country")]
|
|
orgs = [str(r.get("asn_org") or "") for r in cluster_rows[j] if r.get("asn_org")]
|
|
|
|
def topk(lst: list[str], n: int = 5) -> list[str]:
|
|
return [v for v, _ in Counter(lst).most_common(n) if v]
|
|
|
|
radar = [
|
|
{"feature": name, "value": round(float(km.centroids[j][i]), 4)}
|
|
for i, name in enumerate(FEATURE_NAMES)
|
|
]
|
|
|
|
radius = max(8, min(30, int(math.log1p(ip_count) * 2.2)))
|
|
|
|
sample_rows = sorted(cluster_rows[j], key=lambda r: float(r.get("hits") or 0), reverse=True)[:8]
|
|
sample_ips = [r["ip"] for r in sample_rows]
|
|
sample_ua = str(cluster_rows[j][0].get("ua") or "")
|
|
|
|
nodes.append({
|
|
"id": f"c{j}_k{k}",
|
|
"cluster_idx": j,
|
|
"label": label_name,
|
|
"pca_x": round(cxy[0], 6),
|
|
"pca_y": round(cxy[1], 6),
|
|
"radius": radius,
|
|
"color": color,
|
|
"risk_score": round(risk, 4),
|
|
|
|
"mean_ttl": round(mean_ttl, 1),
|
|
"mean_mss": round(mean_mss, 0),
|
|
"mean_scale": round(mean_scale, 1),
|
|
"mean_win": round(mean_win, 0),
|
|
"mean_score": round(avg_f("avg_score"), 4),
|
|
"mean_velocity":round(avg_f("avg_velocity"),3),
|
|
"mean_fuzzing": round(avg_f("avg_fuzzing"), 3),
|
|
"mean_headless":round(avg_f("pct_headless"),3),
|
|
"mean_post": round(avg_f("avg_post"), 3),
|
|
"mean_asset": round(avg_f("asset_ratio"), 3),
|
|
"mean_direct": round(avg_f("direct_ratio"),3),
|
|
"mean_alpn_mismatch": round(avg_f("alpn_mismatch"),3),
|
|
"mean_h2_eff": round(avg_f("h2_eff"), 3),
|
|
"mean_hdr_conf":round(avg_f("hdr_conf"), 3),
|
|
"mean_ua_ch": round(avg_f("ua_ch_mismatch"),3),
|
|
"mean_entropy": round(avg_f("entropy"), 3),
|
|
"mean_ja4_diversity": round(avg_f("ja4_count"),3),
|
|
"mean_ip_id_zero": round(avg_f("ip_id_zero"),3),
|
|
"mean_browser_score": round(avg_f("browser_score"),1),
|
|
"mean_ua_rotating": round(avg_f("ua_rotating"),3),
|
|
|
|
"ip_count": ip_count,
|
|
"hit_count": hit_count,
|
|
"top_threat": topk(threats, 1)[0] if threats else "",
|
|
"top_countries":topk(countries, 5),
|
|
"top_orgs": topk(orgs, 5),
|
|
"sample_ips": sample_ips,
|
|
"sample_ua": sample_ua,
|
|
"radar": radar,
|
|
|
|
# Hull pour deck.gl PolygonLayer
|
|
"hull": hulls.get(j, []),
|
|
})
|
|
|
|
# ── 8. Arêtes k-NN entre clusters ────────────────────────────────
|
|
edges = []
|
|
seen: set[frozenset] = set()
|
|
for i, ni in enumerate(nodes):
|
|
ci = ni["cluster_idx"]
|
|
dists = sorted(
|
|
[(j, nj["cluster_idx"],
|
|
float(np.sum((km.centroids[ci] - km.centroids[nj["cluster_idx"]]) ** 2)))
|
|
for j, nj in enumerate(nodes) if j != i],
|
|
key=lambda x: x[2]
|
|
)
|
|
for j_idx, cj, d2 in dists[:2]:
|
|
key = frozenset([ni["id"], nodes[j_idx]["id"]])
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
edges.append({
|
|
"id": f"e_{ni['id']}_{nodes[j_idx]['id']}",
|
|
"source": ni["id"],
|
|
"target": nodes[j_idx]["id"],
|
|
"similarity": round(1.0 / (1.0 + math.sqrt(d2)), 3),
|
|
})
|
|
|
|
# ── 9. Stockage résultat + cache IPs ─────────────────────────────
|
|
total_ips = sum(n_["ip_count"] for n_ in nodes)
|
|
total_hits = sum(n_["hit_count"] for n_ in nodes)
|
|
bot_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.45 or "🤖" in n_["label"])
|
|
high_ips = sum(n_["ip_count"] for n_ in nodes if n_["risk_score"] > 0.25)
|
|
elapsed = round(time.time() - t0, 2)
|
|
|
|
result_dict = {
|
|
"nodes": nodes,
|
|
"edges": edges,
|
|
"stats": {
|
|
"total_clusters": len(nodes),
|
|
"total_ips": total_ips,
|
|
"total_hits": total_hits,
|
|
"bot_ips": bot_ips,
|
|
"high_risk_ips": high_ips,
|
|
"n_samples": n,
|
|
"k": k,
|
|
"elapsed_s": elapsed,
|
|
},
|
|
"feature_names": FEATURE_NAMES,
|
|
}
|
|
|
|
with _LOCK:
|
|
_CACHE["result"] = result_dict
|
|
_CACHE["cluster_ips"] = cluster_ips_map
|
|
_CACHE["status"] = "ready"
|
|
_CACHE["ts"] = time.time()
|
|
_CACHE["params"] = {"k": k, "hours": hours}
|
|
_CACHE["error"] = None
|
|
|
|
log.info(f"[clustering] Terminé en {elapsed}s — {total_ips} IPs, {len(nodes)} clusters")
|
|
|
|
except Exception as e:
|
|
log.exception("[clustering] Erreur lors du calcul")
|
|
with _LOCK:
|
|
_CACHE["status"] = "error"
|
|
_CACHE["error"] = str(e)
|
|
|
|
|
|
def _maybe_trigger(k: int, hours: int) -> None:
|
|
"""Lance le calcul si cache absent, expiré ou paramètres différents."""
|
|
with _LOCK:
|
|
status = _CACHE["status"]
|
|
params = _CACHE["params"]
|
|
ts = _CACHE["ts"]
|
|
|
|
cache_stale = (time.time() - ts) > _CACHE_TTL
|
|
params_changed = params.get("k") != k or params.get("hours") != hours
|
|
|
|
if status in ("computing",):
|
|
return # déjà en cours
|
|
|
|
if status == "ready" and not cache_stale and not params_changed:
|
|
return # cache frais
|
|
|
|
_EXECUTOR.submit(_run_clustering_job, k, hours)
|
|
|
|
|
|
# ─── Endpoints ────────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/status")
|
|
async def get_status():
|
|
"""État du calcul en cours (polling frontend)."""
|
|
with _LOCK:
|
|
return {
|
|
"status": _CACHE["status"],
|
|
"error": _CACHE["error"],
|
|
"ts": _CACHE["ts"],
|
|
"params": _CACHE["params"],
|
|
"age_s": round(time.time() - _CACHE["ts"], 0) if _CACHE["ts"] else None,
|
|
}
|
|
|
|
|
|
@router.get("/clusters")
|
|
async def get_clusters(
|
|
k: int = Query(14, ge=4, le=30, description="Nombre de clusters"),
|
|
hours: int = Query(24, ge=1, le=168, description="Fenêtre temporelle (heures)"),
|
|
force: bool = Query(False, description="Forcer le recalcul"),
|
|
):
|
|
"""
|
|
Clustering multi-métriques sur TOUTES les IPs.
|
|
|
|
Retourne immédiatement depuis le cache (status=ready).
|
|
Si le calcul est en cours ou non démarré → status=computing/idle + trigger.
|
|
"""
|
|
if force:
|
|
with _LOCK:
|
|
_CACHE["status"] = "idle"
|
|
_CACHE["ts"] = 0.0
|
|
|
|
_maybe_trigger(k, hours)
|
|
|
|
with _LOCK:
|
|
status = _CACHE["status"]
|
|
result = _CACHE["result"]
|
|
error = _CACHE["error"]
|
|
|
|
if status == "computing":
|
|
return {"status": "computing", "message": "Calcul en cours, réessayez dans quelques secondes"}
|
|
|
|
if status == "error":
|
|
raise HTTPException(status_code=500, detail=error or "Erreur inconnue")
|
|
|
|
if result is None:
|
|
return {"status": "idle", "message": "Calcul démarré, réessayez dans quelques secondes"}
|
|
|
|
return {**result, "status": "ready"}
|
|
|
|
|
|
@router.get("/cluster/{cluster_id}/points")
|
|
async def get_cluster_points(
|
|
cluster_id: str,
|
|
limit: int = Query(5000, ge=1, le=20000),
|
|
offset: int = Query(0, ge=0),
|
|
):
|
|
"""
|
|
Coordonnées PCA + métadonnées de toutes les IPs d'un cluster.
|
|
Utilisé par deck.gl ScatterplotLayer (drill-down ou zoom avancé).
|
|
"""
|
|
with _LOCK:
|
|
status = _CACHE["status"]
|
|
ips_map = _CACHE["cluster_ips"]
|
|
|
|
if status != "ready" or not ips_map:
|
|
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
|
|
|
|
try:
|
|
idx = int(cluster_id.split("_")[0][1:])
|
|
except (ValueError, IndexError):
|
|
raise HTTPException(status_code=400, detail="cluster_id invalide (format: c{n}_k{k})")
|
|
|
|
members = ips_map.get(idx, [])
|
|
total = len(members)
|
|
page = members[offset: offset + limit]
|
|
|
|
points = [
|
|
{"ip": m[0], "ja4": m[1], "pca_x": round(m[2], 6), "pca_y": round(m[3], 6), "risk": round(m[4], 3)}
|
|
for m in page
|
|
]
|
|
return {"points": points, "total": total, "offset": offset, "limit": limit}
|
|
|
|
|
|
@router.get("/cluster/{cluster_id}/ips")
|
|
async def get_cluster_ips(
|
|
cluster_id: str,
|
|
limit: int = Query(100, ge=1, le=500),
|
|
offset: int = Query(0, ge=0),
|
|
):
|
|
"""IPs avec détails SQL (backward-compat avec l'ancienne UI)."""
|
|
with _LOCK:
|
|
status = _CACHE["status"]
|
|
ips_map = _CACHE["cluster_ips"]
|
|
|
|
if status != "ready" or not ips_map:
|
|
raise HTTPException(status_code=404, detail="Cache absent — appelez /clusters d'abord")
|
|
|
|
try:
|
|
idx = int(cluster_id.split("_")[0][1:])
|
|
except (ValueError, IndexError):
|
|
raise HTTPException(status_code=400, detail="cluster_id invalide")
|
|
|
|
members = ips_map.get(idx, [])
|
|
total = len(members)
|
|
page = members[offset: offset + limit]
|
|
if not page:
|
|
return {"ips": [], "total": total, "cluster_id": cluster_id}
|
|
|
|
safe_ips = [m[0].replace("'", "") for m in page[:200]]
|
|
ip_filter = ", ".join(f"'{ip}'" for ip in safe_ips)
|
|
|
|
sql = f"""
|
|
SELECT
|
|
replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') AS src_ip,
|
|
t.ja4,
|
|
any(t.tcp_ttl_raw) AS ttl,
|
|
any(t.tcp_win_raw) AS win,
|
|
any(t.tcp_scale_raw) AS scale,
|
|
any(t.tcp_mss_raw) AS mss,
|
|
sum(t.hits) AS hits,
|
|
any(t.first_ua) AS ua,
|
|
round(avg(abs(ml.anomaly_score)), 3) AS avg_score,
|
|
max(ml.threat_level) AS threat_level,
|
|
any(ml.country_code) AS country_code,
|
|
any(ml.asn_org) AS asn_org,
|
|
round(avg(ml.fuzzing_index), 2) AS fuzzing,
|
|
round(avg(ml.hit_velocity), 2) AS velocity
|
|
FROM mabase_prod.agg_host_ip_ja4_1h t
|
|
LEFT JOIN mabase_prod.ml_detected_anomalies ml
|
|
ON t.src_ip = ml.src_ip AND t.ja4 = ml.ja4
|
|
AND ml.detected_at >= now() - INTERVAL 24 HOUR
|
|
WHERE t.window_start >= now() - INTERVAL 24 HOUR
|
|
AND replaceRegexpAll(toString(t.src_ip), '^::ffff:', '') IN ({ip_filter})
|
|
GROUP BY t.src_ip, t.ja4
|
|
ORDER BY hits DESC
|
|
"""
|
|
try:
|
|
result = db.query(sql)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
ips = []
|
|
for row in result.result_rows:
|
|
ips.append({
|
|
"ip": str(row[0] or ""),
|
|
"ja4": str(row[1] or ""),
|
|
"tcp_ttl": int(row[2] or 0),
|
|
"tcp_win": int(row[3] or 0),
|
|
"tcp_scale": int(row[4] or 0),
|
|
"tcp_mss": int(row[5] or 0),
|
|
"hits": int(row[6] or 0),
|
|
"ua": str(row[7] or ""),
|
|
"avg_score": float(row[8] or 0),
|
|
"threat_level": str(row[9] or ""),
|
|
"country_code": str(row[10] or ""),
|
|
"asn_org": str(row[11] or ""),
|
|
"fuzzing": float(row[12] or 0),
|
|
"velocity": float(row[13] or 0),
|
|
})
|
|
|
|
return {"ips": ips, "total": total, "cluster_id": cluster_id}
|