#!/usr/bin/env python3 """Test isotope detection: vega_best vs vega_final, with/without background subtraction.""" import os import sys import json import numpy as np import torch from pathlib import Path # Paths (container-mounted) MODELS_DIR = Path(os.environ.get("MODELS_DIR", "/models")) DATA_DIR = Path(os.environ.get("DATA_DIR", "/data")) VEGA_ML_PATH = Path(os.environ.get("VEGA_ML_PATH", "/models/vega_ml")) # Add vega_ml to path sys.path.insert(0, str(VEGA_ML_PATH)) from training.vega.model import VegaModel, VegaConfig from training.vega.isotope_index import IsotopeIndex # Energy calibration ENERGY_OFFSET = 0.33 ENERGY_SLOPE = 2.97 THRESHOLD = 0.5 # CsI(Tl) non-linear response correction CSI_NONLINEAR_ALPHA = 0.37 CSI_NONLINEAR_BETA = 100.0 def correct_csilinear_energy(spectrum_rate, num_channels=1023): """Apply inverse CsI(Tl) non-linear response correction. Remaps channels so peaks appear at theoretical energy positions (matching training data), correcting for the detector's non-proportional scintillation response that shifts low-energy peaks upward. """ alpha = CSI_NONLINEAR_ALPHA beta = CSI_NONLINEAR_BETA output_channels = np.arange(num_channels, dtype=np.float64) e_true = ENERGY_OFFSET + ENERGY_SLOPE * output_channels # Forward model: E_apparent = E_true * (1 + alpha * exp(-E_true / beta)) e_apparent = e_true * (1 + alpha * np.exp(-e_true / beta)) # Input channel where detector placed counts for this true energy source_channels = (e_apparent - ENERGY_OFFSET) / ENERGY_SLOPE source_channels = np.clip(source_channels, 0, num_channels - 1.001) lower = np.floor(source_channels).astype(int) upper = np.minimum(lower + 1, num_channels - 1) frac = source_channels - lower corrected = spectrum_rate[lower] * (1 - frac) + spectrum_rate[upper] * frac return corrected def load_model(model_path): """Load a VegaModel checkpoint.""" device = torch.device("cpu") checkpoint = torch.load(model_path, map_location=device, weights_only=False) config = VegaConfig(**checkpoint["model_config"]) model = VegaModel(config) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, config def run_inference(model, config, isotope_index, spectrum_rate, threshold=THRESHOLD, apply_correction=True): """Run inference on a spectrum rate array (1023 channels).""" if spectrum_rate.max() <= 0: return [] # Apply CsI(Tl) non-linear correction so peaks match training data positions if apply_correction: spectrum_rate = correct_csilinear_energy(spectrum_rate) log_spectrum = np.log1p(np.maximum(spectrum_rate, 0)) max_val = log_spectrum.max() normalized = log_spectrum / max_val if max_val > 0 else log_spectrum tensor = torch.tensor(normalized, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): logits, activities = model(tensor) probs = torch.sigmoid(logits).numpy()[0] activities = activities.numpy()[0] * config.max_activity_bq results = [] for i in range(len(probs)): if probs[i] >= threshold: results.append({ "isotope": isotope_index.index_to_name(i), "probability": float(probs[i]), "activity_bq": float(activities[i]), }) return sorted(results, key=lambda x: -x["probability"]) def main(): # Load isotope index isotope_index = IsotopeIndex.load(MODELS_DIR / "vega_isotope_index.txt") print(f"Isotope index: {isotope_index.num_isotopes} isotopes\n") # Load monitor state (real spectrum from detector) with open(DATA_DIR / "monitor_state.json") as f: state = json.load(f) counts = np.array(state["counts"], dtype=np.float64) live_time = state["cumulated_live_time_s"] print(f"Spectre reel : {live_time:.0f}s live time, {counts.sum():.0f} coups, {len(counts)} canaux") print(f"CPS : {state['cps']:.2f}") # Load background bg_data = np.load(DATA_DIR / "background_24h.npy", allow_pickle=True).item() bg_counts = bg_data["counts"].astype(np.float64) bg_live_time = float(bg_data["duration"]) print(f"Background : {bg_live_time/3600:.1f}h, {bg_counts.sum():.0f} coups\n") # Prepare spectra rate = counts[:1023] / live_time bg_rate = bg_counts[:1023] / bg_live_time net_rate = np.clip(rate - bg_rate, 0, None) # Apply CsI correction to show peak positions corrected_rate = correct_csilinear_energy(rate) corrected_net = correct_csilinear_energy(net_rate) print("=" * 70) print(f" Sans correction CsI:") print(f" Canal max (brut) : {rate.argmax():>4d} ({ENERGY_OFFSET + ENERGY_SLOPE * rate.argmax():.1f} keV)") print(f" Canal max (net) : {net_rate.argmax():>4d} ({ENERGY_OFFSET + ENERGY_SLOPE * net_rate.argmax():.1f} keV)") print(f" Avec correction CsI:") print(f" Canal max (brut) : {corrected_rate.argmax():>4d} ({ENERGY_OFFSET + ENERGY_SLOPE * corrected_rate.argmax():.1f} keV)") print(f" Canal max (net) : {corrected_net.argmax():>4d} ({ENERGY_OFFSET + ENERGY_SLOPE * corrected_net.argmax():.1f} keV)") print(f" Rate max (brut) : {rate.max():.2f} cps") print(f" Rate max (net) : {net_rate.max():.2f} cps") print("=" * 70) # Am-241 should be at 59.5 keV → ch ~20 print(f"\n Am-241 region (59.5 keV) apres correction CsI:") for ch in range(16, 26): e = ENERGY_OFFSET + ENERGY_SLOPE * ch print(f" ch {ch:3d} ({e:5.1f} keV): brut={corrected_rate[ch]:.5f} net={corrected_net[ch]:.5f}") # Load both models models = { "vega_best": load_model(MODELS_DIR / "vega_best.pt"), "vega_final": load_model(MODELS_DIR / "vega_final.pt"), } scenarios = { "brut (sans soustraction)": rate, "net (avec soustraction bg)": net_rate, } for model_name, (model, config) in models.items(): print(f"\n{'─' * 70}") print(f" Modele : {model_name}") print(f"{'─' * 70}") for scenario_name, spectrum in scenarios.items(): print(f"\n Scenario : {scenario_name}") results = run_inference(model, config, isotope_index, spectrum) if results: print(f" {'Isotope':>10s} {'Probabilite':>12s} {'Activite (Bq)':>15s}") print(f" {'─'*10} {'─'*12} {'─'*15}") for r in results: print(f" {r['isotope']:>10s} {r['probability']*100:>11.1f}% {r['activity_bq']:>15.1f}") else: print(f" Aucun isotope detecte (seuil = {THRESHOLD})") # Also show top-10 probabilities below threshold for context print(f"\n{'═' * 70}") print(" Top-10 probabilites (tous scenarios, meme sous le seuil)") print(f"{'═' * 70}") for model_name, (model, config) in models.items(): print(f"\n Modele : {model_name}") for scenario_name, spectrum in scenarios.items(): if spectrum.max() <= 0: continue # Apply CsI correction before inference corrected = correct_csilinear_energy(spectrum) log_spectrum = np.log1p(np.maximum(corrected, 0)) max_val = log_spectrum.max() normalized = log_spectrum / max_val if max_val > 0 else log_spectrum tensor = torch.tensor(normalized, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): logits, _ = model(tensor) probs = torch.sigmoid(logits).numpy()[0] top10 = np.argsort(probs)[::-1][:10] print(f"\n {scenario_name} :") for idx in top10: name = isotope_index.index_to_name(idx) prob = probs[idx] marker = " *" if prob >= THRESHOLD else "" print(f" {name:>10s} : {prob*100:6.2f}%{marker}") if __name__ == "__main__": main()