""" Endpoints pour l'analyse de corrélations et la classification SOC """ from fastapi import APIRouter, HTTPException, Query from typing import Optional, List from datetime import datetime import ipaddress import json from ..database import db from ..models import ( SubnetAnalysis, CountryAnalysis, CountryData, JA4Analysis, JA4SubnetData, UserAgentAnalysis, UserAgentData, CorrelationIndicators, ClassificationRecommendation, ClassificationLabel, ClassificationCreate, Classification, ClassificationsListResponse ) router = APIRouter(prefix="/api/analysis", tags=["analysis"]) # ============================================================================= # ANALYSE SUBNET / ASN # ============================================================================= @router.get("/{ip}/subnet", response_model=SubnetAnalysis) async def analyze_subnet(ip: str): """ Analyse les IPs du même subnet et ASN """ try: # Calculer le subnet /24 ip_obj = ipaddress.ip_address(ip) subnet = ipaddress.ip_network(f"{ip}/24", strict=False) subnet_str = str(subnet) # Récupérer les infos ASN pour cette IP asn_query = """ SELECT asn_number, asn_org FROM ml_detected_anomalies WHERE src_ip = %(ip)s ORDER BY detected_at DESC LIMIT 1 """ asn_result = db.query(asn_query, {"ip": ip}) if not asn_result.result_rows: # Fallback: utiliser données par défaut asn_number = "0" asn_org = "Unknown" else: asn_number = str(asn_result.result_rows[0][0] or "0") asn_org = asn_result.result_rows[0][1] or "Unknown" # IPs du même subnet /24 subnet_ips_query = """ SELECT DISTINCT src_ip FROM ml_detected_anomalies WHERE toIPv4(src_ip) >= toIPv4(%(subnet_start)s) AND toIPv4(src_ip) <= toIPv4(%(subnet_end)s) AND detected_at >= now() - INTERVAL 24 HOUR ORDER BY src_ip """ subnet_result = db.query(subnet_ips_query, { "subnet_start": str(subnet.network_address), "subnet_end": str(subnet.broadcast_address) }) subnet_ips = [str(row[0]) for row in subnet_result.result_rows] # Total IPs du même ASN if asn_number != "0": asn_total_query = """ SELECT uniq(src_ip) FROM ml_detected_anomalies WHERE asn_number = %(asn_number)s AND detected_at >= now() - INTERVAL 24 HOUR """ asn_total_result = db.query(asn_total_query, {"asn_number": asn_number}) asn_total = asn_total_result.result_rows[0][0] if asn_total_result.result_rows else 0 else: asn_total = 0 return SubnetAnalysis( ip=ip, subnet=subnet_str, ips_in_subnet=subnet_ips, total_in_subnet=len(subnet_ips), asn_number=asn_number, asn_org=asn_org, total_in_asn=asn_total, alert=len(subnet_ips) > 10 ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") @router.get("/{ip}/country", response_model=dict) async def analyze_ip_country(ip: str): """ Analyse le pays d'une IP spécifique et la répartition des autres pays du même ASN """ try: # Pays de l'IP ip_country_query = """ SELECT country_code, asn_number FROM ml_detected_anomalies WHERE src_ip = %(ip)s ORDER BY detected_at DESC LIMIT 1 """ ip_result = db.query(ip_country_query, {"ip": ip}) if not ip_result.result_rows: return {"ip_country": None, "asn_countries": []} ip_country_code = ip_result.result_rows[0][0] asn_number = ip_result.result_rows[0][1] # Noms des pays country_names = { "CN": "China", "US": "United States", "DE": "Germany", "FR": "France", "RU": "Russia", "GB": "United Kingdom", "NL": "Netherlands", "IN": "India", "BR": "Brazil", "JP": "Japan", "KR": "South Korea", "IT": "Italy", "ES": "Spain", "CA": "Canada", "AU": "Australia" } # Répartition des autres pays du même ASN asn_countries_query = """ SELECT country_code, count() AS count FROM ml_detected_anomalies WHERE asn_number = %(asn_number)s AND detected_at >= now() - INTERVAL 24 HOUR GROUP BY country_code ORDER BY count DESC LIMIT 10 """ asn_result = db.query(asn_countries_query, {"asn_number": asn_number}) total = sum(row[1] for row in asn_result.result_rows) asn_countries = [ { "code": row[0], "name": country_names.get(row[0], row[0]), "count": row[1], "percentage": round((row[1] / total * 100), 2) if total > 0 else 0.0 } for row in asn_result.result_rows ] return { "ip_country": { "code": ip_country_code, "name": country_names.get(ip_country_code, ip_country_code) }, "asn_countries": asn_countries } except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") # ============================================================================= # ANALYSE PAYS # ============================================================================= @router.get("/country", response_model=CountryAnalysis) async def analyze_country(days: int = Query(1, ge=1, le=30)): """ Analyse la distribution des pays """ try: # Top pays top_query = """ SELECT country_code, count() AS count FROM ml_detected_anomalies WHERE detected_at >= now() - INTERVAL %(days)s DAY AND country_code != '' AND country_code IS NOT NULL GROUP BY country_code ORDER BY count DESC LIMIT 10 """ top_result = db.query(top_query, {"days": days}) # Calculer le total pour le pourcentage total = sum(row[1] for row in top_result.result_rows) # Noms des pays (mapping simple) country_names = { "CN": "China", "US": "United States", "DE": "Germany", "FR": "France", "RU": "Russia", "GB": "United Kingdom", "NL": "Netherlands", "IN": "India", "BR": "Brazil", "JP": "Japan", "KR": "South Korea", "IT": "Italy", "ES": "Spain", "CA": "Canada", "AU": "Australia" } top_countries = [ CountryData( code=row[0], name=country_names.get(row[0], row[0]), count=row[1], percentage=round((row[1] / total * 100), 2) if total > 0 else 0.0 ) for row in top_result.result_rows ] # Baseline (7 derniers jours) baseline_query = """ SELECT country_code, count() AS count FROM ml_detected_anomalies WHERE detected_at >= now() - INTERVAL 7 DAY AND country_code != '' AND country_code IS NOT NULL GROUP BY country_code ORDER BY count DESC LIMIT 5 """ baseline_result = db.query(baseline_query) baseline_total = sum(row[1] for row in baseline_result.result_rows) baseline = { row[0]: round((row[1] / baseline_total * 100), 2) if baseline_total > 0 else 0.0 for row in baseline_result.result_rows } # Détecter pays surreprésenté alert_country = None for country in top_countries: baseline_pct = baseline.get(country.code, 0) if baseline_pct > 0 and country.percentage > baseline_pct * 2 and country.percentage > 30: alert_country = country.code break return CountryAnalysis( top_countries=top_countries, baseline=baseline, alert_country=alert_country ) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") # ============================================================================= # ANALYSE JA4 # ============================================================================= @router.get("/{ip}/ja4", response_model=JA4Analysis) async def analyze_ja4(ip: str): """ Analyse le JA4 fingerprint """ try: # JA4 de cette IP ja4_query = """ SELECT ja4 FROM ml_detected_anomalies WHERE src_ip = %(ip)s AND ja4 != '' AND ja4 IS NOT NULL ORDER BY detected_at DESC LIMIT 1 """ ja4_result = db.query(ja4_query, {"ip": ip}) if not ja4_result.result_rows: return JA4Analysis( ja4="", shared_ips_count=0, top_subnets=[], other_ja4_for_ip=[] ) ja4 = ja4_result.result_rows[0][0] # IPs avec le même JA4 shared_query = """ SELECT uniq(src_ip) FROM ml_detected_anomalies WHERE ja4 = %(ja4)s AND detected_at >= now() - INTERVAL 24 HOUR """ shared_result = db.query(shared_query, {"ja4": ja4}) shared_count = shared_result.result_rows[0][0] if shared_result.result_rows else 0 # Top subnets pour ce JA4 - Simplifié subnets_query = """ SELECT src_ip, count() AS count FROM ml_detected_anomalies WHERE ja4 = %(ja4)s AND detected_at >= now() - INTERVAL 24 HOUR GROUP BY src_ip ORDER BY count DESC LIMIT 100 """ subnets_result = db.query(subnets_query, {"ja4": ja4}) # Grouper par subnet /24 from collections import defaultdict subnet_counts = defaultdict(int) for row in subnets_result.result_rows: ip_addr = row[0] parts = ip_addr.split('.') if len(parts) == 4: subnet = f"{parts[0]}.{parts[1]}.{parts[2]}.0/24" subnet_counts[subnet] += row[1] top_subnets = [ JA4SubnetData(subnet=subnet, count=count) for subnet, count in sorted(subnet_counts.items(), key=lambda x: x[1], reverse=True)[:10] ] # Autres JA4 pour cette IP other_ja4_query = """ SELECT DISTINCT ja4 FROM ml_detected_anomalies WHERE src_ip = %(ip)s AND ja4 != '' AND ja4 IS NOT NULL AND ja4 != %(current_ja4)s """ other_result = db.query(other_ja4_query, {"ip": ip, "current_ja4": ja4}) other_ja4 = [row[0] for row in other_result.result_rows] return JA4Analysis( ja4=ja4, shared_ips_count=shared_count, top_subnets=top_subnets, other_ja4_for_ip=other_ja4 ) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") # ============================================================================= # ANALYSE USER-AGENTS # ============================================================================= @router.get("/{ip}/user-agents", response_model=UserAgentAnalysis) async def analyze_user_agents(ip: str): """ Analyse les User-Agents """ try: # User-Agents pour cette IP (depuis http_logs) ip_ua_query = """ SELECT header_user_agent AS ua, count() AS count FROM mabase_prod.http_logs WHERE src_ip = %(ip)s AND header_user_agent != '' AND header_user_agent IS NOT NULL AND time >= now() - INTERVAL 24 HOUR GROUP BY ua ORDER BY count DESC LIMIT 10 """ ip_ua_result = db.query(ip_ua_query, {"ip": ip}) # Classification des UAs def classify_ua(ua: str) -> str: ua_lower = ua.lower() if any(bot in ua_lower for bot in ['bot', 'crawler', 'spider', 'curl', 'wget', 'python', 'requests', 'scrapy']): return 'bot' if any(script in ua_lower for script in ['python', 'java', 'php', 'ruby', 'perl', 'node']): return 'script' if not ua or ua.strip() == '': return 'script' return 'normal' # Calculer le total total_count = sum(row[1] for row in ip_ua_result.result_rows) ip_user_agents = [ UserAgentData( value=row[0], count=row[1], percentage=round((row[1] / total_count * 100), 2) if total_count > 0 else 0.0, classification=classify_ua(row[0]) ) for row in ip_ua_result.result_rows ] # Pour les UAs du JA4, on retourne les mêmes pour l'instant ja4_user_agents = ip_user_agents # Pourcentage de bots bot_count = sum(ua.count for ua in ip_user_agents if ua.classification in ['bot', 'script']) bot_percentage = (bot_count / total_count * 100) if total_count > 0 else 0 return UserAgentAnalysis( ip_user_agents=ip_user_agents, ja4_user_agents=ja4_user_agents, bot_percentage=bot_percentage, alert=bot_percentage > 20 ) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") # ============================================================================= # RECOMMANDATION DE CLASSIFICATION # ============================================================================= @router.get("/{ip}/recommendation", response_model=ClassificationRecommendation) async def get_classification_recommendation(ip: str): """ Génère une recommandation de classification basée sur les corrélations """ try: # Récupérer les analyses try: subnet_analysis = await analyze_subnet(ip) except: subnet_analysis = None try: country_analysis = await analyze_country(1) except: country_analysis = None try: ja4_analysis = await analyze_ja4(ip) except: ja4_analysis = None try: ua_analysis = await analyze_user_agents(ip) except: ua_analysis = None # Indicateurs par défaut indicators = CorrelationIndicators( subnet_ips_count=subnet_analysis.total_in_subnet if subnet_analysis else 0, asn_ips_count=subnet_analysis.total_in_asn if subnet_analysis else 0, country_percentage=0.0, ja4_shared_ips=ja4_analysis.shared_ips_count if ja4_analysis else 0, user_agents_count=len(ua_analysis.ja4_user_agents) if ua_analysis else 0, bot_ua_percentage=ua_analysis.bot_percentage if ua_analysis else 0.0 ) # Score de confiance score = 0.0 reasons = [] tags = [] # Subnet > 10 IPs if subnet_analysis and subnet_analysis.total_in_subnet > 10: score += 0.25 reasons.append(f"{subnet_analysis.total_in_subnet} IPs du même subnet") tags.append("distributed") # JA4 partagé > 50 IPs if ja4_analysis and ja4_analysis.shared_ips_count > 50: score += 0.25 reasons.append(f"{ja4_analysis.shared_ips_count} IPs avec même JA4") tags.append("ja4-rotation") # Bot UA > 20% if ua_analysis and ua_analysis.bot_percentage > 20: score += 0.25 reasons.append(f"{ua_analysis.bot_percentage:.0f}% UAs bots/scripts") tags.append("bot-ua") # Pays surreprésenté if country_analysis and country_analysis.alert_country: score += 0.15 reasons.append(f"Pays {country_analysis.alert_country} surreprésenté") tags.append(f"country-{country_analysis.alert_country.lower()}") # ASN hosting if subnet_analysis: hosting_keywords = ["ovh", "amazon", "aws", "google", "azure", "digitalocean", "linode", "vultr", "china169", "chinamobile"] if any(kw in (subnet_analysis.asn_org or "").lower() for kw in hosting_keywords): score += 0.10 tags.append("hosting-asn") # Déterminer label if score >= 0.7: label = ClassificationLabel.MALICIOUS tags.append("campaign") elif score >= 0.4: label = ClassificationLabel.SUSPICIOUS else: label = ClassificationLabel.LEGITIMATE reason = " | ".join(reasons) if reasons else "Aucun indicateur fort" return ClassificationRecommendation( label=label, confidence=min(score, 1.0), indicators=indicators, suggested_tags=tags, reason=reason ) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") # ============================================================================= # CLASSIFICATIONS CRUD # ============================================================================= @router.post("/classifications", response_model=Classification) async def create_classification(data: ClassificationCreate): """ Crée une classification pour une IP ou un JA4 """ try: # Validation: soit ip, soit ja4 doit être fourni if not data.ip and not data.ja4: raise HTTPException(status_code=400, detail="IP ou JA4 requis") query = """ INSERT INTO mabase_prod.classifications (ip, ja4, label, tags, comment, confidence, features, analyst, created_at) VALUES (%(ip)s, %(ja4)s, %(label)s, %(tags)s, %(comment)s, %(confidence)s, %(features)s, %(analyst)s, now()) """ db.query(query, { "ip": data.ip or "", "ja4": data.ja4 or "", "label": data.label.value, "tags": data.tags, "comment": data.comment, "confidence": data.confidence, "features": json.dumps(data.features), "analyst": data.analyst }) # Récupérer la classification créée where_clause = "ip = %(entity)s" if data.ip else "ja4 = %(entity)s" select_query = f""" SELECT ip, ja4, label, tags, comment, confidence, features, analyst, created_at FROM mabase_prod.classifications WHERE {where_clause} ORDER BY created_at DESC LIMIT 1 """ result = db.query(select_query, {"entity": data.ip or data.ja4}) if not result.result_rows: raise HTTPException(status_code=404, detail="Classification non trouvée") row = result.result_rows[0] return Classification( ip=row[0] or None, ja4=row[1] or None, label=ClassificationLabel(row[2]), tags=row[3], comment=row[4], confidence=row[5], features=json.loads(row[6]) if row[6] else {}, analyst=row[7], created_at=row[8] ) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") @router.get("/classifications", response_model=ClassificationsListResponse) async def list_classifications( ip: Optional[str] = Query(None, description="Filtrer par IP"), ja4: Optional[str] = Query(None, description="Filtrer par JA4"), label: Optional[str] = Query(None, description="Filtrer par label"), limit: int = Query(100, ge=1, le=1000) ): """ Liste les classifications """ try: where_clauses = ["1=1"] params = {"limit": limit} if ip: where_clauses.append("ip = %(ip)s") params["ip"] = ip if ja4: where_clauses.append("ja4 = %(ja4)s") params["ja4"] = ja4 if label: where_clauses.append("label = %(label)s") params["label"] = label where_clause = " AND ".join(where_clauses) query = f""" SELECT ip, ja4, label, tags, comment, confidence, features, analyst, created_at FROM mabase_prod.classifications WHERE {where_clause} ORDER BY created_at DESC LIMIT %(limit)s """ result = db.query(query, params) classifications = [ Classification( ip=row[0] or None, ja4=row[1] or None, label=ClassificationLabel(row[2]), tags=row[3], comment=row[4], confidence=row[5], features=json.loads(row[6]) if row[6] else {}, analyst=row[7], created_at=row[8] ) for row in result.result_rows ] # Total count_query = f""" SELECT count() FROM mabase_prod.classifications WHERE {where_clause} """ count_result = db.query(count_query, params) total = count_result.result_rows[0][0] if count_result.result_rows else 0 return ClassificationsListResponse( items=classifications, total=total ) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}") @router.get("/classifications/stats") async def get_classification_stats(): """ Statistiques des classifications """ try: stats_query = """ SELECT label, count() AS total, uniq(ip) AS unique_ips, avg(confidence) AS avg_confidence FROM mabase_prod.classifications GROUP BY label ORDER BY total DESC """ result = db.query(stats_query) stats = [ { "label": row[0], "total": row[1], "unique_ips": row[2], "avg_confidence": float(row[3]) if row[3] else 0.0 } for row in result.result_rows ] return {"stats": stats} except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}")