Files
dashboard/backend/routes/ml_features.py
2026-03-18 09:00:47 +01:00

410 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Endpoints pour les features ML / IA (scores d'anomalies, radar, scatter)
"""
from fastapi import APIRouter, HTTPException, Query
from ..database import db
router = APIRouter(prefix="/api/ml", tags=["ml_features"])
def _attack_type(fuzzing_index: float, hit_velocity: float,
is_fake_nav: int, ua_ch_mismatch: int) -> str:
if fuzzing_index > 50:
return "brute_force"
if hit_velocity > 1.0:
return "flood"
if is_fake_nav:
return "scraper"
if ua_ch_mismatch:
return "spoofing"
return "scanner"
@router.get("/top-anomalies")
async def get_top_anomalies(limit: int = Query(50, ge=1, le=500)):
"""Top IPs anomales (24h) — bypass view_ai_features_1h pour éviter les window functions.
Query directe sur agg_host_ip_ja4_1h + LEFT JOIN agg_header_fingerprint_1h.
"""
try:
sql = """
SELECT
replaceRegexpAll(toString(a.src_ip), '^::ffff:', '') AS ip,
any(a.ja4) AS ja4,
any(a.host) AS host,
sum(a.hits) AS hits,
round(max(uniqMerge(a.uniq_query_params))
/ greatest(max(uniqMerge(a.uniq_paths)), 1), 4) AS fuzzing_index,
round(sum(a.hits)
/ greatest(dateDiff('second', min(a.first_seen), max(a.last_seen)), 1), 2) AS hit_velocity,
round(sum(a.count_head) / greatest(sum(a.hits), 1), 4) AS head_ratio,
round(sum(a.count_no_sec_fetch) / greatest(sum(a.hits), 1), 4) AS sec_fetch_absence,
round(sum(a.tls12_count) / greatest(sum(a.hits), 1), 4) AS tls12_ratio,
round(sum(a.count_generic_accept) / greatest(sum(a.hits), 1), 4) AS generic_accept_ratio,
any(a.src_country_code) AS country,
any(a.src_as_name) AS asn_name,
max(h.ua_ch_mismatch) AS ua_ch_mismatch,
max(h.modern_browser_score) AS browser_score,
dictGetOrDefault('mabase_prod.dict_asn_reputation', 'label', toUInt64(any(a.src_asn)), 'unknown') AS asn_label,
coalesce(
nullIf(dictGetOrDefault('mabase_prod.dict_bot_ja4', 'bot_name', tuple(any(a.ja4)), ''), ''),
''
) AS bot_name
FROM mabase_prod.agg_host_ip_ja4_1h a
LEFT JOIN mabase_prod.agg_header_fingerprint_1h h
ON a.src_ip = h.src_ip AND a.window_start = h.window_start
WHERE a.window_start >= now() - INTERVAL 24 HOUR
GROUP BY a.src_ip
ORDER BY fuzzing_index DESC
LIMIT %(limit)s
"""
result = db.query(sql, {"limit": limit})
items = []
for row in result.result_rows:
fuzzing = float(row[4] or 0)
velocity = float(row[5] or 0)
ua_mm = int(row[12] or 0)
items.append({
"ip": str(row[0]),
"ja4": str(row[1]),
"host": str(row[2]),
"hits": int(row[3] or 0),
"fuzzing_index": fuzzing,
"hit_velocity": velocity,
"head_ratio": float(row[6] or 0),
"sec_fetch_absence": float(row[7] or 0),
"tls12_ratio": float(row[8] or 0),
"generic_accept_ratio": float(row[9] or 0),
"country": str(row[10] or ""),
"asn_name": str(row[11] or ""),
"ua_ch_mismatch": ua_mm,
"browser_score": int(row[13] or 0),
"asn_label": str(row[14] or ""),
"bot_name": str(row[15] or ""),
"attack_type": _attack_type(fuzzing, velocity, 0, ua_mm),
})
return {"items": items}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/ip/{ip}/radar")
async def get_ip_radar(ip: str):
"""Scores radar pour une IP spécifique (8 dimensions d'anomalie)."""
try:
sql = """
SELECT
avg(fuzzing_index) AS fuzzing_index,
avg(hit_velocity) AS hit_velocity,
avg(is_fake_navigation) AS is_fake_navigation,
avg(ua_ch_mismatch) AS ua_ch_mismatch,
avg(sni_host_mismatch) AS sni_host_mismatch,
avg(orphan_ratio) AS orphan_ratio,
avg(path_diversity_ratio) AS path_diversity_ratio,
avg(anomalous_payload_ratio) AS anomalous_payload_ratio
FROM mabase_prod.view_ai_features_1h
WHERE replaceRegexpAll(toString(src_ip), '^::ffff:', '') = %(ip)s
AND window_start >= now() - INTERVAL 24 HOUR
"""
result = db.query(sql, {"ip": ip})
if not result.result_rows:
raise HTTPException(status_code=404, detail="IP not found")
row = result.result_rows[0]
def _f(v) -> float:
return float(v or 0)
return {
"ip": ip,
"fuzzing_score": min(100.0, _f(row[0])),
"velocity_score": min(100.0, _f(row[1]) * 100),
"fake_nav_score": _f(row[2]) * 100,
"ua_mismatch_score": _f(row[3]) * 100,
"sni_mismatch_score": _f(row[4]) * 100,
"orphan_score": min(100.0, _f(row[5]) * 100),
"path_repetition_score": max(0.0, 100 - _f(row[6]) * 100),
"payload_anomaly_score": min(100.0, _f(row[7]) * 100),
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/score-distribution")
async def get_score_distribution():
"""
Distribution de TOUS les scores ML depuis ml_all_scores (3j).
Single query avec conditional aggregates pour éviter le double scan.
"""
try:
# Single scan — global totals + per-model breakdown via GROUPING SETS
sql = """
SELECT
threat_level,
model_name,
count() AS total,
round(avg(anomaly_score), 4) AS avg_score,
round(min(anomaly_score), 4) AS min_score,
countIf(threat_level = 'NORMAL') AS normal_count,
countIf(threat_level NOT IN ('NORMAL','KNOWN_BOT')) AS anomaly_count,
countIf(threat_level = 'KNOWN_BOT') AS bot_count
FROM mabase_prod.ml_all_scores
WHERE detected_at >= now() - INTERVAL 3 DAY
GROUP BY threat_level, model_name
ORDER BY model_name, total DESC
"""
result = db.query(sql)
by_model: dict = {}
grand_total = 0
total_normal = total_anomaly = total_bot = 0
for row in result.result_rows:
level = str(row[0])
model = str(row[1])
total = int(row[2])
grand_total += total
total_normal += int(row[5] or 0)
total_anomaly += int(row[6] or 0)
total_bot += int(row[7] or 0)
if model not in by_model:
by_model[model] = []
by_model[model].append({
"threat_level": level,
"total": total,
"avg_score": float(row[3] or 0),
"min_score": float(row[4] or 0),
})
grand_total = max(grand_total, 1)
return {
"by_model": by_model,
"totals": {
"normal": total_normal,
"anomaly": total_anomaly,
"known_bot": total_bot,
"grand_total": grand_total,
"normal_pct": round(total_normal / grand_total * 100, 1),
"anomaly_pct": round(total_anomaly / grand_total * 100, 1),
"bot_pct": round(total_bot / grand_total * 100, 1),
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/score-trends")
async def get_score_trends(hours: int = Query(72, ge=1, le=168)):
"""
Évolution temporelle des scores ML depuis ml_all_scores.
Retourne le score moyen et les counts par heure et par modèle.
"""
try:
sql = """
SELECT
toStartOfHour(window_start) AS hour,
model_name,
countIf(threat_level = 'NORMAL') AS normal_count,
countIf(threat_level IN ('LOW','MEDIUM','HIGH','CRITICAL')) AS anomaly_count,
countIf(threat_level = 'KNOWN_BOT') AS bot_count,
round(avgIf(anomaly_score, threat_level IN ('LOW','MEDIUM','HIGH','CRITICAL')), 4) AS avg_anomaly_score
FROM mabase_prod.ml_all_scores
WHERE window_start >= now() - INTERVAL %(hours)s HOUR
GROUP BY hour, model_name
ORDER BY hour ASC, model_name
"""
result = db.query(sql, {"hours": hours})
points = []
for row in result.result_rows:
points.append({
"hour": str(row[0]),
"model": str(row[1]),
"normal_count": int(row[2] or 0),
"anomaly_count": int(row[3] or 0),
"bot_count": int(row[4] or 0),
"avg_anomaly_score": float(row[5] or 0),
})
return {"points": points, "hours": hours}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/b-features")
async def get_b_features(limit: int = Query(50, ge=1, le=200)):
"""
Agrégation des B-features (HTTP pures) pour les top IPs anomales.
Source: agg_host_ip_ja4_1h (SimpleAggregateFunction columns).
Expose: head_ratio, sec_fetch_absence, tls12_ratio, generic_accept_ratio, http10_ratio.
Ces features sont calculées dans view_ai_features_1h mais jamais visualisées dans le dashboard.
"""
try:
sql = """
SELECT ip, ja4, country, asn_name, hits,
head_ratio, sec_fetch_absence, tls12_ratio, generic_accept_ratio, http10_ratio
FROM (
SELECT
replaceRegexpAll(toString(src_ip), '^::ffff:', '') AS ip,
any(ja4) AS ja4,
any(src_country_code) AS country,
any(src_as_name) AS asn_name,
sum(hits) AS hits,
round(sum(count_head) / greatest(sum(hits),1), 4) AS head_ratio,
round(sum(count_no_sec_fetch) / greatest(sum(hits),1), 4) AS sec_fetch_absence,
round(sum(tls12_count) / greatest(sum(hits),1), 4) AS tls12_ratio,
round(sum(count_generic_accept) / greatest(sum(hits),1), 4) AS generic_accept_ratio,
round(sum(count_http10) / greatest(sum(hits),1), 4) AS http10_ratio
FROM mabase_prod.agg_host_ip_ja4_1h
WHERE window_start >= now() - INTERVAL 24 HOUR
GROUP BY src_ip
)
WHERE sec_fetch_absence > 0.5 OR generic_accept_ratio > 0.3
OR head_ratio > 0.1 OR tls12_ratio > 0.5
ORDER BY (head_ratio + sec_fetch_absence + generic_accept_ratio) DESC
LIMIT %(limit)s
"""
result = db.query(sql, {"limit": limit})
items = []
for row in result.result_rows:
items.append({
"ip": str(row[0]),
"ja4": str(row[1] or ""),
"country": str(row[2] or ""),
"asn_name": str(row[3] or ""),
"hits": int(row[4] or 0),
"head_ratio": float(row[5] or 0),
"sec_fetch_absence": float(row[6] or 0),
"tls12_ratio": float(row[7] or 0),
"generic_accept_ratio":float(row[8] or 0),
"http10_ratio": float(row[9] or 0),
})
return {"items": items, "total": len(items)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/campaigns")
async def get_ml_campaigns(hours: int = Query(24, ge=1, le=168), limit: int = Query(20, ge=1, le=100)):
"""
Groupes d'anomalies détectées par DBSCAN (campaign_id >= 0).
Si aucune campagne active, fallback sur clustering par /24 subnet + JA4 commun.
Utile pour détecter les botnets distribués sans état de campagne DBSCAN.
"""
try:
# First: check real campaigns
campaign_sql = """
SELECT
campaign_id,
count() AS total_detections,
uniq(src_ip) AS unique_ips,
any(threat_level) AS dominant_threat,
groupUniqArray(3)(threat_level) AS threat_levels,
groupUniqArray(3)(bot_name) AS bot_names,
min(detected_at) AS first_seen,
max(detected_at) AS last_seen
FROM mabase_prod.ml_detected_anomalies
WHERE detected_at >= now() - INTERVAL %(hours)s HOUR
AND campaign_id >= 0
GROUP BY campaign_id
ORDER BY total_detections DESC
LIMIT %(limit)s
"""
result = db.query(campaign_sql, {"hours": hours, "limit": limit})
campaigns = []
for row in result.result_rows:
campaigns.append({
"id": f"C{row[0]}",
"campaign_id": int(row[0]),
"total_detections": int(row[1]),
"unique_ips": int(row[2]),
"dominant_threat": str(row[3] or ""),
"threat_levels": list(row[4] or []),
"bot_names": list(row[5] or []),
"first_seen": str(row[6]),
"last_seen": str(row[7]),
"source": "dbscan",
})
# Fallback: subnet-based clustering when DBSCAN has no campaigns
if not campaigns:
subnet_sql = """
SELECT
IPv4CIDRToRange(toIPv4(replaceRegexpAll(toString(src_ip),'^::ffff:','')), 24).1 AS subnet,
count() AS total_detections,
uniq(src_ip) AS unique_ips,
groupArray(3)(threat_level) AS threat_levels,
any(bot_name) AS bot_name,
any(ja4) AS sample_ja4,
min(detected_at) AS first_seen,
max(detected_at) AS last_seen
FROM mabase_prod.ml_detected_anomalies
WHERE detected_at >= now() - INTERVAL %(hours)s HOUR
AND threat_level IN ('HIGH','CRITICAL','MEDIUM')
GROUP BY subnet
HAVING unique_ips >= 3
ORDER BY total_detections DESC
LIMIT %(limit)s
"""
result2 = db.query(subnet_sql, {"hours": hours, "limit": limit})
for i, row in enumerate(result2.result_rows):
subnet_str = str(row[0]) + "/24"
campaigns.append({
"id": f"S{i+1:03d}",
"campaign_id": -1,
"subnet": subnet_str,
"total_detections": int(row[1]),
"unique_ips": int(row[2]),
"dominant_threat": str((row[3] or [""])[0]),
"threat_levels": list(row[3] or []),
"bot_names": [str(row[4] or "")],
"sample_ja4": str(row[5] or ""),
"first_seen": str(row[6]),
"last_seen": str(row[7]),
"source": "subnet_cluster",
})
dbscan_active = any(c["campaign_id"] >= 0 for c in campaigns)
return {
"campaigns": campaigns,
"total": len(campaigns),
"dbscan_active": dbscan_active,
"hours": hours,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/scatter")
async def get_ml_scatter(limit: int = Query(200, ge=1, le=1000)):
"""Points scatter plot (fuzzing_index × hit_velocity) — bypass view_ai_features_1h."""
try:
sql = """
SELECT
replaceRegexpAll(toString(src_ip), '^::ffff:', '') AS ip,
any(ja4) AS ja4,
round(max(uniqMerge(uniq_query_params)) / greatest(max(uniqMerge(uniq_paths)), 1), 4) AS fuzzing_index,
round(sum(hits) / greatest(dateDiff('second', min(first_seen), max(last_seen)), 1), 2) AS hit_velocity,
sum(hits) AS hits,
round(sum(count_head) / greatest(sum(hits), 1), 4) AS head_ratio,
max(correlated_raw) AS correlated
FROM mabase_prod.agg_host_ip_ja4_1h
WHERE window_start >= now() - INTERVAL 24 HOUR
GROUP BY src_ip
ORDER BY fuzzing_index DESC
LIMIT %(limit)s
"""
result = db.query(sql, {"limit": limit})
points = []
for row in result.result_rows:
fuzzing = float(row[2] or 0)
velocity = float(row[3] or 0)
points.append({
"ip": str(row[0]),
"ja4": str(row[1]),
"fuzzing_index":fuzzing,
"hit_velocity": velocity,
"hits": int(row[4] or 0),
"attack_type": _attack_type(fuzzing, velocity, 0, 0),
})
return {"points": points}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))