Files
dashboard/backend/services/clustering_engine.py
SOC Analyst 08d003a050 feat(clustering): palette diversifiée, suppression scores anomalie/robot, visualisation éclatée
- Suppression de 'Score Anomalie' (avg_score) des 31→30 features de clustering
- Suppression de 'Score de détection robot' (mean_score) de la sidebar et de l'API
- Suppression de bot_ips / high_risk_ips des stats (métriques dérivées des scores supprimés)
- Redistribution des poids dans risk_score_from_centroid: UA-CH mismatch +17%,
  fuzzing +14%, headless +10%, vélocité +9%, ip_id_zero +7%
- Mise à jour des indices feature dans name_cluster et risk_score_from_centroid
- Palette 24 couleurs spectrales (cluster_color) → bleu/violet/rose/teal/amber/cyan/lime...
  Les couleurs identifient les clusters, non leur niveau de risque
- Remplacement de la légende CRITICAL/HIGH/MEDIUM/LOW par la liste des clusters actifs
- Ajout de spread_clusters(): répulsion itérative des centroïdes trop proches (50 iter)
  min_dist=0.16 → les clusters se repoussent mutuellement → visualisation plus lisible
- Interface TypeScript mise à jour (suppression mean_score, bot_ips, high_risk_ips)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-19 14:01:14 +01:00

548 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Moteur de clustering K-means++ multi-métriques (numpy + scipy vectorisé).
Ref:
Arthur & Vassilvitskii (2007) — k-means++: The Advantages of Careful Seeding
scipy.spatial.ConvexHull — enveloppe convexe (Graham/Qhull)
sklearn-style API — centroids, labels_, inertia_
Features (31 dimensions, normalisées [0,1]) :
0 ttl_n : TTL initial normalisé
1 mss_n : MSS normalisé → type réseau
2 scale_n : facteur de mise à l'échelle TCP
3 win_n : fenêtre TCP normalisée
4 score_n : score anomalie ML (abs)
5 velocity_n : vélocité de requêtes (log1p)
6 fuzzing_n : index de fuzzing (log1p)
7 headless_n : ratio sessions headless
8 post_n : ratio POST/total
9 ip_id_zero_n : ratio IP-ID=0 (Linux/spoofé)
10 entropy_n : entropie temporelle
11 browser_n : score navigateur moderne
12 alpn_n : mismatch ALPN/protocole
13 alpn_absent_n : ratio ALPN absent
14 h2_n : efficacité H2 multiplexing (log1p)
15 hdr_conf_n : confiance ordre headers
16 ua_ch_n : mismatch User-Agent-Client-Hints
17 asset_n : ratio assets statiques
18 direct_n : ratio accès directs
19 ja4_div_n : diversité JA4 (log1p)
20 ua_rot_n : UA rotatif (booléen)
21 country_risk_n : risque pays source (CN/RU/KP → 1.0, US/DE/FR → 0.0)
22 asn_cloud_n : hébergeur cloud/CDN/VPN (Cloudflare/AWS/OVH → 1.0)
23 hdr_accept_lang_n : présence header Accept-Language (0=absent=bot-like)
24 hdr_encoding_n : présence header Accept-Encoding (0=absent=bot-like)
25 hdr_sec_fetch_n : présence headers Sec-Fetch-* (1=navigateur réel)
26 hdr_count_n : nombre de headers HTTP normalisé (3=bot, 15=browser)
27 hfp_popular_n : popularité du fingerprint headers (log-normalisé)
fingerprint rare = suspect ; très populaire = browser légitime
28 hfp_rotating_n : rotation de fingerprint (distinct_header_orders)
plusieurs fingerprints distincts → bot en rotation
29 hfp_cookie_n : présence header Cookie (engagement utilisateur réel)
30 hfp_referer_n : présence header Referer (navigation HTTP normale)
"""
from __future__ import annotations
import math
import logging
import numpy as np
from dataclasses import dataclass, field
from scipy.spatial import ConvexHull
log = logging.getLogger(__name__)
# ─── Encodage pays (risque source) ───────────────────────────────────────────
# Source: MISP threat intel, Spamhaus DROP list, géographie offensive connue
_COUNTRY_RISK: dict[str, float] = {
# Très haut risque : infrastructure offensive documentée
"CN": 1.0, "RU": 1.0, "KP": 1.0, "IR": 1.0,
"BY": 0.9, "SY": 0.9, "CU": 0.8,
# Haut risque : transit/hébergement permissif, bulletproof hosters
"HK": 0.75, "VN": 0.7, "UA": 0.65,
"RO": 0.6, "PK": 0.6, "NG": 0.6,
"BG": 0.55, "TR": 0.55, "BR": 0.5,
"TH": 0.5, "IN": 0.45, "ID": 0.45,
# Risque faible : pays à faible tolérance envers activité malveillante
"US": 0.1, "DE": 0.1, "FR": 0.1, "GB": 0.1,
"CA": 0.1, "JP": 0.1, "AU": 0.1, "NL": 0.15,
"CH": 0.1, "SE": 0.1, "NO": 0.1, "DK": 0.1,
"FI": 0.1, "AT": 0.1, "BE": 0.1, "IT": 0.15,
"SG": 0.3, "TW": 0.2, "KR": 0.2, "RS": 0.4,
}
_DEFAULT_COUNTRY_RISK = 0.35 # pays inconnu → risque modéré
def country_risk(cc: str | None) -> float:
"""Score de risque [0,1] d'un code pays ISO-3166."""
return _COUNTRY_RISK.get((cc or "").upper(), _DEFAULT_COUNTRY_RISK)
# ─── Encodage ASN (type d'infrastructure) ────────────────────────────────────
# Cloud/CDN/hosting → fort corrélé avec scanners automatisés et bots
_ASN_CLOUD_KEYWORDS = [
# Hyperscalers
"amazon", "aws", "google", "microsoft", "azure", "alibaba", "tencent", "huawei",
# CDN / edge
"cloudflare", "akamai", "fastly", "cloudfront", "incapsula", "imperva",
"sucuri", "stackpath", "keycdn",
# Hébergeurs
"ovh", "hetzner", "digitalocean", "vultr", "linode", "akamai-linode",
"leaseweb", "choopa", "packet", "equinix", "serverius", "combahton",
"m247", "b2 net", "hostinger", "contabo",
# Bulletproof / transit permissif connus
"hwclouds", "multacom", "psychz", "serverius", "colocrossing",
"frantech", "sharktech", "tzulo",
# VPN / proxy commerciaux
"nordvpn", "expressvpn", "mullvad", "protonvpn", "surfshark",
"privateinternetaccess", "pia ", "cyberghost", "hotspot shield",
"ipvanish", "hide.me",
# Bots search engines / crawlers
"facebook", "meta ", "twitter", "linkedin", "semrush", "ahrefs",
"majestic", "moz ", "babbar", "sistrix", "criteo", "peer39",
]
def asn_cloud_score(asn_org: str | None) -> float:
"""
Score [0,1] : 1.0 = cloud/CDN/hébergement/VPN confirmé.
Correspond à une infrastructure typiquement utilisée par les bots.
"""
if not asn_org:
return 0.2 # inconnu → légèrement suspect
s = asn_org.lower()
for kw in _ASN_CLOUD_KEYWORDS:
if kw in s:
return 1.0
return 0.0
# ─── Définition des features ──────────────────────────────────────────────────
FEATURES: list[tuple[str, str, object]] = [
# TCP stack
("ttl", "TTL Initial", lambda v: min(1.0, (v or 0) / 255.0)),
("mss", "MSS Réseau", lambda v: min(1.0, (v or 0) / 1460.0)),
("scale", "Scale TCP", lambda v: min(1.0, (v or 0) / 14.0)),
("win", "Fenêtre TCP", lambda v: min(1.0, (v or 0) / 65535.0)),
# Anomalie ML
("avg_velocity", "Vélocité (rps)", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(100))), ("avg_fuzzing", "Fuzzing", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(300))),
("pct_headless", "Headless", lambda v: min(1.0, float(v or 0))),
("avg_post", "Ratio POST", lambda v: min(1.0, float(v or 0))),
# IP-ID
("ip_id_zero", "IP-ID Zéro", lambda v: min(1.0, float(v or 0))),
# Temporel
("entropy", "Entropie Temporelle", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(10))),
# Navigateur
("browser_score", "Score Navigateur", lambda v: min(1.0, float(v or 0) / 50.0)),
# TLS / Protocole
("alpn_mismatch", "ALPN Mismatch", lambda v: min(1.0, float(v or 0))),
("alpn_missing", "ALPN Absent", lambda v: min(1.0, float(v or 0))),
("h2_eff", "H2 Multiplexing", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(20))),
("hdr_conf", "Ordre Headers", lambda v: min(1.0, float(v or 0))),
("ua_ch_mismatch","UA-CH Mismatch", lambda v: min(1.0, float(v or 0))),
# Comportement HTTP
("asset_ratio", "Ratio Assets", lambda v: min(1.0, float(v or 0))),
("direct_ratio", "Accès Direct", lambda v: min(1.0, float(v or 0))),
# Diversité JA4
("ja4_count", "Diversité JA4", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(30))),
# UA rotatif
("ua_rotating", "UA Rotatif", lambda v: 1.0 if float(v or 0) > 0 else 0.0),
# ── Géographie & infrastructure (nouvelles features) ──────────────────
("country", "Risque Pays", lambda v: country_risk(str(v) if v else None)),
("asn_org", "Hébergeur Cloud/VPN", lambda v: asn_cloud_score(str(v) if v else None)),
# ── Headers HTTP (présence / profil de la requête) ────────────────────
# Absence d'Accept-Language ou Accept-Encoding = fort signal bot (bots simples l'omettent)
# Sec-Fetch-* = exclusif aux navigateurs réels (fetch metadata)
("hdr_accept_lang", "Accept-Language", lambda v: min(1.0, float(v or 0))),
("hdr_has_encoding", "Accept-Encoding", lambda v: 1.0 if float(v or 0) > 0 else 0.0),
("hdr_has_sec_fetch", "Sec-Fetch Headers", lambda v: 1.0 if float(v or 0) > 0 else 0.0),
("hdr_count_raw", "Nb Headers", lambda v: min(1.0, float(v or 0) / 20.0)),
# ── Fingerprint HTTP Headers (agg_header_fingerprint_1h) ──────────────
# header_order_shared_count : nb d'IPs partageant ce fingerprint
# élevé → populaire → browser légitime (normalisé log1p / log1p(500000))
("hfp_shared_count", "FP Popularité", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(500_000))),
# distinct_header_orders : nb de fingerprints distincts pour cette IP
# élevé → rotation de fingerprint → bot (normalisé log1p / log1p(10))
("hfp_distinct_orders", "FP Rotation", lambda v: min(1.0, math.log1p(float(v or 0)) / math.log1p(10))),
# Cookie et Referer : signaux de navigation légitime
("hfp_cookie", "Cookie Présent", lambda v: min(1.0, float(v or 0))),
("hfp_referer", "Referer Présent", lambda v: min(1.0, float(v or 0))),
]
FEATURE_KEYS = [f[0] for f in FEATURES]
FEATURE_NAMES = [f[1] for f in FEATURES]
FEATURE_NORMS = [f[2] for f in FEATURES]
N_FEATURES = len(FEATURES)
# ─── Construction du vecteur de features ─────────────────────────────────────
def build_feature_vector(row: dict) -> list[float]:
"""Construit le vecteur normalisé [0,1]^23 depuis un dict SQL."""
return [norm(row.get(key, 0)) for key, _, norm in FEATURES]
# ─── Standardisation z-score ──────────────────────────────────────────────────
def standardize(X: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Z-score standardisation : chaque feature est centrée et mise à l'échelle
par sa déviation standard.
Ref: Bishop (2006) PRML §9.1 — preprocessing recommandé pour K-means.
Retourne (X_std, mean, std) pour pouvoir projeter de nouveaux points.
"""
mean = X.mean(axis=0)
std = X.std(axis=0)
std[std < 1e-8] = 1.0 # évite la division par zéro pour features constantes
return (X - mean) / std, mean, std
# ─── K-means++ vectorisé (numpy) ─────────────────────────────────────────────
@dataclass
class KMeansResult:
centroids: np.ndarray # (k, n_features)
labels: np.ndarray # (n_points,) int32
inertia: float
n_iter: int
def kmeans_pp(X: np.ndarray, k: int, max_iter: int = 60, n_init: int = 3,
seed: int = 42) -> KMeansResult:
"""
K-means++ entièrement vectorisé avec numpy.
n_init exécutions, meilleure inertie conservée.
"""
rng = np.random.default_rng(seed)
n, d = X.shape
best: KMeansResult | None = None
for _ in range(n_init):
# ── Initialisation K-means++ ──────────────────────────────────────
centers = [X[rng.integers(n)].copy()]
for _ in range(k - 1):
D = _min_sq_dist(X, np.array(centers))
# Garantit des probabilités non-négatives (erreurs float, points dupliqués)
D = np.clip(D, 0.0, None)
total = D.sum()
if total < 1e-12:
# Tous les points sont confondus — tirage aléatoire
centers.append(X[rng.integers(n)].copy())
else:
probs = D / total
centers.append(X[rng.choice(n, p=probs)].copy())
centers_arr = np.array(centers) # (k, d)
# ── Iterations ───────────────────────────────────────────────────
labels = np.zeros(n, dtype=np.int32)
for it in range(max_iter):
# Assignation vectorisée : (n, k) distance²
dists = _sq_dists(X, centers_arr) # (n, k)
new_labels = np.argmin(dists, axis=1).astype(np.int32)
if it > 0 and np.all(new_labels == labels):
break # convergence
labels = new_labels
# Mise à jour des centroïdes
for j in range(k):
mask = labels == j
if mask.any():
centers_arr[j] = X[mask].mean(axis=0)
inertia = float(np.sum(np.min(_sq_dists(X, centers_arr), axis=1)))
result = KMeansResult(centers_arr, labels, inertia, it + 1)
if best is None or inertia < best.inertia:
best = result
return best # type: ignore[return-value]
def _sq_dists(X: np.ndarray, C: np.ndarray) -> np.ndarray:
"""Distance² entre chaque point de X et chaque centroïde de C. O(n·k·d)."""
# ||x - c||² = ||x||² + ||c||² - 2·x·cᵀ
X2 = np.sum(X ** 2, axis=1, keepdims=True) # (n, 1)
C2 = np.sum(C ** 2, axis=1, keepdims=True).T # (1, k)
return X2 + C2 - 2.0 * X @ C.T # (n, k)
def _min_sq_dist(X: np.ndarray, C: np.ndarray) -> np.ndarray:
"""Distance² minimale de chaque point aux centroïdes existants."""
return np.min(_sq_dists(X, C), axis=1)
# ─── PCA 2D (numpy) ──────────────────────────────────────────────────────────
def pca_2d(X: np.ndarray) -> np.ndarray:
"""
PCA-2D vectorisée. Retourne les coordonnées normalisées [0,1] × [0,1].
"""
mean = X.mean(axis=0)
Xc = X - mean
# Power iteration pour les 2 premières composantes
rng = np.random.default_rng(0)
v1 = _power_iter(Xc, rng.standard_normal(Xc.shape[1]))
proj1 = Xc @ v1
# Déflation (Hotelling)
Xc2 = Xc - np.outer(proj1, v1)
v2 = _power_iter(Xc2, rng.standard_normal(Xc.shape[1]))
proj2 = Xc2 @ v2
coords = np.column_stack([proj1, proj2])
# Normalisation [0,1]
mn, mx = coords.min(axis=0), coords.max(axis=0)
rng_ = mx - mn
rng_[rng_ == 0] = 1.0
return (coords - mn) / rng_
def _power_iter(X: np.ndarray, v: np.ndarray, n_iter: int = 30) -> np.ndarray:
"""Power iteration : trouve le premier vecteur propre de XᵀX."""
for _ in range(n_iter):
v = X.T @ (X @ v)
norm = np.linalg.norm(v)
if norm < 1e-12:
break
v /= norm
return v
# ─── Enveloppe convexe (hull) par cluster ────────────────────────────────────
def compute_hulls(coords_2d: np.ndarray, labels: np.ndarray,
k: int, min_pts: int = 4) -> dict[int, list[list[float]]]:
"""
Calcule l'enveloppe convexe (convex hull) des points PCA pour chaque cluster.
Retourne {cluster_idx: [[x,y], ...]} (polygone fermé).
"""
hulls: dict[int, list[list[float]]] = {}
for j in range(k):
pts = coords_2d[labels == j]
if len(pts) < min_pts:
# Pas assez de points : bounding box
if len(pts) > 0:
mx_, my_ = pts.mean(axis=0)
r = max(0.01, pts.std(axis=0).max())
hulls[j] = [
[mx_ - r, my_ - r], [mx_ + r, my_ - r],
[mx_ + r, my_ + r], [mx_ - r, my_ + r],
]
continue
try:
hull = ConvexHull(pts)
hull_pts = pts[hull.vertices].tolist()
# Fermer le polygone
hull_pts.append(hull_pts[0])
hulls[j] = hull_pts
except Exception:
hulls[j] = []
return hulls
# ─── Nommage et scoring ───────────────────────────────────────────────────────
def name_cluster(centroid: np.ndarray, raw_stats: dict) -> str:
"""Nom lisible basé sur les features dominantes du centroïde [0,1]."""
s = centroid
n = len(s)
ttl_raw = float(raw_stats.get("mean_ttl", 0))
mss_raw = float(raw_stats.get("mean_mss", 0))
country_risk_v = s[20] if n > 20 else 0.0
asn_cloud = s[21] if n > 21 else 0.0
accept_lang = s[22] if n > 22 else 1.0
accept_enc = s[23] if n > 23 else 1.0
sec_fetch = s[24] if n > 24 else 0.0
hdr_count = s[25] if n > 25 else 0.5
hfp_popular = s[26] if n > 26 else 0.5
hfp_rotating = s[27] if n > 27 else 0.0
# Scanner pur : aucun header browser, fingerprint rare, peu de headers
if accept_lang < 0.15 and accept_enc < 0.15 and hdr_count < 0.25:
return "🤖 Scanner pur (no headers)"
# Fingerprint tournant : bot qui change de profil headers
if hfp_rotating > 0.6:
return "🔄 Bot fingerprint tournant"
# Fingerprint très rare : bot artisanal unique
if hfp_popular < 0.15:
return "🕵️ Fingerprint rare suspect"
# Scanners Masscan
if s[0] > 0.16 and s[0] < 0.25 and mss_raw in range(1440, 1460) and s[2] > 0.25:
return "🤖 Masscan Scanner"
# Bots offensifs agressifs (fuzzing élevé)
if s[4] > 0.40 and s[5] > 0.3:
return "🤖 Bot agressif"
# Bot qui simule un navigateur mais sans les vrais headers
if s[15] > 0.40 and sec_fetch < 0.2 and accept_lang < 0.3:
return "🤖 Bot UA simulé"
# Pays à très haut risque avec infrastructure cloud
if country_risk_v > 0.75 and asn_cloud > 0.5:
return "🌏 Source pays risqué"
# Cloud + UA-CH mismatch
if s[15] > 0.50 and asn_cloud > 0.70:
return "☁️ Bot cloud UA-CH"
if s[15] > 0.60:
return "🤖 UA-CH Mismatch"
# Headless browser (Puppeteer/Playwright) : a les headers Sec-Fetch mais headless
if s[6] > 0.50 and sec_fetch > 0.5:
return "🤖 Headless Browser"
if s[6] > 0.50:
return "🤖 Headless (no Sec-Fetch)"
# Cloud pur (CDN/crawler légitime ?)
if asn_cloud > 0.85:
return "☁️ Infrastructure cloud"
# Pays à risque élevé sans autre signal
if country_risk_v > 0.60:
return "🌏 Trafic suspect (pays)"
# Navigateur légitime : tous les signaux positifs y compris fingerprint populaire
if (accept_lang > 0.7 and accept_enc > 0.7 and sec_fetch > 0.5
and hdr_count > 0.5 and hfp_popular > 0.5):
return "🌐 Navigateur légitime"
# OS fingerprinting
if s[3] > 0.85 and ttl_raw > 120:
return "🖥️ Windows"
if s[0] > 0.22 and s[0] < 0.28 and mss_raw > 1400:
return "🐧 Linux"
if mss_raw < 1380 and mss_raw > 0:
return "🌐 Tunnel réseau"
if s[4] > 0.40:
return "⚡ Trafic rapide"
if s[4] < 0.10 and asn_cloud < 0.30:
return "✅ Trafic sain"
return "📊 Cluster mixte"
def risk_score_from_centroid(centroid: np.ndarray) -> float:
"""
Score de risque [0,1] depuis le centroïde (espace original [0,1]).
30 features (avg_score supprimé) — poids calibrés pour sommer à 1.0.
Indices décalés de -1 après suppression de avg_score (ancien idx 4).
"""
s = centroid
n = len(s)
country_risk_v = s[20] if n > 20 else 0.0
asn_cloud = s[21] if n > 21 else 0.0
no_accept_lang = 1.0 - (s[22] if n > 22 else 1.0)
no_encoding = 1.0 - (s[23] if n > 23 else 1.0)
no_sec_fetch = 1.0 - (s[24] if n > 24 else 0.0)
few_headers = 1.0 - (s[25] if n > 25 else 0.5)
hfp_rare = 1.0 - (s[26] if n > 26 else 0.5)
hfp_rotating = s[27] if n > 27 else 0.0
# [4]=vélocité [5]=fuzzing [6]=headless [8]=ip_id_zero [15]=ua_ch_mismatch
# Poids redistribués depuis l'ancien score ML anomalie (0.25) vers les signaux restants
return float(np.clip(
0.14 * s[5] + # fuzzing
0.17 * s[15] + # UA-CH mismatch (fort signal impersonation navigateur)
0.10 * s[6] + # headless
0.09 * s[4] + # vélocité (rps)
0.07 * s[8] + # IP-ID zéro
0.09 * country_risk_v+ # risque pays source
0.06 * asn_cloud + # infrastructure cloud/VPN
0.04 * no_accept_lang+ # absence Accept-Language
0.04 * no_encoding + # absence Accept-Encoding
0.04 * no_sec_fetch + # absence Sec-Fetch
0.04 * few_headers + # très peu de headers
0.06 * hfp_rare + # fingerprint rare = suspect
0.06 * hfp_rotating, # rotation de fingerprint = bot
0.0, 1.0
))
# ─── Palette de couleurs diversifiée (non liée au risque) ────────────────────
# 24 couleurs couvrant tout le spectre HSL pour distinguer les clusters visuellement.
# Choix: teintes espacées de ~15° avec alternance de saturation/luminosité.
_CLUSTER_PALETTE: list[str] = [
"#3b82f6", # blue
"#8b5cf6", # violet
"#ec4899", # pink
"#14b8a6", # teal
"#f59e0b", # amber
"#06b6d4", # cyan
"#a3e635", # lime
"#f97316", # orange
"#6366f1", # indigo
"#10b981", # emerald
"#e879f9", # fuchsia
"#fbbf24", # yellow
"#60a5fa", # light blue
"#c084fc", # light purple
"#fb7185", # rose
"#34d399", # light green
"#38bdf8", # sky
"#a78bfa", # lavender
"#fdba74", # peach
"#4ade80", # green
"#f472b6", # light pink
"#67e8f9", # light cyan
"#d97706", # dark amber
"#7c3aed", # dark violet
]
def cluster_color(cluster_idx: int) -> str:
"""Couleur distinctive pour un cluster, cyclique sur la palette."""
return _CLUSTER_PALETTE[cluster_idx % len(_CLUSTER_PALETTE)]
# ─── Dispersion des clusters dans l'espace 2D ────────────────────────────────
def spread_clusters(coords_2d: np.ndarray, labels: np.ndarray, k: int,
n_iter: int = 50, min_dist: float = 0.14) -> np.ndarray:
"""
Repousse les centroïdes trop proches par répulsion itérative (spring repulsion).
Chaque point suit le déplacement de son centroïde.
Paramètres
----------
min_dist : distance minimale souhaitée entre centroïdes (espace [0,1]).
Augmenter pour plus d'éclatement.
n_iter : nombre d'itérations de la physique de répulsion.
"""
rng = np.random.default_rng(0)
centroids = np.zeros((k, 2))
counts = np.zeros(k, dtype=int)
for j in range(k):
mask = labels == j
if mask.any():
centroids[j] = coords_2d[mask].mean(axis=0)
counts[j] = int(mask.sum())
orig = centroids.copy()
for _ in range(n_iter):
forces = np.zeros_like(centroids)
for i in range(k):
if counts[i] == 0:
continue
for j in range(k):
if i == j or counts[j] == 0:
continue
delta = centroids[i] - centroids[j]
dist = float(np.linalg.norm(delta))
if dist < 1e-8:
delta = rng.uniform(-0.02, 0.02, size=2)
dist = float(np.linalg.norm(delta)) + 1e-8
if dist < min_dist:
# Force inversement proportionnelle à l'écart
magnitude = (min_dist - dist) / min_dist
forces[i] += magnitude * (delta / dist)
centroids += forces * 0.10
# Déplace chaque point par le delta de son centroïde
displaced = coords_2d.copy()
for j in range(k):
if counts[j] == 0:
continue
displaced[labels == j] += centroids[j] - orig[j]
# Re-normalisation [0, 1]
mn, mx = displaced.min(axis=0), displaced.max(axis=0)
rng_ = mx - mn
rng_[rng_ < 1e-8] = 1.0
return (displaced - mn) / rng_