Files
radiacode/train/vega_ml/inference/vega_inference.py
Jacquin Antoine 745a64b342 Pipeline complet Radiacode 103 - identification automatique d'isotopes
- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89%
- Generation 50k spectres synthetiques 1D (12-24h durees)
- Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell)
- Detection continue avec soustraction du background
- Capture background 24h avec gestion deconnexion
- Docker Compose : conteneur train (GPU) + detect (CPU/USB)
- Modele entraite inclus (vega_best.pt, 395 Mo)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 12:29:56 +02:00

407 lines
13 KiB
Python

"""
Vega Inference Script
Load a trained Vega model and run inference on gamma spectra to identify
isotopes and estimate their activities.
"""
import sys
import json
import numpy as np
import torch
from pathlib import Path
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, asdict
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from training.vega.model import VegaModel, VegaConfig
from training.vega.isotope_index import IsotopeIndex
@dataclass
class IsotopePrediction:
"""Prediction for a single isotope."""
name: str
probability: float
activity_bq: float
present: bool
@dataclass
class SpectrumPrediction:
"""Full prediction results for a spectrum."""
isotopes: List[IsotopePrediction]
num_present: int
confidence: float
threshold_used: float
def to_dict(self) -> Dict:
"""Convert to dictionary."""
return {
'isotopes': [
{
'name': iso.name,
'probability': round(iso.probability, 4),
'activity_bq': round(iso.activity_bq, 2),
'present': iso.present
}
for iso in self.isotopes
],
'num_isotopes_detected': self.num_present,
'confidence': round(self.confidence, 4),
'threshold': self.threshold_used
}
def get_present_isotopes(self) -> List[IsotopePrediction]:
"""Get only isotopes predicted as present."""
return [iso for iso in self.isotopes if iso.present]
def summary(self) -> str:
"""Get a human-readable summary."""
present = self.get_present_isotopes()
if not present:
return "No isotopes detected above threshold"
lines = [f"Detected {len(present)} isotope(s):"]
for iso in sorted(present, key=lambda x: -x.probability):
lines.append(
f" - {iso.name}: {iso.probability*100:.1f}% confidence, "
f"{iso.activity_bq:.1f} Bq"
)
return "\n".join(lines)
class VegaInference:
"""
Inference engine for the Vega model.
Loads a trained model and provides methods for running predictions
on gamma spectra.
"""
def __init__(
self,
model_path: Union[str, Path],
isotope_index_path: Optional[Union[str, Path]] = None,
device: Optional[torch.device] = None
):
"""
Initialize the inference engine.
Args:
model_path: Path to the saved model checkpoint
isotope_index_path: Path to isotope index file. If None, will try
to find it in the same directory as the model.
device: Device to run inference on. If None, uses CUDA if available.
"""
self.model_path = Path(model_path)
# Determine device with CUDA compatibility test
if device is not None:
self.device = device
elif torch.cuda.is_available():
# Test if CUDA actually works (RTX 5090/Blackwell may not be compatible)
try:
test_tensor = torch.zeros(1, device='cuda')
_ = test_tensor + 1
self.device = torch.device('cuda')
print("Using CUDA for inference")
except RuntimeError as e:
if "no kernel image is available" in str(e):
print(f"CUDA device detected but not compatible (likely Blackwell arch)")
print("Falling back to CPU for inference")
self.device = torch.device('cpu')
else:
raise
else:
self.device = torch.device('cpu')
print("Using CPU for inference")
# Load checkpoint
print(f"Loading model from: {self.model_path}")
self.checkpoint = torch.load(self.model_path, map_location=self.device)
# Load model config
model_config_dict = self.checkpoint['model_config']
self.model_config = VegaConfig(**model_config_dict)
# Create and load model
self.model = VegaModel(self.model_config)
self.model.load_state_dict(self.checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
# Load isotope index
if isotope_index_path is None:
# Try to find in same directory
isotope_index_path = self.model_path.parent / "vega_isotope_index.txt"
if Path(isotope_index_path).exists():
self.isotope_index = IsotopeIndex.load(Path(isotope_index_path))
else:
# Use default
from training.vega.isotope_index import get_default_isotope_index
self.isotope_index = get_default_isotope_index()
print("Warning: Using default isotope index")
print(f"Model loaded successfully!")
print(f"Device: {self.device}")
print(f"Isotopes: {self.isotope_index.num_isotopes}")
def preprocess_spectrum(
self,
spectrum: np.ndarray,
normalize: bool = True
) -> torch.Tensor:
"""
Preprocess a spectrum for inference.
Args:
spectrum: Input spectrum array. Can be:
- 1D: (channels,) - single spectrum
- 2D: (time, channels) - will be averaged over time
normalize: Whether to normalize to [0, 1]
Returns:
Preprocessed tensor ready for model
"""
# Handle 2D spectra
if spectrum.ndim == 2:
spectrum = spectrum.mean(axis=0)
# Normalize
if normalize and spectrum.max() > 0:
spectrum = spectrum / spectrum.max()
# Convert to tensor
tensor = torch.tensor(spectrum, dtype=torch.float32)
# Add batch dimension
tensor = tensor.unsqueeze(0)
return tensor.to(self.device)
@torch.no_grad()
def predict(
self,
spectrum: Union[np.ndarray, torch.Tensor],
threshold: float = 0.5,
return_all: bool = False
) -> SpectrumPrediction:
"""
Run inference on a spectrum.
Args:
spectrum: Input spectrum (numpy array or tensor)
threshold: Probability threshold for considering an isotope present
return_all: If True, include all isotopes in output. If False,
only include those above threshold.
Returns:
SpectrumPrediction with isotope predictions
"""
# Preprocess if numpy
if isinstance(spectrum, np.ndarray):
spectrum = self.preprocess_spectrum(spectrum)
# Run model (outputs logits)
logits, activities = self.model(spectrum)
# Apply sigmoid to get probabilities
probs = torch.sigmoid(logits)
# Convert to numpy
probs = probs.cpu().numpy()[0]
activities = activities.cpu().numpy()[0]
# Scale activities
activities = activities * self.model_config.max_activity_bq
# Create predictions
isotopes = []
for i in range(len(probs)):
prob = float(probs[i])
activity = float(activities[i])
present = prob >= threshold
if return_all or present:
isotopes.append(IsotopePrediction(
name=self.isotope_index.index_to_name(i),
probability=prob,
activity_bq=activity if present else 0.0,
present=present
))
# Calculate overall confidence (average of top predictions)
present_isotopes = [iso for iso in isotopes if iso.present]
if present_isotopes:
confidence = np.mean([iso.probability for iso in present_isotopes])
else:
confidence = 1.0 - probs.max() # Confidence in "background only"
return SpectrumPrediction(
isotopes=isotopes,
num_present=len(present_isotopes),
confidence=float(confidence),
threshold_used=threshold
)
def predict_batch(
self,
spectra: List[np.ndarray],
threshold: float = 0.5
) -> List[SpectrumPrediction]:
"""
Run inference on multiple spectra.
Args:
spectra: List of spectrum arrays
threshold: Probability threshold
Returns:
List of predictions
"""
return [self.predict(s, threshold) for s in spectra]
def predict_from_file(
self,
file_path: Union[str, Path],
threshold: float = 0.5
) -> SpectrumPrediction:
"""
Load a spectrum from a numpy file and run inference.
Args:
file_path: Path to .npy file
threshold: Probability threshold
Returns:
SpectrumPrediction
"""
spectrum = np.load(file_path)
return self.predict(spectrum, threshold)
def run_inference_demo(
model_path: str,
data_dir: str,
threshold: float = 0.5
):
"""
Demo function to run inference on test data.
Args:
model_path: Path to model checkpoint
data_dir: Path to data directory with spectra
threshold: Detection threshold
"""
# Initialize inference engine
inference = VegaInference(model_path)
# Find spectra files
data_path = Path(data_dir)
spectra_dir = data_path / "spectra"
if not spectra_dir.exists():
print(f"Spectra directory not found: {spectra_dir}")
return
# Load labels for comparison
labels_path = data_path / "labels.json"
with open(labels_path, 'r') as f:
labels = json.load(f)
print("\n" + "=" * 70)
print("VEGA INFERENCE DEMO")
print("=" * 70)
# Process each spectrum
npy_files = list(spectra_dir.glob("*.npy"))
print(f"\nFound {len(npy_files)} spectra to process\n")
for npy_file in npy_files:
# Extract sample ID from filename
sample_id = npy_file.stem.replace("spectrum_", "")
# Get ground truth
if sample_id in labels['samples']:
ground_truth = labels['samples'][sample_id]
true_isotopes = ground_truth['isotopes']
true_activities = ground_truth.get('source_activities_bq', {})
else:
true_isotopes = []
true_activities = {}
# Run prediction
prediction = inference.predict_from_file(npy_file, threshold=threshold)
# Display results
print("-" * 70)
print(f"Sample: {sample_id}")
print(f"Ground Truth Isotopes: {true_isotopes if true_isotopes else 'Background only'}")
if true_activities:
activities_str = ", ".join(
f"{k}: {v:.1f} Bq" for k, v in true_activities.items()
)
print(f"Ground Truth Activities: {activities_str}")
print(f"\nPrediction:")
print(prediction.summary())
# Compare
predicted_names = {iso.name for iso in prediction.get_present_isotopes()}
true_names = set(true_isotopes)
correct = predicted_names & true_names
missed = true_names - predicted_names
false_positives = predicted_names - true_names
if correct:
print(f"\n✓ Correctly identified: {correct}")
if missed:
print(f"✗ Missed: {missed}")
if false_positives:
print(f"! False positives: {false_positives}")
print()
print("=" * 70)
print("Inference complete!")
print("=" * 70)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run Vega model inference")
parser.add_argument(
"--model", "-m",
type=str,
default="models/vega_best.pt",
help="Path to model checkpoint"
)
parser.add_argument(
"--data", "-d",
type=str,
default="O:/master_data_collection/isotopev2",
help="Path to data directory"
)
parser.add_argument(
"--threshold", "-t",
type=float,
default=0.5,
help="Detection threshold (0-1)"
)
args = parser.parse_args()
# Make paths absolute if needed
project_root = Path(__file__).parent.parent
model_path = args.model if Path(args.model).is_absolute() else project_root / args.model
data_path = args.data if Path(args.data).is_absolute() else project_root / args.data
run_inference_demo(str(model_path), str(data_path), args.threshold)