""" Endpoints pour les features ML / IA (scores d'anomalies, radar, scatter) """ from fastapi import APIRouter, HTTPException, Query from ..database import db router = APIRouter(prefix="/api/ml", tags=["ml_features"]) def _attack_type(fuzzing_index: float, hit_velocity: float, is_fake_nav: int, ua_ch_mismatch: int) -> str: if fuzzing_index > 50: return "brute_force" if hit_velocity > 1.0: return "flood" if is_fake_nav: return "scraper" if ua_ch_mismatch: return "spoofing" return "scanner" @router.get("/top-anomalies") async def get_top_anomalies(limit: int = Query(50, ge=1, le=500)): """Top IPs anomales (24h) — bypass view_ai_features_1h pour éviter les window functions. Query directe sur agg_host_ip_ja4_1h + LEFT JOIN agg_header_fingerprint_1h. """ try: sql = """ SELECT replaceRegexpAll(toString(a.src_ip), '^::ffff:', '') AS ip, any(a.ja4) AS ja4, any(a.host) AS host, sum(a.hits) AS hits, round(max(uniqMerge(a.uniq_query_params)) / greatest(max(uniqMerge(a.uniq_paths)), 1), 4) AS fuzzing_index, round(sum(a.hits) / greatest(dateDiff('second', min(a.first_seen), max(a.last_seen)), 1), 2) AS hit_velocity, round(sum(a.count_head) / greatest(sum(a.hits), 1), 4) AS head_ratio, round(sum(a.count_no_sec_fetch) / greatest(sum(a.hits), 1), 4) AS sec_fetch_absence, round(sum(a.tls12_count) / greatest(sum(a.hits), 1), 4) AS tls12_ratio, round(sum(a.count_generic_accept) / greatest(sum(a.hits), 1), 4) AS generic_accept_ratio, any(a.src_country_code) AS country, any(a.src_as_name) AS asn_name, max(h.ua_ch_mismatch) AS ua_ch_mismatch, max(h.modern_browser_score) AS browser_score, dictGetOrDefault('mabase_prod.dict_asn_reputation', 'label', toUInt64(any(a.src_asn)), 'unknown') AS asn_label, coalesce( nullIf(dictGetOrDefault('mabase_prod.dict_bot_ja4', 'bot_name', tuple(any(a.ja4)), ''), ''), '' ) AS bot_name FROM mabase_prod.agg_host_ip_ja4_1h a LEFT JOIN mabase_prod.agg_header_fingerprint_1h h ON a.src_ip = h.src_ip AND a.window_start = h.window_start WHERE a.window_start >= now() - INTERVAL 24 HOUR GROUP BY a.src_ip ORDER BY fuzzing_index DESC LIMIT %(limit)s """ result = db.query(sql, {"limit": limit}) items = [] for row in result.result_rows: fuzzing = float(row[4] or 0) velocity = float(row[5] or 0) ua_mm = int(row[12] or 0) items.append({ "ip": str(row[0]), "ja4": str(row[1]), "host": str(row[2]), "hits": int(row[3] or 0), "fuzzing_index": fuzzing, "hit_velocity": velocity, "head_ratio": float(row[6] or 0), "sec_fetch_absence": float(row[7] or 0), "tls12_ratio": float(row[8] or 0), "generic_accept_ratio": float(row[9] or 0), "country": str(row[10] or ""), "asn_name": str(row[11] or ""), "ua_ch_mismatch": ua_mm, "browser_score": int(row[13] or 0), "asn_label": str(row[14] or ""), "bot_name": str(row[15] or ""), "attack_type": _attack_type(fuzzing, velocity, 0, ua_mm), }) return {"items": items} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/ip/{ip}/radar") async def get_ip_radar(ip: str): """Scores radar pour une IP spécifique (8 dimensions d'anomalie).""" try: sql = """ SELECT avg(fuzzing_index) AS fuzzing_index, avg(hit_velocity) AS hit_velocity, avg(is_fake_navigation) AS is_fake_navigation, avg(ua_ch_mismatch) AS ua_ch_mismatch, avg(sni_host_mismatch) AS sni_host_mismatch, avg(orphan_ratio) AS orphan_ratio, avg(path_diversity_ratio) AS path_diversity_ratio, avg(anomalous_payload_ratio) AS anomalous_payload_ratio FROM mabase_prod.view_ai_features_1h WHERE replaceRegexpAll(toString(src_ip), '^::ffff:', '') = %(ip)s AND window_start >= now() - INTERVAL 24 HOUR """ result = db.query(sql, {"ip": ip}) if not result.result_rows: raise HTTPException(status_code=404, detail="IP not found") row = result.result_rows[0] def _f(v) -> float: return float(v or 0) return { "ip": ip, "fuzzing_score": min(100.0, _f(row[0])), "velocity_score": min(100.0, _f(row[1]) * 100), "fake_nav_score": _f(row[2]) * 100, "ua_mismatch_score": _f(row[3]) * 100, "sni_mismatch_score": _f(row[4]) * 100, "orphan_score": min(100.0, _f(row[5]) * 100), "path_repetition_score": max(0.0, 100 - _f(row[6]) * 100), "payload_anomaly_score": min(100.0, _f(row[7]) * 100), } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/score-distribution") async def get_score_distribution(): """ Distribution de TOUS les scores ML depuis ml_all_scores (3j). Single query avec conditional aggregates pour éviter le double scan. """ try: # Single scan — global totals + per-model breakdown via GROUPING SETS sql = """ SELECT threat_level, model_name, count() AS total, round(avg(anomaly_score), 4) AS avg_score, round(min(anomaly_score), 4) AS min_score, countIf(threat_level = 'NORMAL') AS normal_count, countIf(threat_level NOT IN ('NORMAL','KNOWN_BOT')) AS anomaly_count, countIf(threat_level = 'KNOWN_BOT') AS bot_count FROM mabase_prod.ml_all_scores WHERE detected_at >= now() - INTERVAL 3 DAY GROUP BY threat_level, model_name ORDER BY model_name, total DESC """ result = db.query(sql) by_model: dict = {} grand_total = 0 total_normal = total_anomaly = total_bot = 0 for row in result.result_rows: level = str(row[0]) model = str(row[1]) total = int(row[2]) grand_total += total total_normal += int(row[5] or 0) total_anomaly += int(row[6] or 0) total_bot += int(row[7] or 0) if model not in by_model: by_model[model] = [] by_model[model].append({ "threat_level": level, "total": total, "avg_score": float(row[3] or 0), "min_score": float(row[4] or 0), }) grand_total = max(grand_total, 1) return { "by_model": by_model, "totals": { "normal": total_normal, "anomaly": total_anomaly, "known_bot": total_bot, "grand_total": grand_total, "normal_pct": round(total_normal / grand_total * 100, 1), "anomaly_pct": round(total_anomaly / grand_total * 100, 1), "bot_pct": round(total_bot / grand_total * 100, 1), } } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/score-trends") async def get_score_trends(hours: int = Query(72, ge=1, le=168)): """ Évolution temporelle des scores ML depuis ml_all_scores. Retourne le score moyen et les counts par heure et par modèle. """ try: sql = """ SELECT toStartOfHour(window_start) AS hour, model_name, countIf(threat_level = 'NORMAL') AS normal_count, countIf(threat_level IN ('LOW','MEDIUM','HIGH','CRITICAL')) AS anomaly_count, countIf(threat_level = 'KNOWN_BOT') AS bot_count, round(avgIf(anomaly_score, threat_level IN ('LOW','MEDIUM','HIGH','CRITICAL')), 4) AS avg_anomaly_score FROM mabase_prod.ml_all_scores WHERE window_start >= now() - INTERVAL %(hours)s HOUR GROUP BY hour, model_name ORDER BY hour ASC, model_name """ result = db.query(sql, {"hours": hours}) points = [] for row in result.result_rows: points.append({ "hour": str(row[0]), "model": str(row[1]), "normal_count": int(row[2] or 0), "anomaly_count": int(row[3] or 0), "bot_count": int(row[4] or 0), "avg_anomaly_score": float(row[5] or 0), }) return {"points": points, "hours": hours} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/b-features") async def get_b_features(limit: int = Query(50, ge=1, le=200)): """ Agrégation des B-features (HTTP pures) pour les top IPs anomales. Source: agg_host_ip_ja4_1h (SimpleAggregateFunction columns). Expose: head_ratio, sec_fetch_absence, tls12_ratio, generic_accept_ratio, http10_ratio. Ces features sont calculées dans view_ai_features_1h mais jamais visualisées dans le dashboard. """ try: sql = """ SELECT ip, ja4, country, asn_name, hits, head_ratio, sec_fetch_absence, tls12_ratio, generic_accept_ratio, http10_ratio FROM ( SELECT replaceRegexpAll(toString(src_ip), '^::ffff:', '') AS ip, any(ja4) AS ja4, any(src_country_code) AS country, any(src_as_name) AS asn_name, sum(hits) AS hits, round(sum(count_head) / greatest(sum(hits),1), 4) AS head_ratio, round(sum(count_no_sec_fetch) / greatest(sum(hits),1), 4) AS sec_fetch_absence, round(sum(tls12_count) / greatest(sum(hits),1), 4) AS tls12_ratio, round(sum(count_generic_accept) / greatest(sum(hits),1), 4) AS generic_accept_ratio, round(sum(count_http10) / greatest(sum(hits),1), 4) AS http10_ratio FROM mabase_prod.agg_host_ip_ja4_1h WHERE window_start >= now() - INTERVAL 24 HOUR GROUP BY src_ip ) WHERE sec_fetch_absence > 0.5 OR generic_accept_ratio > 0.3 OR head_ratio > 0.1 OR tls12_ratio > 0.5 ORDER BY (head_ratio + sec_fetch_absence + generic_accept_ratio) DESC LIMIT %(limit)s """ result = db.query(sql, {"limit": limit}) items = [] for row in result.result_rows: items.append({ "ip": str(row[0]), "ja4": str(row[1] or ""), "country": str(row[2] or ""), "asn_name": str(row[3] or ""), "hits": int(row[4] or 0), "head_ratio": float(row[5] or 0), "sec_fetch_absence": float(row[6] or 0), "tls12_ratio": float(row[7] or 0), "generic_accept_ratio":float(row[8] or 0), "http10_ratio": float(row[9] or 0), }) return {"items": items, "total": len(items)} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/campaigns") async def get_ml_campaigns(hours: int = Query(24, ge=1, le=168), limit: int = Query(20, ge=1, le=100)): """ Groupes d'anomalies détectées par DBSCAN (campaign_id >= 0). Si aucune campagne active, fallback sur clustering par /24 subnet + JA4 commun. Utile pour détecter les botnets distribués sans état de campagne DBSCAN. """ try: # First: check real campaigns campaign_sql = """ SELECT campaign_id, count() AS total_detections, uniq(src_ip) AS unique_ips, any(threat_level) AS dominant_threat, groupUniqArray(3)(threat_level) AS threat_levels, groupUniqArray(3)(bot_name) AS bot_names, min(detected_at) AS first_seen, max(detected_at) AS last_seen FROM mabase_prod.ml_detected_anomalies WHERE detected_at >= now() - INTERVAL %(hours)s HOUR AND campaign_id >= 0 GROUP BY campaign_id ORDER BY total_detections DESC LIMIT %(limit)s """ result = db.query(campaign_sql, {"hours": hours, "limit": limit}) campaigns = [] for row in result.result_rows: campaigns.append({ "id": f"C{row[0]}", "campaign_id": int(row[0]), "total_detections": int(row[1]), "unique_ips": int(row[2]), "dominant_threat": str(row[3] or ""), "threat_levels": list(row[4] or []), "bot_names": list(row[5] or []), "first_seen": str(row[6]), "last_seen": str(row[7]), "source": "dbscan", }) # Fallback: subnet-based clustering when DBSCAN has no campaigns if not campaigns: subnet_sql = """ SELECT IPv4CIDRToRange(toIPv4(replaceRegexpAll(toString(src_ip),'^::ffff:','')), 24).1 AS subnet, count() AS total_detections, uniq(src_ip) AS unique_ips, groupArray(3)(threat_level) AS threat_levels, any(bot_name) AS bot_name, any(ja4) AS sample_ja4, min(detected_at) AS first_seen, max(detected_at) AS last_seen FROM mabase_prod.ml_detected_anomalies WHERE detected_at >= now() - INTERVAL %(hours)s HOUR AND threat_level IN ('HIGH','CRITICAL','MEDIUM') GROUP BY subnet HAVING unique_ips >= 3 ORDER BY total_detections DESC LIMIT %(limit)s """ result2 = db.query(subnet_sql, {"hours": hours, "limit": limit}) for i, row in enumerate(result2.result_rows): subnet_str = str(row[0]) + "/24" campaigns.append({ "id": f"S{i+1:03d}", "campaign_id": -1, "subnet": subnet_str, "total_detections": int(row[1]), "unique_ips": int(row[2]), "dominant_threat": str((row[3] or [""])[0]), "threat_levels": list(row[3] or []), "bot_names": [str(row[4] or "")], "sample_ja4": str(row[5] or ""), "first_seen": str(row[6]), "last_seen": str(row[7]), "source": "subnet_cluster", }) dbscan_active = any(c["campaign_id"] >= 0 for c in campaigns) return { "campaigns": campaigns, "total": len(campaigns), "dbscan_active": dbscan_active, "hours": hours, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/scatter") async def get_ml_scatter(limit: int = Query(200, ge=1, le=1000)): """Points scatter plot (fuzzing_index × hit_velocity) — bypass view_ai_features_1h.""" try: sql = """ SELECT replaceRegexpAll(toString(src_ip), '^::ffff:', '') AS ip, any(ja4) AS ja4, round(max(uniqMerge(uniq_query_params)) / greatest(max(uniqMerge(uniq_paths)), 1), 4) AS fuzzing_index, round(sum(hits) / greatest(dateDiff('second', min(first_seen), max(last_seen)), 1), 2) AS hit_velocity, sum(hits) AS hits, round(sum(count_head) / greatest(sum(hits), 1), 4) AS head_ratio, max(correlated_raw) AS correlated FROM mabase_prod.agg_host_ip_ja4_1h WHERE window_start >= now() - INTERVAL 24 HOUR GROUP BY src_ip ORDER BY fuzzing_index DESC LIMIT %(limit)s """ result = db.query(sql, {"limit": limit}) points = [] for row in result.result_rows: fuzzing = float(row[2] or 0) velocity = float(row[3] or 0) points.append({ "ip": str(row[0]), "ja4": str(row[1]), "fuzzing_index":fuzzing, "hit_velocity": velocity, "hits": int(row[4] or 0), "attack_type": _attack_type(fuzzing, velocity, 0, 0), }) return {"points": points} except Exception as e: raise HTTPException(status_code=500, detail=str(e))