- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89% - Generation 50k spectres synthetiques 1D (12-24h durees) - Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell) - Detection continue avec soustraction du background - Capture background 24h avec gestion deconnexion - Docker Compose : conteneur train (GPU) + detect (CPU/USB) - Modele entraite inclus (vega_best.pt, 395 Mo) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
615 lines
21 KiB
Python
615 lines
21 KiB
Python
"""
|
|
Training Script for Vega Model
|
|
|
|
Implements the training loop with:
|
|
- Mixed precision training for RTX 5090 efficiency
|
|
- Learning rate scheduling
|
|
- Early stopping
|
|
- Model checkpointing
|
|
- Training metrics logging
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple
|
|
from dataclasses import dataclass, asdict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
import numpy as np
|
|
|
|
# Sklearn metrics for comprehensive evaluation
|
|
from sklearn.metrics import (
|
|
roc_auc_score,
|
|
f1_score,
|
|
precision_score,
|
|
recall_score,
|
|
hamming_loss
|
|
)
|
|
|
|
from .model import VegaModel, VegaConfig, VegaLoss
|
|
from .dataset import create_data_loaders, SpectrumDataset
|
|
from .isotope_index import IsotopeIndex, get_default_isotope_index
|
|
|
|
|
|
@dataclass
|
|
class TrainingConfig:
|
|
"""Configuration for training."""
|
|
# Data
|
|
data_dir: str = "O:/master_data_collection/isotopev2"
|
|
|
|
# Model save path
|
|
model_dir: str = "models"
|
|
model_name: str = "vega"
|
|
|
|
# Training hyperparameters
|
|
batch_size: int = 64 # Increased from 32 for better GPU utilization
|
|
learning_rate: float = 1e-3
|
|
weight_decay: float = 1e-4
|
|
num_epochs: int = 100
|
|
|
|
# Early stopping
|
|
patience: int = 10
|
|
min_delta: float = 1e-4
|
|
|
|
# Learning rate scheduling
|
|
lr_scheduler_patience: int = 5
|
|
lr_scheduler_factor: float = 0.5
|
|
min_lr: float = 1e-6
|
|
|
|
# Mixed precision
|
|
use_amp: bool = True
|
|
|
|
# Data splits
|
|
train_split: float = 0.8
|
|
val_split: float = 0.1
|
|
test_split: float = 0.1
|
|
|
|
# Workers - parallel data loading for better GPU utilization
|
|
num_workers: int = 8 # Parallel data loading workers
|
|
prefetch_factor: int = 4 # Batches to prefetch per worker
|
|
persistent_workers: bool = True # Keep workers alive between epochs
|
|
|
|
# Reproducibility
|
|
seed: int = 42
|
|
|
|
# Activity normalization
|
|
max_activity_bq: float = 1000.0
|
|
|
|
|
|
class VegaTrainer:
|
|
"""
|
|
Trainer class for the Vega model.
|
|
|
|
Handles the training loop, validation, checkpointing, and metrics.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: VegaModel,
|
|
config: TrainingConfig,
|
|
device: Optional[torch.device] = None,
|
|
force_cpu: bool = False
|
|
):
|
|
self.model = model
|
|
self.config = config
|
|
|
|
# Device selection - force CPU if requested or if CUDA incompatible
|
|
if force_cpu:
|
|
self.device = torch.device('cpu')
|
|
elif device:
|
|
self.device = device
|
|
else:
|
|
# Try CUDA but fall back to CPU if there are compatibility issues
|
|
if torch.cuda.is_available():
|
|
try:
|
|
# Test if CUDA actually works
|
|
test_tensor = torch.zeros(1, device='cuda')
|
|
_ = test_tensor + 1
|
|
self.device = torch.device('cuda')
|
|
except RuntimeError:
|
|
print("CUDA device found but not compatible, falling back to CPU")
|
|
self.device = torch.device('cpu')
|
|
else:
|
|
self.device = torch.device('cpu')
|
|
|
|
# Move model to device
|
|
self.model = self.model.to(self.device)
|
|
|
|
# Setup loss function
|
|
self.loss_fn = VegaLoss(
|
|
classification_weight=model.config.classification_weight,
|
|
regression_weight=model.config.regression_weight
|
|
)
|
|
|
|
# Setup optimizer
|
|
self.optimizer = Adam(
|
|
self.model.parameters(),
|
|
lr=config.learning_rate,
|
|
weight_decay=config.weight_decay
|
|
)
|
|
|
|
# Setup learning rate scheduler
|
|
self.scheduler = ReduceLROnPlateau(
|
|
self.optimizer,
|
|
mode='min',
|
|
patience=config.lr_scheduler_patience,
|
|
factor=config.lr_scheduler_factor,
|
|
min_lr=config.min_lr
|
|
)
|
|
|
|
# Setup mixed precision training (only if CUDA is working)
|
|
if config.use_amp and self.device.type == 'cuda':
|
|
self.scaler = torch.amp.GradScaler('cuda')
|
|
else:
|
|
self.scaler = None
|
|
|
|
# Training state
|
|
self.current_epoch = 0
|
|
self.best_val_loss = float('inf')
|
|
self.epochs_without_improvement = 0
|
|
self.training_history = []
|
|
|
|
# Create model directory
|
|
self.model_dir = Path(config.model_dir)
|
|
self.model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"Training on device: {self.device}")
|
|
if self.device.type == 'cuda':
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
print(f"Mixed precision: {config.use_amp}")
|
|
|
|
def train_epoch(self, train_loader) -> Dict[str, float]:
|
|
"""Train for one epoch."""
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
cls_loss_sum = 0.0
|
|
reg_loss_sum = 0.0
|
|
num_batches = 0
|
|
|
|
# Track accuracy during training
|
|
correct_isotopes = 0
|
|
total_isotopes = 0
|
|
|
|
# Timing for profiling - track data loading vs GPU compute
|
|
data_time = 0.0
|
|
compute_time = 0.0
|
|
data_start = time.time()
|
|
|
|
for batch in train_loader:
|
|
# Data loading time (time spent waiting for next batch)
|
|
data_time += time.time() - data_start
|
|
compute_start = time.time()
|
|
|
|
# Move to device
|
|
spectra = batch['spectrum'].to(self.device)
|
|
presence = batch['presence_labels'].to(self.device)
|
|
activities = batch['activity_labels'].to(self.device)
|
|
|
|
# Zero gradients
|
|
self.optimizer.zero_grad()
|
|
|
|
# Forward pass with optional mixed precision
|
|
if self.scaler is not None:
|
|
with torch.amp.autocast('cuda'):
|
|
pred_logits, pred_activities = self.model(spectra)
|
|
loss, loss_dict = self.loss_fn(
|
|
pred_logits, pred_activities, presence, activities
|
|
)
|
|
|
|
# Backward pass with scaling
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
pred_logits, pred_activities = self.model(spectra)
|
|
loss, loss_dict = self.loss_fn(
|
|
pred_logits, pred_activities, presence, activities
|
|
)
|
|
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss_dict['total']
|
|
cls_loss_sum += loss_dict['classification']
|
|
reg_loss_sum += loss_dict['regression']
|
|
num_batches += 1
|
|
|
|
# Calculate training accuracy (detach to avoid memory buildup)
|
|
with torch.no_grad():
|
|
pred_probs = torch.sigmoid(pred_logits)
|
|
pred_presence = (pred_probs >= 0.5).float()
|
|
correct_isotopes += (pred_presence == presence).sum().item()
|
|
total_isotopes += presence.numel()
|
|
|
|
# Mark compute time and restart timing for data loading
|
|
compute_time += time.time() - compute_start
|
|
data_start = time.time()
|
|
|
|
train_accuracy = correct_isotopes / total_isotopes if total_isotopes > 0 else 0.0
|
|
|
|
return {
|
|
'train_loss': total_loss / num_batches,
|
|
'train_cls_loss': cls_loss_sum / num_batches,
|
|
'train_reg_loss': reg_loss_sum / num_batches,
|
|
'train_accuracy': train_accuracy,
|
|
'data_time': data_time,
|
|
'compute_time': compute_time
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def validate(self, val_loader) -> Dict[str, float]:
|
|
"""Validate the model with comprehensive metrics."""
|
|
if val_loader is None:
|
|
return {}
|
|
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
cls_loss_sum = 0.0
|
|
reg_loss_sum = 0.0
|
|
num_batches = 0
|
|
|
|
# Collect all predictions and labels for sklearn metrics
|
|
all_probs = []
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
for batch in val_loader:
|
|
spectra = batch['spectrum'].to(self.device)
|
|
presence = batch['presence_labels'].to(self.device)
|
|
activities = batch['activity_labels'].to(self.device)
|
|
|
|
pred_logits, pred_activities = self.model(spectra)
|
|
loss, loss_dict = self.loss_fn(
|
|
pred_logits, pred_activities, presence, activities
|
|
)
|
|
|
|
total_loss += loss_dict['total']
|
|
cls_loss_sum += loss_dict['classification']
|
|
reg_loss_sum += loss_dict['regression']
|
|
num_batches += 1
|
|
|
|
# Collect predictions for metrics
|
|
pred_probs = torch.sigmoid(pred_logits)
|
|
pred_presence = (pred_probs >= 0.5).float()
|
|
|
|
all_probs.append(pred_probs.cpu().numpy())
|
|
all_preds.append(pred_presence.cpu().numpy())
|
|
all_labels.append(presence.cpu().numpy())
|
|
|
|
# Concatenate all batches
|
|
all_probs = np.vstack(all_probs)
|
|
all_preds = np.vstack(all_preds)
|
|
all_labels = np.vstack(all_labels)
|
|
|
|
# Basic accuracy (element-wise)
|
|
correct = (all_preds == all_labels).sum()
|
|
total = all_labels.size
|
|
accuracy = correct / total if total > 0 else 0.0
|
|
|
|
# Multi-label metrics using sklearn
|
|
metrics = {
|
|
'val_loss': total_loss / num_batches,
|
|
'val_cls_loss': cls_loss_sum / num_batches,
|
|
'val_reg_loss': reg_loss_sum / num_batches,
|
|
'val_accuracy': accuracy,
|
|
}
|
|
|
|
try:
|
|
# ROC-AUC (macro-averaged over isotopes with both classes present)
|
|
# Only compute for columns that have both 0s and 1s
|
|
valid_cols = []
|
|
for i in range(all_labels.shape[1]):
|
|
if len(np.unique(all_labels[:, i])) == 2:
|
|
valid_cols.append(i)
|
|
|
|
if valid_cols:
|
|
auc_macro = roc_auc_score(
|
|
all_labels[:, valid_cols],
|
|
all_probs[:, valid_cols],
|
|
average='macro'
|
|
)
|
|
auc_micro = roc_auc_score(
|
|
all_labels[:, valid_cols],
|
|
all_probs[:, valid_cols],
|
|
average='micro'
|
|
)
|
|
metrics['val_auc_macro'] = auc_macro
|
|
metrics['val_auc_micro'] = auc_micro
|
|
else:
|
|
metrics['val_auc_macro'] = 0.0
|
|
metrics['val_auc_micro'] = 0.0
|
|
|
|
except ValueError:
|
|
# Handle case where AUC can't be computed
|
|
metrics['val_auc_macro'] = 0.0
|
|
metrics['val_auc_micro'] = 0.0
|
|
|
|
# F1, Precision, Recall (samples-averaged for multi-label)
|
|
metrics['val_f1_macro'] = f1_score(all_labels, all_preds, average='macro', zero_division=0)
|
|
metrics['val_f1_micro'] = f1_score(all_labels, all_preds, average='micro', zero_division=0)
|
|
metrics['val_precision'] = precision_score(all_labels, all_preds, average='micro', zero_division=0)
|
|
metrics['val_recall'] = recall_score(all_labels, all_preds, average='micro', zero_division=0)
|
|
|
|
# Hamming loss (fraction of labels incorrectly predicted)
|
|
metrics['val_hamming'] = hamming_loss(all_labels, all_preds)
|
|
|
|
# Exact match ratio (all isotopes correct for a sample)
|
|
exact_matches = (all_preds == all_labels).all(axis=1).sum()
|
|
metrics['val_exact_match'] = exact_matches / len(all_labels)
|
|
|
|
return metrics
|
|
|
|
def save_checkpoint(self, path: Path, is_best: bool = False):
|
|
"""Save a model checkpoint."""
|
|
checkpoint = {
|
|
'epoch': self.current_epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'best_val_loss': self.best_val_loss,
|
|
'model_config': asdict(self.model.config),
|
|
'training_config': asdict(self.config),
|
|
'training_history': self.training_history
|
|
}
|
|
|
|
if self.scaler is not None:
|
|
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
|
|
|
torch.save(checkpoint, path)
|
|
|
|
if is_best:
|
|
best_path = path.parent / f"{self.config.model_name}_best.pt"
|
|
torch.save(checkpoint, best_path)
|
|
|
|
def load_checkpoint(self, path: Path):
|
|
"""Load a model checkpoint."""
|
|
checkpoint = torch.load(path, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
if self.scaler is not None and 'scaler_state_dict' in checkpoint:
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
|
|
self.current_epoch = checkpoint['epoch']
|
|
self.best_val_loss = checkpoint['best_val_loss']
|
|
self.training_history = checkpoint.get('training_history', [])
|
|
|
|
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
|
|
|
def train(
|
|
self,
|
|
train_loader,
|
|
val_loader,
|
|
resume_from: Optional[Path] = None
|
|
) -> Dict:
|
|
"""
|
|
Full training loop.
|
|
|
|
Args:
|
|
train_loader: Training data loader
|
|
val_loader: Validation data loader
|
|
resume_from: Optional path to checkpoint to resume from
|
|
|
|
Returns:
|
|
Training results dictionary
|
|
"""
|
|
if resume_from is not None:
|
|
self.load_checkpoint(resume_from)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Starting Vega Training")
|
|
print("=" * 60)
|
|
print(f"Epochs: {self.config.num_epochs}")
|
|
print(f"Batch size: {self.config.batch_size}")
|
|
print(f"Learning rate: {self.config.learning_rate}")
|
|
print(f"Training samples: {len(train_loader.dataset)}")
|
|
if val_loader:
|
|
print(f"Validation samples: {len(val_loader.dataset)}")
|
|
print("=" * 60 + "\n")
|
|
|
|
start_time = time.time()
|
|
|
|
for epoch in range(self.current_epoch, self.config.num_epochs):
|
|
self.current_epoch = epoch
|
|
epoch_start = time.time()
|
|
|
|
# Train
|
|
train_metrics = self.train_epoch(train_loader)
|
|
|
|
# Validate
|
|
val_metrics = self.validate(val_loader)
|
|
|
|
# Combine metrics
|
|
metrics = {**train_metrics, **val_metrics, 'epoch': epoch}
|
|
self.training_history.append(metrics)
|
|
|
|
# Update learning rate
|
|
if val_loader and 'val_loss' in val_metrics:
|
|
self.scheduler.step(val_metrics['val_loss'])
|
|
else:
|
|
self.scheduler.step(train_metrics['train_loss'])
|
|
|
|
# Check for improvement
|
|
val_loss = val_metrics.get('val_loss', train_metrics['train_loss'])
|
|
is_best = val_loss < self.best_val_loss - self.config.min_delta
|
|
|
|
if is_best:
|
|
self.best_val_loss = val_loss
|
|
self.epochs_without_improvement = 0
|
|
else:
|
|
self.epochs_without_improvement += 1
|
|
|
|
# Save checkpoint
|
|
checkpoint_path = self.model_dir / f"{self.config.model_name}_epoch_{epoch}.pt"
|
|
self.save_checkpoint(checkpoint_path, is_best=is_best)
|
|
|
|
# Logging
|
|
epoch_time = time.time() - epoch_start
|
|
lr = self.optimizer.param_groups[0]['lr']
|
|
|
|
# Primary metrics line
|
|
log_str = (
|
|
f"Epoch {epoch+1:3d}/{self.config.num_epochs} | "
|
|
f"Train Loss: {train_metrics['train_loss']:.4f} | "
|
|
f"Train Acc: {train_metrics['train_accuracy']:.4f} | "
|
|
)
|
|
if val_loader:
|
|
log_str += (
|
|
f"Val Loss: {val_metrics['val_loss']:.4f} | "
|
|
f"Val Acc: {val_metrics['val_accuracy']:.4f} | "
|
|
)
|
|
log_str += f"LR: {lr:.2e} | Time: {epoch_time:.1f}s"
|
|
|
|
if is_best:
|
|
log_str += " *"
|
|
|
|
print(log_str)
|
|
|
|
# Timing breakdown line
|
|
data_t = train_metrics.get('data_time', 0)
|
|
compute_t = train_metrics.get('compute_time', 0)
|
|
if data_t > 0 or compute_t > 0:
|
|
data_pct = 100 * data_t / (data_t + compute_t) if (data_t + compute_t) > 0 else 0
|
|
print(f" └── Data: {data_t:.1f}s ({data_pct:.0f}%) | Compute: {compute_t:.1f}s ({100-data_pct:.0f}%)")
|
|
|
|
# Secondary metrics line (detailed classification metrics)
|
|
if val_loader and 'val_auc_macro' in val_metrics:
|
|
detail_str = (
|
|
f" └── AUC: {val_metrics['val_auc_macro']:.4f} | "
|
|
f"F1: {val_metrics['val_f1_macro']:.4f} | "
|
|
f"Prec: {val_metrics['val_precision']:.4f} | "
|
|
f"Recall: {val_metrics['val_recall']:.4f} | "
|
|
f"Exact: {val_metrics['val_exact_match']:.4f}"
|
|
)
|
|
print(detail_str)
|
|
|
|
# Early stopping
|
|
if self.epochs_without_improvement >= self.config.patience:
|
|
print(f"\nEarly stopping after {epoch + 1} epochs")
|
|
break
|
|
|
|
total_time = time.time() - start_time
|
|
|
|
# Save final model
|
|
final_path = self.model_dir / f"{self.config.model_name}_final.pt"
|
|
self.save_checkpoint(final_path)
|
|
|
|
# Save training history
|
|
history_path = self.model_dir / f"{self.config.model_name}_history.json"
|
|
with open(history_path, 'w') as f:
|
|
json.dump(self.training_history, f, indent=2)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"Training complete!")
|
|
print(f"Total time: {total_time / 60:.1f} minutes")
|
|
print(f"Best validation loss: {self.best_val_loss:.4f}")
|
|
print(f"Model saved to: {final_path}")
|
|
print("=" * 60)
|
|
|
|
return {
|
|
'best_val_loss': self.best_val_loss,
|
|
'total_epochs': self.current_epoch + 1,
|
|
'total_time': total_time,
|
|
'history': self.training_history
|
|
}
|
|
|
|
|
|
def train_vega(
|
|
data_dir: Optional[str] = None,
|
|
model_dir: Optional[str] = None,
|
|
config: Optional[TrainingConfig] = None,
|
|
model_config: Optional[VegaConfig] = None
|
|
) -> Tuple[VegaModel, Dict]:
|
|
"""
|
|
Convenience function to train a Vega model.
|
|
|
|
Args:
|
|
data_dir: Path to data directory
|
|
model_dir: Path to save models
|
|
config: Training configuration
|
|
model_config: Model configuration
|
|
|
|
Returns:
|
|
Tuple of (trained model, training results)
|
|
"""
|
|
# Setup paths
|
|
project_root = Path(__file__).parent.parent.parent
|
|
|
|
if config is None:
|
|
config = TrainingConfig()
|
|
|
|
if data_dir:
|
|
config.data_dir = data_dir
|
|
if model_dir:
|
|
config.model_dir = model_dir
|
|
|
|
# Make paths absolute
|
|
data_path = Path(config.data_dir)
|
|
if not data_path.is_absolute():
|
|
data_path = project_root / data_path
|
|
|
|
model_path = Path(config.model_dir)
|
|
if not model_path.is_absolute():
|
|
model_path = project_root / model_path
|
|
|
|
config.model_dir = str(model_path)
|
|
|
|
# Set random seeds
|
|
torch.manual_seed(config.seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(config.seed)
|
|
|
|
# Get isotope index
|
|
isotope_index = get_default_isotope_index()
|
|
|
|
# Create data loaders with parallel loading
|
|
train_loader, val_loader, test_loader = create_data_loaders(
|
|
data_dir=data_path,
|
|
batch_size=config.batch_size,
|
|
train_split=config.train_split,
|
|
val_split=config.val_split,
|
|
test_split=config.test_split,
|
|
num_workers=config.num_workers,
|
|
prefetch_factor=config.prefetch_factor,
|
|
persistent_workers=config.persistent_workers,
|
|
isotope_index=isotope_index,
|
|
max_activity_bq=config.max_activity_bq,
|
|
seed=config.seed
|
|
)
|
|
|
|
# Create model
|
|
if model_config is None:
|
|
model_config = VegaConfig(
|
|
num_isotopes=isotope_index.num_isotopes,
|
|
max_activity_bq=config.max_activity_bq
|
|
)
|
|
|
|
model = VegaModel(model_config)
|
|
print(model.summary())
|
|
|
|
# Create trainer
|
|
trainer = VegaTrainer(model, config)
|
|
|
|
# Train
|
|
results = trainer.train(train_loader, val_loader)
|
|
|
|
# Save isotope index with model
|
|
index_path = model_path / f"{config.model_name}_isotope_index.txt"
|
|
isotope_index.save(index_path)
|
|
|
|
return model, results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Quick test training
|
|
model, results = train_vega()
|