Files
radiacode/train/vega_ml/training/vega/train_2d.py
Jacquin Antoine 745a64b342 Pipeline complet Radiacode 103 - identification automatique d'isotopes
- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89%
- Generation 50k spectres synthetiques 1D (12-24h durees)
- Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell)
- Detection continue avec soustraction du background
- Capture background 24h avec gestion deconnexion
- Docker Compose : conteneur train (GPU) + detect (CPU/USB)
- Modele entraite inclus (vega_best.pt, 395 Mo)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 12:29:56 +02:00

412 lines
13 KiB
Python

"""
Training Script for Vega 2D Model
Uses 2D convolutions to process gamma spectra with temporal information.
"""
import argparse
import json
import time
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from .model_2d import Vega2DModel, Vega2DConfig, count_parameters
from .dataset_2d import create_data_loaders_2d, SpectrumDataset2D
from .isotope_index import get_default_isotope_index
@dataclass
class TrainingConfig2D:
"""Training configuration for 2D model."""
# Data
data_dir: str = "O:/master_data_collection/isotopev2"
model_dir: str = "models"
# Model
target_time_intervals: int = 60
# Training
epochs: int = 50
batch_size: int = 32
learning_rate: float = 1e-3
weight_decay: float = 1e-5
# Loss weights
classification_weight: float = 1.0
regression_weight: float = 0.1
# Mixed precision
use_amp: bool = True
# Early stopping
early_stopping_patience: int = 10
# Learning rate scheduler
lr_scheduler_patience: int = 5
lr_scheduler_factor: float = 0.5
# Data loading
num_workers: int = 4
def train_epoch(
model: nn.Module,
train_loader,
optimizer: optim.Optimizer,
criterion_cls: nn.Module,
criterion_reg: nn.Module,
device: torch.device,
scaler: Optional[GradScaler],
config: TrainingConfig2D
) -> Dict[str, float]:
"""Train for one epoch."""
model.train()
total_loss = 0.0
total_cls_loss = 0.0
total_reg_loss = 0.0
num_batches = 0
for batch in train_loader:
spectra = batch['spectrum'].to(device)
presence = batch['presence_labels'].to(device)
activities = batch['activity_labels'].to(device)
optimizer.zero_grad()
if scaler is not None:
with autocast():
logits, pred_activities = model(spectra)
cls_loss = criterion_cls(logits, presence)
reg_loss = criterion_reg(pred_activities, activities)
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
logits, pred_activities = model(spectra)
cls_loss = criterion_cls(logits, presence)
reg_loss = criterion_reg(pred_activities, activities)
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_reg_loss += reg_loss.item()
num_batches += 1
return {
'loss': total_loss / num_batches,
'cls_loss': total_cls_loss / num_batches,
'reg_loss': total_reg_loss / num_batches
}
@torch.no_grad()
def validate(
model: nn.Module,
val_loader,
criterion_cls: nn.Module,
criterion_reg: nn.Module,
device: torch.device,
config: TrainingConfig2D,
threshold: float = 0.5
) -> Dict[str, float]:
"""Validate the model."""
model.eval()
total_loss = 0.0
total_cls_loss = 0.0
total_reg_loss = 0.0
num_batches = 0
all_preds = []
all_labels = []
for batch in val_loader:
spectra = batch['spectrum'].to(device)
presence = batch['presence_labels'].to(device)
activities = batch['activity_labels'].to(device)
logits, pred_activities = model(spectra)
cls_loss = criterion_cls(logits, presence)
reg_loss = criterion_reg(pred_activities, activities)
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_reg_loss += reg_loss.item()
num_batches += 1
# Collect predictions for metrics
probs = torch.sigmoid(logits)
preds = (probs >= threshold).float()
all_preds.append(preds.cpu())
all_labels.append(presence.cpu())
# Calculate metrics
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Per-sample accuracy (all isotopes correct)
exact_match = (all_preds == all_labels).all(dim=1).float().mean().item()
# Per-isotope metrics
tp = ((all_preds == 1) & (all_labels == 1)).sum().item()
fp = ((all_preds == 1) & (all_labels == 0)).sum().item()
fn = ((all_preds == 0) & (all_labels == 1)).sum().item()
tn = ((all_preds == 0) & (all_labels == 0)).sum().item()
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return {
'loss': total_loss / num_batches,
'cls_loss': total_cls_loss / num_batches,
'reg_loss': total_reg_loss / num_batches,
'exact_match': exact_match,
'precision': precision,
'recall': recall,
'f1': f1
}
def train_vega_2d(
config: TrainingConfig2D = None,
model_config: Vega2DConfig = None
) -> Tuple[Vega2DModel, Dict]:
"""
Train the Vega 2D model.
"""
config = config or TrainingConfig2D()
model_config = model_config or Vega2DConfig(num_time_intervals=config.target_time_intervals)
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Create model
model = Vega2DModel(model_config).to(device)
print(f"\nModel: Vega 2D")
print(f" Input: ({model_config.num_time_intervals}, {model_config.num_channels})")
print(f" Conv channels: {model_config.conv_channels}")
print(f" FC dims: {model_config.fc_hidden_dims}")
print(f" Parameters: {count_parameters(model):,}")
# Create data loaders
print(f"\nLoading data from: {config.data_dir}")
isotope_index = get_default_isotope_index()
train_loader, val_loader, test_loader = create_data_loaders_2d(
data_dir=Path(config.data_dir),
batch_size=config.batch_size,
target_time_intervals=config.target_time_intervals,
isotope_index=isotope_index,
num_workers=config.num_workers
)
# Loss functions
criterion_cls = nn.BCEWithLogitsLoss()
criterion_reg = nn.HuberLoss()
# Optimizer
optimizer = optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=config.lr_scheduler_factor,
patience=config.lr_scheduler_patience
)
# Mixed precision scaler
scaler = GradScaler() if config.use_amp and device.type == 'cuda' else None
# Training history
history = {
'train_loss': [], 'val_loss': [],
'train_cls_loss': [], 'val_cls_loss': [],
'train_reg_loss': [], 'val_reg_loss': [],
'val_exact_match': [], 'val_precision': [], 'val_recall': [], 'val_f1': [],
'lr': []
}
# Early stopping
best_val_loss = float('inf')
patience_counter = 0
# Model directory
model_dir = Path(config.model_dir)
model_dir.mkdir(exist_ok=True)
print(f"\nStarting training for {config.epochs} epochs...")
print(f" Batch size: {config.batch_size}")
print(f" Learning rate: {config.learning_rate}")
print(f" AMP: {scaler is not None}")
print()
start_time = time.time()
for epoch in range(config.epochs):
epoch_start = time.time()
# Train
train_metrics = train_epoch(
model, train_loader, optimizer,
criterion_cls, criterion_reg,
device, scaler, config
)
# Validate
val_metrics = validate(
model, val_loader,
criterion_cls, criterion_reg,
device, config
)
# Update scheduler
scheduler.step(val_metrics['loss'])
current_lr = optimizer.param_groups[0]['lr']
# Record history
history['train_loss'].append(train_metrics['loss'])
history['val_loss'].append(val_metrics['loss'])
history['train_cls_loss'].append(train_metrics['cls_loss'])
history['val_cls_loss'].append(val_metrics['cls_loss'])
history['train_reg_loss'].append(train_metrics['reg_loss'])
history['val_reg_loss'].append(val_metrics['reg_loss'])
history['val_exact_match'].append(val_metrics['exact_match'])
history['val_precision'].append(val_metrics['precision'])
history['val_recall'].append(val_metrics['recall'])
history['val_f1'].append(val_metrics['f1'])
history['lr'].append(current_lr)
epoch_time = time.time() - epoch_start
# Print progress
print(f"Epoch {epoch+1:3d}/{config.epochs} ({epoch_time:.1f}s) | "
f"Train Loss: {train_metrics['loss']:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"F1: {val_metrics['f1']:.4f} | "
f"Recall: {val_metrics['recall']:.4f} | "
f"LR: {current_lr:.2e}")
# Save best model
if val_metrics['loss'] < best_val_loss:
best_val_loss = val_metrics['loss']
patience_counter = 0
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'model_config': asdict(model_config),
'training_config': asdict(config),
'val_metrics': val_metrics,
'history': history
}, model_dir / 'vega_2d_best.pt')
print(f" ✓ Saved best model (val_loss: {best_val_loss:.4f})")
else:
patience_counter += 1
# Early stopping
if patience_counter >= config.early_stopping_patience:
print(f"\nEarly stopping at epoch {epoch+1}")
break
total_time = time.time() - start_time
print(f"\nTraining complete in {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.4f}")
# Save final model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'model_config': asdict(model_config),
'training_config': asdict(config),
'history': history
}, model_dir / 'vega_2d_final.pt')
# Save history
with open(model_dir / 'vega_2d_history.json', 'w') as f:
json.dump(history, f, indent=2)
# Test set evaluation
print("\nEvaluating on test set...")
test_metrics = validate(
model, test_loader,
criterion_cls, criterion_reg,
device, config
)
print(f" Test Loss: {test_metrics['loss']:.4f}")
print(f" Test F1: {test_metrics['f1']:.4f}")
print(f" Test Recall: {test_metrics['recall']:.4f}")
print(f" Test Precision: {test_metrics['precision']:.4f}")
print(f" Test Exact Match: {test_metrics['exact_match']:.4f}")
return model, history
def main():
parser = argparse.ArgumentParser(description='Train Vega 2D Model')
parser.add_argument('--data-dir', type=str, default='O:/master_data_collection/isotopev2',
help='Path to training data')
parser.add_argument('--model-dir', type=str, default='models',
help='Path to save models')
parser.add_argument('--epochs', type=int, default=50,
help='Number of epochs')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size')
parser.add_argument('--lr', type=float, default=1e-3,
help='Learning rate')
parser.add_argument('--time-intervals', type=int, default=60,
help='Target time intervals (pad/truncate)')
parser.add_argument('--no-amp', action='store_true',
help='Disable mixed precision training')
parser.add_argument('--workers', type=int, default=4,
help='Data loading workers')
args = parser.parse_args()
config = TrainingConfig2D(
data_dir=args.data_dir,
model_dir=args.model_dir,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
target_time_intervals=args.time_intervals,
use_amp=not args.no_amp,
num_workers=args.workers
)
model_config = Vega2DConfig(
num_time_intervals=args.time_intervals
)
train_vega_2d(config, model_config)
if __name__ == '__main__':
main()