Files
radiacode/detect/radiacode_monitor.py
Jacquin Antoine 091d7d9eb8 Fix: mask channels below 30 keV in inference and training to prevent misidentification
Below ~30 keV the detector signal is dominated by X-ray fluorescence (L-shell)
and artifacts not modelled in training data. This spurious low-energy continuum
caused the model to misidentify Am-241 as Th-232/U-235. Masking channels <30 keV
before inference fixes Am-241 detection from 2% to 99%. Same masking applied in
the synthetic spectrum generator for consistent retraining.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-21 21:55:34 +02:00

462 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Radiacode 103 — Identification automatique d'isotopes
Cycle de 24h avec détection branché/débranché
Fonctionne en Docker sur machine GPU (dev) ou RPi 4 (production)
"""
import numpy as np
import torch
import time
import json
import logging
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
# Configuration via variables d'environnement
MODEL_PATH = os.environ.get("MODEL_PATH", "/models/vega_best.pt")
ISOTOPE_INDEX_PATH = os.environ.get("ISOTOPE_INDEX_PATH", "/models/vega_isotope_index.txt")
BACKGROUND_PATH = os.environ.get("BACKGROUND_PATH", "/data/background_24h.npy")
LOG_DIR = Path(os.environ.get("LOG_DIR", "/logs"))
LOG_DIR.mkdir(parents=True, exist_ok=True)
THRESHOLD = float(os.environ.get("THRESHOLD", "0.5"))
SAMPLE_INTERVAL = int(os.environ.get("SAMPLE_INTERVAL", "60"))
REPORT_HOUR = int(os.environ.get("REPORT_HOUR", "0"))
MIN_LIVE_TIME = int(os.environ.get("MIN_LIVE_TIME", "3600"))
STATE_PATH = os.environ.get("STATE_PATH", "/data/monitor_state.json")
CPS_LOG_PATH = os.environ.get("CPS_LOG_PATH", "/data/cps_log.jsonl")
ENERGY_OFFSET = float(os.environ.get("ENERGY_CALIBRATION_OFFSET", "0.33"))
ENERGY_SLOPE = float(os.environ.get("ENERGY_CALIBRATION_SLOPE", "2.97"))
# CsI(Tl) non-linear response correction
# CsI(Tl) produces more light per keV at low energies, shifting peaks to higher
# apparent energies. Model: E_apparent = E_true * (1 + alpha * exp(-E_true/beta))
# Calibrated from Am-241 (59.5 keV appears at ~71.6 keV) and K-40 (correct at 1460.8 keV).
CSI_NONLINEAR_ALPHA = float(os.environ.get("CSI_NONLINEAR_ALPHA", "0.37"))
CSI_NONLINEAR_BETA = float(os.environ.get("CSI_NONLINEAR_BETA", "100.0"))
# Minimum energy for inference — channels below this are masked to zero.
# Below ~30 keV the signal is dominated by X-ray fluorescence and detector
# artifacts not modelled in training data, causing misidentifications.
MIN_ENERGY_KEV = float(os.environ.get("MIN_ENERGY_KEV", "30.0"))
MIN_CHANNEL = max(0, int((MIN_ENERGY_KEV - ENERGY_OFFSET) / ENERGY_SLOPE))
HOURLY_DIR = Path(os.environ.get("HOURLY_DIR", "/data/hourly"))
HOURLY_DIR.mkdir(parents=True, exist_ok=True)
def correct_csilinear_energy(spectrum_rate, num_channels=1023):
"""Apply inverse CsI(Tl) non-linear response correction to spectrum channels.
CsI(Tl) has non-proportional scintillation response at low energies,
causing peaks to appear at higher channels than their true energy position.
This function remaps channels so that peaks appear at their theoretical
energy positions, matching what the model was trained on.
For each output channel j (true energy position), we find the input
channel i (apparent energy position) where the detector actually placed
counts for that true energy.
Args:
spectrum_rate: Array of 1023 channel count rates
num_channels: Number of channels
Returns:
Corrected spectrum with peaks at theoretical energy positions
"""
alpha = CSI_NONLINEAR_ALPHA
beta = CSI_NONLINEAR_BETA
# For each output channel j, compute the apparent energy where
# counts for true energy E_true(j) actually appear
output_channels = np.arange(num_channels, dtype=np.float64)
e_true = ENERGY_OFFSET + ENERGY_SLOPE * output_channels
# Forward model: E_apparent = E_true * (1 + alpha * exp(-E_true / beta))
e_apparent = e_true * (1 + alpha * np.exp(-e_true / beta))
# Input channel where the detector placed counts for this true energy
source_channels = (e_apparent - ENERGY_OFFSET) / ENERGY_SLOPE
source_channels = np.clip(source_channels, 0, num_channels - 1.001)
# Linear interpolation from source channels
lower = np.floor(source_channels).astype(int)
upper = np.minimum(lower + 1, num_channels - 1)
frac = source_channels - lower
corrected = spectrum_rate[lower] * (1 - frac) + spectrum_rate[upper] * frac
return corrected
# Logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(LOG_DIR / "radiacode.log"),
],
)
log = logging.getLogger(__name__)
class RadiacodeMonitor:
def __init__(self):
# Charger le modèle PyTorch
device_str = os.environ.get("VEGA_DEVICE", "cpu")
self.torch_device = torch.device(device_str)
log.info(f"Chargement du modèle depuis {MODEL_PATH} sur {self.torch_device}...")
checkpoint = torch.load(MODEL_PATH, map_location=self.torch_device, weights_only=False)
# Importer VegaModel (depuis le volume monté)
vega_ml_path = os.environ.get("VEGA_ML_PATH", "/models/vega_ml")
if vega_ml_path not in sys.path:
sys.path.insert(0, vega_ml_path)
from training.vega.model import VegaModel, VegaConfig
from training.vega.isotope_index import IsotopeIndex
self.model_config = VegaConfig(**checkpoint["model_config"])
self.model = VegaModel(self.model_config)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
log.info(
f"Modèle chargé : {self.model_config.num_isotopes} isotopes, "
f"{self.model.count_parameters():,} paramètres"
)
# Charger l'index des isotopes
self.isotope_index = IsotopeIndex.load(Path(ISOTOPE_INDEX_PATH))
# Charger le bruit de fond de référence
self.bg_counts = None
self.bg_live_time = None
bg_path = Path(BACKGROUND_PATH)
if bg_path.exists():
bg_data = np.load(str(bg_path), allow_pickle=True).item()
self.bg_counts = bg_data["counts"].astype(np.float64)
self.bg_live_time = float(bg_data["duration"])
log.info(
f"Background chargé : {self.bg_live_time/3600:.1f}h, "
f"{self.bg_counts.sum():.0f} coups"
)
else:
log.warning(f"Pas de fichier background : {BACKGROUND_PATH}")
# Connexion persistante au Radiacode
self._rc = None
self.reconnect_backoff = 0
# Compteurs cumulés
self.cumulated_counts = np.zeros(1024, dtype=np.float64)
self.cumulated_live_time = 0.0
self.last_report_date = None
self.connected = False
# Suivi horaire
self.current_hour = datetime.now().hour
self.hourly_counts = np.zeros(1024, dtype=np.float64)
self.hourly_live_time = 0.0
# Restaurer l'état si disponible et du même jour
state_path = Path(STATE_PATH)
if state_path.exists():
try:
with open(state_path) as f:
saved = json.load(f)
saved_ts = datetime.fromisoformat(saved["timestamp"])
now = datetime.now()
same_day = saved_ts.date() == now.date()
yesterday_before_report = (
now.hour < REPORT_HOUR
and saved_ts.date() == (now.date() - timedelta(days=1))
)
if same_day or yesterday_before_report:
counts = saved.get("counts", [])
if len(counts) == 1024:
self.cumulated_counts = np.array(counts, dtype=np.float64)
self.cumulated_live_time = float(saved.get("cumulated_live_time_s", 0))
log.info(
f"État restauré : {self.cumulated_live_time/3600:.1f}h, "
f"{self.cumulated_counts.sum():.0f} coups"
)
else:
log.info("État sauvegardé d'un jour précédent, redémarrage à zéro")
except (json.JSONDecodeError, OSError, KeyError) as e:
log.warning(f"Impossible de restaurer l'état : {e}")
def _connect(self):
"""Tente d'établir une connexion persistante au Radiacode."""
try:
from radiacode import RadiaCode
self._rc = RadiaCode()
self.connected = True
self.reconnect_backoff = 0
log.info("Radiacode connecté")
return True
except Exception as e:
self._rc = None
self.connected = False
self.reconnect_backoff = min(self.reconnect_backoff + 1, 10)
log.debug(f"Détecteur non disponible (retry dans {self.reconnect_backoff} cycles) : {e}")
return False
def _disconnect(self):
"""Ferme la connexion au Radiacode."""
if self._rc is not None:
try:
del self._rc
except Exception:
pass
self._rc = None
self.connected = False
def sample_once(self):
"""Échantillonne une fois. Retourne True si succès."""
# Établir la connexion si nécessaire
if self._rc is None:
if self.reconnect_backoff > 0:
self.reconnect_backoff -= 1
return False
if not self._connect():
return False
try:
spectrum = self._rc.spectrum()
counts = np.array(spectrum.counts, dtype=np.float64)
live_time = spectrum.duration.total_seconds()
if live_time > 0 and counts.sum() > 0:
self.cumulated_counts += counts
self.cumulated_live_time += live_time
self.hourly_counts += counts
self.hourly_live_time += live_time
self._rc.spectrum_reset()
log.info(
f"Échantillon : {counts.sum():.0f} coups en {live_time:.1f}s "
f"(cumul : {self.cumulated_live_time/3600:.1f}h)"
)
return True
return False
except Exception as e:
log.warning(f"Erreur échantillonnage, reconnexion : {e}")
self._disconnect()
return False
def save_state(self):
"""Ecrit l'etat actuel du moniteur dans un fichier JSON atomique."""
energy_kev = [round(ENERGY_OFFSET + ENERGY_SLOPE * i, 2) for i in range(1024)]
cps = float(self.cumulated_counts.sum() / self.cumulated_live_time) if self.cumulated_live_time > 0 else 0.0
isotopes = []
if self.cumulated_live_time > 0:
rate = self.cumulated_counts / self.cumulated_live_time
if self.bg_counts is not None and self.bg_live_time is not None:
bg_rate = self.bg_counts[:1023] / self.bg_live_time
net_rate = np.clip(rate[:1023] - bg_rate, 0, None)
else:
net_rate = rate[:1023]
net_rate[:MIN_CHANNEL] = 0
isotopes = self.run_inference(net_rate)
state = {
"timestamp": datetime.now().isoformat(),
"connected": self.connected,
"cumulated_live_time_s": round(self.cumulated_live_time, 1),
"cumulated_live_time_h": round(self.cumulated_live_time / 3600, 2),
"total_counts": int(self.cumulated_counts.sum()),
"cps": round(cps, 2),
"background_subtracted": self.bg_counts is not None,
"isotopes_detected": isotopes,
"energy_kev": energy_kev,
"counts": [round(float(c), 1) for c in self.cumulated_counts],
}
state_path = Path(STATE_PATH)
state_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = state_path.with_suffix(".tmp")
with open(tmp_path, "w") as f:
json.dump(state, f)
os.replace(tmp_path, state_path)
def log_cps(self, counts, live_time):
"""Ajoute un point CPS au journal horodaté."""
cps = float(counts.sum() / live_time) if live_time > 0 else 0.0
entry = {
"ts": datetime.now().isoformat(),
"cps": round(cps, 2),
"live_time_s": round(live_time, 1),
"total_counts": int(counts.sum()),
}
log_path = Path(CPS_LOG_PATH)
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a") as f:
f.write(json.dumps(entry) + "\n")
def run_inference(self, spectrum_rate):
"""Exécute l'inférence PyTorch sur le spectre cumulé."""
if spectrum_rate.max() > 0:
# Apply CsI(Tl) non-linear correction so peaks appear
# at theoretical energy positions (matching training data)
corrected = correct_csilinear_energy(spectrum_rate)
log_spectrum = np.log1p(np.maximum(corrected, 0))
normalized = log_spectrum / log_spectrum.max()
else:
return []
tensor = torch.tensor(normalized, dtype=torch.float32).unsqueeze(0).to(self.torch_device)
with torch.no_grad():
logits, activities = self.model(tensor)
probs = torch.sigmoid(logits).cpu().numpy()[0]
activities = activities.cpu().numpy()[0] * self.model_config.max_activity_bq
results = []
for i in range(len(probs)):
if probs[i] >= THRESHOLD:
results.append(
{
"isotope": self.isotope_index.index_to_name(i),
"probability": float(probs[i]),
"activity_bq": float(activities[i]),
}
)
return sorted(results, key=lambda x: -x["probability"])
def generate_report(self):
"""Génère le rapport quotidien."""
if self.cumulated_live_time < MIN_LIVE_TIME:
log.warning(
f"Pas assez de données ({self.cumulated_live_time/3600:.1f}h < "
f"{MIN_LIVE_TIME/3600:.1f}h minimum). Pas de rapport."
)
return
rate = self.cumulated_counts / self.cumulated_live_time
if self.bg_counts is not None and self.bg_live_time is not None:
bg_rate = self.bg_counts[:1023] / self.bg_live_time
net_rate = np.clip(rate[:1023] - bg_rate, 0, None)
else:
net_rate = rate[:1023]
net_rate[:MIN_CHANNEL] = 0
results = self.run_inference(net_rate)
now = datetime.now()
report = {
"date": now.isoformat(),
"live_time_hours": self.cumulated_live_time / 3600,
"total_counts": int(self.cumulated_counts.sum()),
"cps_mean": float(self.cumulated_counts.sum() / self.cumulated_live_time),
"background_subtracted": self.bg_counts is not None,
"isotopes_detected": results,
}
report_path = LOG_DIR / f"report_{now.strftime('%Y-%m-%d')}.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2, ensure_ascii=False)
# Affichage
print(f"\n{'='*50}")
print(f" RAPPORT — {now.strftime('%d/%m/%Y')}")
print(f"{'='*50}")
print(f" Live time : {self.cumulated_live_time/3600:.1f}h")
print(f" Comptages : {self.cumulated_counts.sum():.0f}")
print(f" CPS moyen : {self.cumulated_counts.sum()/self.cumulated_live_time:.1f}")
print(
f" Background : {'soustrait' if self.bg_counts is not None else 'non soustrait'}"
)
print()
if results:
for r in results:
print(
f" {r['isotope']:>10s} : {r['probability']*100:5.1f}% — {r['activity_bq']:.1f} Bq"
)
else:
print(" (background uniquement)")
print(f"{'='*50}\n")
log.info(f"Rapport sauvegardé : {report_path}")
# Reset pour le cycle suivant
self.cumulated_counts = np.zeros(1024, dtype=np.float64)
self.cumulated_live_time = 0.0
def save_hourly_snapshot(self):
"""Sauvegarde le spectre accumulé pendant l'heure écoulée."""
if self.hourly_live_time < 1.0:
return
now = datetime.now()
cps = float(self.hourly_counts.sum() / self.hourly_live_time) if self.hourly_live_time > 0 else 0.0
energy_kev = [round(ENERGY_OFFSET + ENERGY_SLOPE * i, 2) for i in range(1024)]
snapshot = {
"timestamp": now.replace(minute=0, second=0, microsecond=0).isoformat(),
"date": now.strftime("%Y-%m-%d"),
"hour": self.current_hour,
"live_time_s": round(self.hourly_live_time, 1),
"total_counts": int(self.hourly_counts.sum()),
"cps": round(cps, 2),
"energy_kev": energy_kev,
"counts": [round(float(c), 1) for c in self.hourly_counts],
}
filename = f"{now.strftime('%Y-%m-%d')}_{self.current_hour:02d}.json"
filepath = HOURLY_DIR / filename
tmp_path = filepath.with_suffix(".tmp")
with open(tmp_path, "w") as f:
json.dump(snapshot, f)
os.replace(tmp_path, filepath)
log.info(f"Snapshot horaire sauvegardé : {filename} ({self.hourly_live_time/3600:.2f}h)")
def _check_hour_rollover(self):
"""Vérifie le changement d'heure, sauvegarde et réinitialise les compteurs horaires."""
now = datetime.now()
current_hour = now.hour
if current_hour != self.current_hour:
self.save_hourly_snapshot()
self.hourly_counts = np.zeros(1024, dtype=np.float64)
self.hourly_live_time = 0.0
self.current_hour = current_hour
def run(self):
"""Boucle principale."""
log.info("=" * 50)
log.info("Radiacode 103 — Moniteur d'isotopes")
log.info("=" * 50)
log.info(f"Modèle : {MODEL_PATH}")
log.info(f"Device : {self.torch_device}")
log.info(f"Isotopes : {self.isotope_index.num_isotopes}")
log.info(
f"Background : {'chargé' if self.bg_counts is not None else 'non disponible'}"
)
log.info(f"Seuil : {THRESHOLD}")
log.info(f"Intervalle : {SAMPLE_INTERVAL}s")
while True:
now = datetime.now()
if self.last_report_date != now.date() and now.hour == REPORT_HOUR:
self.generate_report()
self.last_report_date = now.date()
success = self.sample_once()
if success:
self.save_state()
self.log_cps(self.cumulated_counts, self.cumulated_live_time)
else:
self.save_state()
self._check_hour_rollover()
time.sleep(SAMPLE_INTERVAL)
if __name__ == "__main__":
monitor = RadiacodeMonitor()
monitor.run()