Pipeline complet Radiacode 103 - identification automatique d'isotopes
- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89% - Generation 50k spectres synthetiques 1D (12-24h durees) - Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell) - Detection continue avec soustraction du background - Capture background 24h avec gestion deconnexion - Docker Compose : conteneur train (GPU) + detect (CPU/USB) - Modele entraite inclus (vega_best.pt, 395 Mo) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
411
train/vega_ml/training/vega/train_2d.py
Normal file
411
train/vega_ml/training/vega/train_2d.py
Normal file
@ -0,0 +1,411 @@
|
||||
"""
|
||||
Training Script for Vega 2D Model
|
||||
|
||||
Uses 2D convolutions to process gamma spectra with temporal information.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from .model_2d import Vega2DModel, Vega2DConfig, count_parameters
|
||||
from .dataset_2d import create_data_loaders_2d, SpectrumDataset2D
|
||||
from .isotope_index import get_default_isotope_index
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig2D:
|
||||
"""Training configuration for 2D model."""
|
||||
|
||||
# Data
|
||||
data_dir: str = "O:/master_data_collection/isotopev2"
|
||||
model_dir: str = "models"
|
||||
|
||||
# Model
|
||||
target_time_intervals: int = 60
|
||||
|
||||
# Training
|
||||
epochs: int = 50
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 1e-3
|
||||
weight_decay: float = 1e-5
|
||||
|
||||
# Loss weights
|
||||
classification_weight: float = 1.0
|
||||
regression_weight: float = 0.1
|
||||
|
||||
# Mixed precision
|
||||
use_amp: bool = True
|
||||
|
||||
# Early stopping
|
||||
early_stopping_patience: int = 10
|
||||
|
||||
# Learning rate scheduler
|
||||
lr_scheduler_patience: int = 5
|
||||
lr_scheduler_factor: float = 0.5
|
||||
|
||||
# Data loading
|
||||
num_workers: int = 4
|
||||
|
||||
|
||||
def train_epoch(
|
||||
model: nn.Module,
|
||||
train_loader,
|
||||
optimizer: optim.Optimizer,
|
||||
criterion_cls: nn.Module,
|
||||
criterion_reg: nn.Module,
|
||||
device: torch.device,
|
||||
scaler: Optional[GradScaler],
|
||||
config: TrainingConfig2D
|
||||
) -> Dict[str, float]:
|
||||
"""Train for one epoch."""
|
||||
model.train()
|
||||
|
||||
total_loss = 0.0
|
||||
total_cls_loss = 0.0
|
||||
total_reg_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in train_loader:
|
||||
spectra = batch['spectrum'].to(device)
|
||||
presence = batch['presence_labels'].to(device)
|
||||
activities = batch['activity_labels'].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if scaler is not None:
|
||||
with autocast():
|
||||
logits, pred_activities = model(spectra)
|
||||
cls_loss = criterion_cls(logits, presence)
|
||||
reg_loss = criterion_reg(pred_activities, activities)
|
||||
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
logits, pred_activities = model(spectra)
|
||||
cls_loss = criterion_cls(logits, presence)
|
||||
reg_loss = criterion_reg(pred_activities, activities)
|
||||
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_cls_loss += cls_loss.item()
|
||||
total_reg_loss += reg_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
return {
|
||||
'loss': total_loss / num_batches,
|
||||
'cls_loss': total_cls_loss / num_batches,
|
||||
'reg_loss': total_reg_loss / num_batches
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
model: nn.Module,
|
||||
val_loader,
|
||||
criterion_cls: nn.Module,
|
||||
criterion_reg: nn.Module,
|
||||
device: torch.device,
|
||||
config: TrainingConfig2D,
|
||||
threshold: float = 0.5
|
||||
) -> Dict[str, float]:
|
||||
"""Validate the model."""
|
||||
model.eval()
|
||||
|
||||
total_loss = 0.0
|
||||
total_cls_loss = 0.0
|
||||
total_reg_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
for batch in val_loader:
|
||||
spectra = batch['spectrum'].to(device)
|
||||
presence = batch['presence_labels'].to(device)
|
||||
activities = batch['activity_labels'].to(device)
|
||||
|
||||
logits, pred_activities = model(spectra)
|
||||
cls_loss = criterion_cls(logits, presence)
|
||||
reg_loss = criterion_reg(pred_activities, activities)
|
||||
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
|
||||
|
||||
total_loss += loss.item()
|
||||
total_cls_loss += cls_loss.item()
|
||||
total_reg_loss += reg_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Collect predictions for metrics
|
||||
probs = torch.sigmoid(logits)
|
||||
preds = (probs >= threshold).float()
|
||||
all_preds.append(preds.cpu())
|
||||
all_labels.append(presence.cpu())
|
||||
|
||||
# Calculate metrics
|
||||
all_preds = torch.cat(all_preds, dim=0)
|
||||
all_labels = torch.cat(all_labels, dim=0)
|
||||
|
||||
# Per-sample accuracy (all isotopes correct)
|
||||
exact_match = (all_preds == all_labels).all(dim=1).float().mean().item()
|
||||
|
||||
# Per-isotope metrics
|
||||
tp = ((all_preds == 1) & (all_labels == 1)).sum().item()
|
||||
fp = ((all_preds == 1) & (all_labels == 0)).sum().item()
|
||||
fn = ((all_preds == 0) & (all_labels == 1)).sum().item()
|
||||
tn = ((all_preds == 0) & (all_labels == 0)).sum().item()
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
return {
|
||||
'loss': total_loss / num_batches,
|
||||
'cls_loss': total_cls_loss / num_batches,
|
||||
'reg_loss': total_reg_loss / num_batches,
|
||||
'exact_match': exact_match,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1
|
||||
}
|
||||
|
||||
|
||||
def train_vega_2d(
|
||||
config: TrainingConfig2D = None,
|
||||
model_config: Vega2DConfig = None
|
||||
) -> Tuple[Vega2DModel, Dict]:
|
||||
"""
|
||||
Train the Vega 2D model.
|
||||
"""
|
||||
config = config or TrainingConfig2D()
|
||||
model_config = model_config or Vega2DConfig(num_time_intervals=config.target_time_intervals)
|
||||
|
||||
# Setup device
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
if device.type == 'cuda':
|
||||
print(f" GPU: {torch.cuda.get_device_name()}")
|
||||
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
||||
|
||||
# Create model
|
||||
model = Vega2DModel(model_config).to(device)
|
||||
print(f"\nModel: Vega 2D")
|
||||
print(f" Input: ({model_config.num_time_intervals}, {model_config.num_channels})")
|
||||
print(f" Conv channels: {model_config.conv_channels}")
|
||||
print(f" FC dims: {model_config.fc_hidden_dims}")
|
||||
print(f" Parameters: {count_parameters(model):,}")
|
||||
|
||||
# Create data loaders
|
||||
print(f"\nLoading data from: {config.data_dir}")
|
||||
isotope_index = get_default_isotope_index()
|
||||
|
||||
train_loader, val_loader, test_loader = create_data_loaders_2d(
|
||||
data_dir=Path(config.data_dir),
|
||||
batch_size=config.batch_size,
|
||||
target_time_intervals=config.target_time_intervals,
|
||||
isotope_index=isotope_index,
|
||||
num_workers=config.num_workers
|
||||
)
|
||||
|
||||
# Loss functions
|
||||
criterion_cls = nn.BCEWithLogitsLoss()
|
||||
criterion_reg = nn.HuberLoss()
|
||||
|
||||
# Optimizer
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='min',
|
||||
factor=config.lr_scheduler_factor,
|
||||
patience=config.lr_scheduler_patience
|
||||
)
|
||||
|
||||
# Mixed precision scaler
|
||||
scaler = GradScaler() if config.use_amp and device.type == 'cuda' else None
|
||||
|
||||
# Training history
|
||||
history = {
|
||||
'train_loss': [], 'val_loss': [],
|
||||
'train_cls_loss': [], 'val_cls_loss': [],
|
||||
'train_reg_loss': [], 'val_reg_loss': [],
|
||||
'val_exact_match': [], 'val_precision': [], 'val_recall': [], 'val_f1': [],
|
||||
'lr': []
|
||||
}
|
||||
|
||||
# Early stopping
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
# Model directory
|
||||
model_dir = Path(config.model_dir)
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
|
||||
print(f"\nStarting training for {config.epochs} epochs...")
|
||||
print(f" Batch size: {config.batch_size}")
|
||||
print(f" Learning rate: {config.learning_rate}")
|
||||
print(f" AMP: {scaler is not None}")
|
||||
print()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
epoch_start = time.time()
|
||||
|
||||
# Train
|
||||
train_metrics = train_epoch(
|
||||
model, train_loader, optimizer,
|
||||
criterion_cls, criterion_reg,
|
||||
device, scaler, config
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_metrics = validate(
|
||||
model, val_loader,
|
||||
criterion_cls, criterion_reg,
|
||||
device, config
|
||||
)
|
||||
|
||||
# Update scheduler
|
||||
scheduler.step(val_metrics['loss'])
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Record history
|
||||
history['train_loss'].append(train_metrics['loss'])
|
||||
history['val_loss'].append(val_metrics['loss'])
|
||||
history['train_cls_loss'].append(train_metrics['cls_loss'])
|
||||
history['val_cls_loss'].append(val_metrics['cls_loss'])
|
||||
history['train_reg_loss'].append(train_metrics['reg_loss'])
|
||||
history['val_reg_loss'].append(val_metrics['reg_loss'])
|
||||
history['val_exact_match'].append(val_metrics['exact_match'])
|
||||
history['val_precision'].append(val_metrics['precision'])
|
||||
history['val_recall'].append(val_metrics['recall'])
|
||||
history['val_f1'].append(val_metrics['f1'])
|
||||
history['lr'].append(current_lr)
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
|
||||
# Print progress
|
||||
print(f"Epoch {epoch+1:3d}/{config.epochs} ({epoch_time:.1f}s) | "
|
||||
f"Train Loss: {train_metrics['loss']:.4f} | "
|
||||
f"Val Loss: {val_metrics['loss']:.4f} | "
|
||||
f"F1: {val_metrics['f1']:.4f} | "
|
||||
f"Recall: {val_metrics['recall']:.4f} | "
|
||||
f"LR: {current_lr:.2e}")
|
||||
|
||||
# Save best model
|
||||
if val_metrics['loss'] < best_val_loss:
|
||||
best_val_loss = val_metrics['loss']
|
||||
patience_counter = 0
|
||||
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'model_config': asdict(model_config),
|
||||
'training_config': asdict(config),
|
||||
'val_metrics': val_metrics,
|
||||
'history': history
|
||||
}, model_dir / 'vega_2d_best.pt')
|
||||
print(f" ✓ Saved best model (val_loss: {best_val_loss:.4f})")
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
# Early stopping
|
||||
if patience_counter >= config.early_stopping_patience:
|
||||
print(f"\nEarly stopping at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"\nTraining complete in {total_time/60:.1f} minutes")
|
||||
print(f"Best validation loss: {best_val_loss:.4f}")
|
||||
|
||||
# Save final model
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'model_config': asdict(model_config),
|
||||
'training_config': asdict(config),
|
||||
'history': history
|
||||
}, model_dir / 'vega_2d_final.pt')
|
||||
|
||||
# Save history
|
||||
with open(model_dir / 'vega_2d_history.json', 'w') as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
# Test set evaluation
|
||||
print("\nEvaluating on test set...")
|
||||
test_metrics = validate(
|
||||
model, test_loader,
|
||||
criterion_cls, criterion_reg,
|
||||
device, config
|
||||
)
|
||||
print(f" Test Loss: {test_metrics['loss']:.4f}")
|
||||
print(f" Test F1: {test_metrics['f1']:.4f}")
|
||||
print(f" Test Recall: {test_metrics['recall']:.4f}")
|
||||
print(f" Test Precision: {test_metrics['precision']:.4f}")
|
||||
print(f" Test Exact Match: {test_metrics['exact_match']:.4f}")
|
||||
|
||||
return model, history
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train Vega 2D Model')
|
||||
parser.add_argument('--data-dir', type=str, default='O:/master_data_collection/isotopev2',
|
||||
help='Path to training data')
|
||||
parser.add_argument('--model-dir', type=str, default='models',
|
||||
help='Path to save models')
|
||||
parser.add_argument('--epochs', type=int, default=50,
|
||||
help='Number of epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-3,
|
||||
help='Learning rate')
|
||||
parser.add_argument('--time-intervals', type=int, default=60,
|
||||
help='Target time intervals (pad/truncate)')
|
||||
parser.add_argument('--no-amp', action='store_true',
|
||||
help='Disable mixed precision training')
|
||||
parser.add_argument('--workers', type=int, default=4,
|
||||
help='Data loading workers')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TrainingConfig2D(
|
||||
data_dir=args.data_dir,
|
||||
model_dir=args.model_dir,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
target_time_intervals=args.time_intervals,
|
||||
use_amp=not args.no_amp,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
model_config = Vega2DConfig(
|
||||
num_time_intervals=args.time_intervals
|
||||
)
|
||||
|
||||
train_vega_2d(config, model_config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user