Pipeline complet Radiacode 103 - identification automatique d'isotopes
- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89% - Generation 50k spectres synthetiques 1D (12-24h durees) - Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell) - Detection continue avec soustraction du background - Capture background 24h avec gestion deconnexion - Docker Compose : conteneur train (GPU) + detect (CPU/USB) - Modele entraite inclus (vega_best.pt, 395 Mo) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
614
train/vega_ml/training/vega/train.py
Normal file
614
train/vega_ml/training/vega/train.py
Normal file
@ -0,0 +1,614 @@
|
||||
"""
|
||||
Training Script for Vega Model
|
||||
|
||||
Implements the training loop with:
|
||||
- Mixed precision training for RTX 5090 efficiency
|
||||
- Learning rate scheduling
|
||||
- Early stopping
|
||||
- Model checkpointing
|
||||
- Training metrics logging
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
import numpy as np
|
||||
|
||||
# Sklearn metrics for comprehensive evaluation
|
||||
from sklearn.metrics import (
|
||||
roc_auc_score,
|
||||
f1_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
hamming_loss
|
||||
)
|
||||
|
||||
from .model import VegaModel, VegaConfig, VegaLoss
|
||||
from .dataset import create_data_loaders, SpectrumDataset
|
||||
from .isotope_index import IsotopeIndex, get_default_isotope_index
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for training."""
|
||||
# Data
|
||||
data_dir: str = "O:/master_data_collection/isotopev2"
|
||||
|
||||
# Model save path
|
||||
model_dir: str = "models"
|
||||
model_name: str = "vega"
|
||||
|
||||
# Training hyperparameters
|
||||
batch_size: int = 64 # Increased from 32 for better GPU utilization
|
||||
learning_rate: float = 1e-3
|
||||
weight_decay: float = 1e-4
|
||||
num_epochs: int = 100
|
||||
|
||||
# Early stopping
|
||||
patience: int = 10
|
||||
min_delta: float = 1e-4
|
||||
|
||||
# Learning rate scheduling
|
||||
lr_scheduler_patience: int = 5
|
||||
lr_scheduler_factor: float = 0.5
|
||||
min_lr: float = 1e-6
|
||||
|
||||
# Mixed precision
|
||||
use_amp: bool = True
|
||||
|
||||
# Data splits
|
||||
train_split: float = 0.8
|
||||
val_split: float = 0.1
|
||||
test_split: float = 0.1
|
||||
|
||||
# Workers - parallel data loading for better GPU utilization
|
||||
num_workers: int = 8 # Parallel data loading workers
|
||||
prefetch_factor: int = 4 # Batches to prefetch per worker
|
||||
persistent_workers: bool = True # Keep workers alive between epochs
|
||||
|
||||
# Reproducibility
|
||||
seed: int = 42
|
||||
|
||||
# Activity normalization
|
||||
max_activity_bq: float = 1000.0
|
||||
|
||||
|
||||
class VegaTrainer:
|
||||
"""
|
||||
Trainer class for the Vega model.
|
||||
|
||||
Handles the training loop, validation, checkpointing, and metrics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: VegaModel,
|
||||
config: TrainingConfig,
|
||||
device: Optional[torch.device] = None,
|
||||
force_cpu: bool = False
|
||||
):
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
# Device selection - force CPU if requested or if CUDA incompatible
|
||||
if force_cpu:
|
||||
self.device = torch.device('cpu')
|
||||
elif device:
|
||||
self.device = device
|
||||
else:
|
||||
# Try CUDA but fall back to CPU if there are compatibility issues
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# Test if CUDA actually works
|
||||
test_tensor = torch.zeros(1, device='cuda')
|
||||
_ = test_tensor + 1
|
||||
self.device = torch.device('cuda')
|
||||
except RuntimeError:
|
||||
print("CUDA device found but not compatible, falling back to CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
# Move model to device
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Setup loss function
|
||||
self.loss_fn = VegaLoss(
|
||||
classification_weight=model.config.classification_weight,
|
||||
regression_weight=model.config.regression_weight
|
||||
)
|
||||
|
||||
# Setup optimizer
|
||||
self.optimizer = Adam(
|
||||
self.model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Setup learning rate scheduler
|
||||
self.scheduler = ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
mode='min',
|
||||
patience=config.lr_scheduler_patience,
|
||||
factor=config.lr_scheduler_factor,
|
||||
min_lr=config.min_lr
|
||||
)
|
||||
|
||||
# Setup mixed precision training (only if CUDA is working)
|
||||
if config.use_amp and self.device.type == 'cuda':
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
else:
|
||||
self.scaler = None
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float('inf')
|
||||
self.epochs_without_improvement = 0
|
||||
self.training_history = []
|
||||
|
||||
# Create model directory
|
||||
self.model_dir = Path(config.model_dir)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Training on device: {self.device}")
|
||||
if self.device.type == 'cuda':
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"Mixed precision: {config.use_amp}")
|
||||
|
||||
def train_epoch(self, train_loader) -> Dict[str, float]:
|
||||
"""Train for one epoch."""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
cls_loss_sum = 0.0
|
||||
reg_loss_sum = 0.0
|
||||
num_batches = 0
|
||||
|
||||
# Track accuracy during training
|
||||
correct_isotopes = 0
|
||||
total_isotopes = 0
|
||||
|
||||
# Timing for profiling - track data loading vs GPU compute
|
||||
data_time = 0.0
|
||||
compute_time = 0.0
|
||||
data_start = time.time()
|
||||
|
||||
for batch in train_loader:
|
||||
# Data loading time (time spent waiting for next batch)
|
||||
data_time += time.time() - data_start
|
||||
compute_start = time.time()
|
||||
|
||||
# Move to device
|
||||
spectra = batch['spectrum'].to(self.device)
|
||||
presence = batch['presence_labels'].to(self.device)
|
||||
activities = batch['activity_labels'].to(self.device)
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Forward pass with optional mixed precision
|
||||
if self.scaler is not None:
|
||||
with torch.amp.autocast('cuda'):
|
||||
pred_logits, pred_activities = self.model(spectra)
|
||||
loss, loss_dict = self.loss_fn(
|
||||
pred_logits, pred_activities, presence, activities
|
||||
)
|
||||
|
||||
# Backward pass with scaling
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
pred_logits, pred_activities = self.model(spectra)
|
||||
loss, loss_dict = self.loss_fn(
|
||||
pred_logits, pred_activities, presence, activities
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss_dict['total']
|
||||
cls_loss_sum += loss_dict['classification']
|
||||
reg_loss_sum += loss_dict['regression']
|
||||
num_batches += 1
|
||||
|
||||
# Calculate training accuracy (detach to avoid memory buildup)
|
||||
with torch.no_grad():
|
||||
pred_probs = torch.sigmoid(pred_logits)
|
||||
pred_presence = (pred_probs >= 0.5).float()
|
||||
correct_isotopes += (pred_presence == presence).sum().item()
|
||||
total_isotopes += presence.numel()
|
||||
|
||||
# Mark compute time and restart timing for data loading
|
||||
compute_time += time.time() - compute_start
|
||||
data_start = time.time()
|
||||
|
||||
train_accuracy = correct_isotopes / total_isotopes if total_isotopes > 0 else 0.0
|
||||
|
||||
return {
|
||||
'train_loss': total_loss / num_batches,
|
||||
'train_cls_loss': cls_loss_sum / num_batches,
|
||||
'train_reg_loss': reg_loss_sum / num_batches,
|
||||
'train_accuracy': train_accuracy,
|
||||
'data_time': data_time,
|
||||
'compute_time': compute_time
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self, val_loader) -> Dict[str, float]:
|
||||
"""Validate the model with comprehensive metrics."""
|
||||
if val_loader is None:
|
||||
return {}
|
||||
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
cls_loss_sum = 0.0
|
||||
reg_loss_sum = 0.0
|
||||
num_batches = 0
|
||||
|
||||
# Collect all predictions and labels for sklearn metrics
|
||||
all_probs = []
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
for batch in val_loader:
|
||||
spectra = batch['spectrum'].to(self.device)
|
||||
presence = batch['presence_labels'].to(self.device)
|
||||
activities = batch['activity_labels'].to(self.device)
|
||||
|
||||
pred_logits, pred_activities = self.model(spectra)
|
||||
loss, loss_dict = self.loss_fn(
|
||||
pred_logits, pred_activities, presence, activities
|
||||
)
|
||||
|
||||
total_loss += loss_dict['total']
|
||||
cls_loss_sum += loss_dict['classification']
|
||||
reg_loss_sum += loss_dict['regression']
|
||||
num_batches += 1
|
||||
|
||||
# Collect predictions for metrics
|
||||
pred_probs = torch.sigmoid(pred_logits)
|
||||
pred_presence = (pred_probs >= 0.5).float()
|
||||
|
||||
all_probs.append(pred_probs.cpu().numpy())
|
||||
all_preds.append(pred_presence.cpu().numpy())
|
||||
all_labels.append(presence.cpu().numpy())
|
||||
|
||||
# Concatenate all batches
|
||||
all_probs = np.vstack(all_probs)
|
||||
all_preds = np.vstack(all_preds)
|
||||
all_labels = np.vstack(all_labels)
|
||||
|
||||
# Basic accuracy (element-wise)
|
||||
correct = (all_preds == all_labels).sum()
|
||||
total = all_labels.size
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
# Multi-label metrics using sklearn
|
||||
metrics = {
|
||||
'val_loss': total_loss / num_batches,
|
||||
'val_cls_loss': cls_loss_sum / num_batches,
|
||||
'val_reg_loss': reg_loss_sum / num_batches,
|
||||
'val_accuracy': accuracy,
|
||||
}
|
||||
|
||||
try:
|
||||
# ROC-AUC (macro-averaged over isotopes with both classes present)
|
||||
# Only compute for columns that have both 0s and 1s
|
||||
valid_cols = []
|
||||
for i in range(all_labels.shape[1]):
|
||||
if len(np.unique(all_labels[:, i])) == 2:
|
||||
valid_cols.append(i)
|
||||
|
||||
if valid_cols:
|
||||
auc_macro = roc_auc_score(
|
||||
all_labels[:, valid_cols],
|
||||
all_probs[:, valid_cols],
|
||||
average='macro'
|
||||
)
|
||||
auc_micro = roc_auc_score(
|
||||
all_labels[:, valid_cols],
|
||||
all_probs[:, valid_cols],
|
||||
average='micro'
|
||||
)
|
||||
metrics['val_auc_macro'] = auc_macro
|
||||
metrics['val_auc_micro'] = auc_micro
|
||||
else:
|
||||
metrics['val_auc_macro'] = 0.0
|
||||
metrics['val_auc_micro'] = 0.0
|
||||
|
||||
except ValueError:
|
||||
# Handle case where AUC can't be computed
|
||||
metrics['val_auc_macro'] = 0.0
|
||||
metrics['val_auc_micro'] = 0.0
|
||||
|
||||
# F1, Precision, Recall (samples-averaged for multi-label)
|
||||
metrics['val_f1_macro'] = f1_score(all_labels, all_preds, average='macro', zero_division=0)
|
||||
metrics['val_f1_micro'] = f1_score(all_labels, all_preds, average='micro', zero_division=0)
|
||||
metrics['val_precision'] = precision_score(all_labels, all_preds, average='micro', zero_division=0)
|
||||
metrics['val_recall'] = recall_score(all_labels, all_preds, average='micro', zero_division=0)
|
||||
|
||||
# Hamming loss (fraction of labels incorrectly predicted)
|
||||
metrics['val_hamming'] = hamming_loss(all_labels, all_preds)
|
||||
|
||||
# Exact match ratio (all isotopes correct for a sample)
|
||||
exact_matches = (all_preds == all_labels).all(axis=1).sum()
|
||||
metrics['val_exact_match'] = exact_matches / len(all_labels)
|
||||
|
||||
return metrics
|
||||
|
||||
def save_checkpoint(self, path: Path, is_best: bool = False):
|
||||
"""Save a model checkpoint."""
|
||||
checkpoint = {
|
||||
'epoch': self.current_epoch,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'best_val_loss': self.best_val_loss,
|
||||
'model_config': asdict(self.model.config),
|
||||
'training_config': asdict(self.config),
|
||||
'training_history': self.training_history
|
||||
}
|
||||
|
||||
if self.scaler is not None:
|
||||
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
||||
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
if is_best:
|
||||
best_path = path.parent / f"{self.config.model_name}_best.pt"
|
||||
torch.save(checkpoint, best_path)
|
||||
|
||||
def load_checkpoint(self, path: Path):
|
||||
"""Load a model checkpoint."""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
if self.scaler is not None and 'scaler_state_dict' in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||
|
||||
self.current_epoch = checkpoint['epoch']
|
||||
self.best_val_loss = checkpoint['best_val_loss']
|
||||
self.training_history = checkpoint.get('training_history', [])
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_loader,
|
||||
val_loader,
|
||||
resume_from: Optional[Path] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Full training loop.
|
||||
|
||||
Args:
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
resume_from: Optional path to checkpoint to resume from
|
||||
|
||||
Returns:
|
||||
Training results dictionary
|
||||
"""
|
||||
if resume_from is not None:
|
||||
self.load_checkpoint(resume_from)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting Vega Training")
|
||||
print("=" * 60)
|
||||
print(f"Epochs: {self.config.num_epochs}")
|
||||
print(f"Batch size: {self.config.batch_size}")
|
||||
print(f"Learning rate: {self.config.learning_rate}")
|
||||
print(f"Training samples: {len(train_loader.dataset)}")
|
||||
if val_loader:
|
||||
print(f"Validation samples: {len(val_loader.dataset)}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.current_epoch, self.config.num_epochs):
|
||||
self.current_epoch = epoch
|
||||
epoch_start = time.time()
|
||||
|
||||
# Train
|
||||
train_metrics = self.train_epoch(train_loader)
|
||||
|
||||
# Validate
|
||||
val_metrics = self.validate(val_loader)
|
||||
|
||||
# Combine metrics
|
||||
metrics = {**train_metrics, **val_metrics, 'epoch': epoch}
|
||||
self.training_history.append(metrics)
|
||||
|
||||
# Update learning rate
|
||||
if val_loader and 'val_loss' in val_metrics:
|
||||
self.scheduler.step(val_metrics['val_loss'])
|
||||
else:
|
||||
self.scheduler.step(train_metrics['train_loss'])
|
||||
|
||||
# Check for improvement
|
||||
val_loss = val_metrics.get('val_loss', train_metrics['train_loss'])
|
||||
is_best = val_loss < self.best_val_loss - self.config.min_delta
|
||||
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.epochs_without_improvement = 0
|
||||
else:
|
||||
self.epochs_without_improvement += 1
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = self.model_dir / f"{self.config.model_name}_epoch_{epoch}.pt"
|
||||
self.save_checkpoint(checkpoint_path, is_best=is_best)
|
||||
|
||||
# Logging
|
||||
epoch_time = time.time() - epoch_start
|
||||
lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
# Primary metrics line
|
||||
log_str = (
|
||||
f"Epoch {epoch+1:3d}/{self.config.num_epochs} | "
|
||||
f"Train Loss: {train_metrics['train_loss']:.4f} | "
|
||||
f"Train Acc: {train_metrics['train_accuracy']:.4f} | "
|
||||
)
|
||||
if val_loader:
|
||||
log_str += (
|
||||
f"Val Loss: {val_metrics['val_loss']:.4f} | "
|
||||
f"Val Acc: {val_metrics['val_accuracy']:.4f} | "
|
||||
)
|
||||
log_str += f"LR: {lr:.2e} | Time: {epoch_time:.1f}s"
|
||||
|
||||
if is_best:
|
||||
log_str += " *"
|
||||
|
||||
print(log_str)
|
||||
|
||||
# Timing breakdown line
|
||||
data_t = train_metrics.get('data_time', 0)
|
||||
compute_t = train_metrics.get('compute_time', 0)
|
||||
if data_t > 0 or compute_t > 0:
|
||||
data_pct = 100 * data_t / (data_t + compute_t) if (data_t + compute_t) > 0 else 0
|
||||
print(f" └── Data: {data_t:.1f}s ({data_pct:.0f}%) | Compute: {compute_t:.1f}s ({100-data_pct:.0f}%)")
|
||||
|
||||
# Secondary metrics line (detailed classification metrics)
|
||||
if val_loader and 'val_auc_macro' in val_metrics:
|
||||
detail_str = (
|
||||
f" └── AUC: {val_metrics['val_auc_macro']:.4f} | "
|
||||
f"F1: {val_metrics['val_f1_macro']:.4f} | "
|
||||
f"Prec: {val_metrics['val_precision']:.4f} | "
|
||||
f"Recall: {val_metrics['val_recall']:.4f} | "
|
||||
f"Exact: {val_metrics['val_exact_match']:.4f}"
|
||||
)
|
||||
print(detail_str)
|
||||
|
||||
# Early stopping
|
||||
if self.epochs_without_improvement >= self.config.patience:
|
||||
print(f"\nEarly stopping after {epoch + 1} epochs")
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Save final model
|
||||
final_path = self.model_dir / f"{self.config.model_name}_final.pt"
|
||||
self.save_checkpoint(final_path)
|
||||
|
||||
# Save training history
|
||||
history_path = self.model_dir / f"{self.config.model_name}_history.json"
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.training_history, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Training complete!")
|
||||
print(f"Total time: {total_time / 60:.1f} minutes")
|
||||
print(f"Best validation loss: {self.best_val_loss:.4f}")
|
||||
print(f"Model saved to: {final_path}")
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
'best_val_loss': self.best_val_loss,
|
||||
'total_epochs': self.current_epoch + 1,
|
||||
'total_time': total_time,
|
||||
'history': self.training_history
|
||||
}
|
||||
|
||||
|
||||
def train_vega(
|
||||
data_dir: Optional[str] = None,
|
||||
model_dir: Optional[str] = None,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
model_config: Optional[VegaConfig] = None
|
||||
) -> Tuple[VegaModel, Dict]:
|
||||
"""
|
||||
Convenience function to train a Vega model.
|
||||
|
||||
Args:
|
||||
data_dir: Path to data directory
|
||||
model_dir: Path to save models
|
||||
config: Training configuration
|
||||
model_config: Model configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (trained model, training results)
|
||||
"""
|
||||
# Setup paths
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
if config is None:
|
||||
config = TrainingConfig()
|
||||
|
||||
if data_dir:
|
||||
config.data_dir = data_dir
|
||||
if model_dir:
|
||||
config.model_dir = model_dir
|
||||
|
||||
# Make paths absolute
|
||||
data_path = Path(config.data_dir)
|
||||
if not data_path.is_absolute():
|
||||
data_path = project_root / data_path
|
||||
|
||||
model_path = Path(config.model_dir)
|
||||
if not model_path.is_absolute():
|
||||
model_path = project_root / model_path
|
||||
|
||||
config.model_dir = str(model_path)
|
||||
|
||||
# Set random seeds
|
||||
torch.manual_seed(config.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(config.seed)
|
||||
|
||||
# Get isotope index
|
||||
isotope_index = get_default_isotope_index()
|
||||
|
||||
# Create data loaders with parallel loading
|
||||
train_loader, val_loader, test_loader = create_data_loaders(
|
||||
data_dir=data_path,
|
||||
batch_size=config.batch_size,
|
||||
train_split=config.train_split,
|
||||
val_split=config.val_split,
|
||||
test_split=config.test_split,
|
||||
num_workers=config.num_workers,
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
persistent_workers=config.persistent_workers,
|
||||
isotope_index=isotope_index,
|
||||
max_activity_bq=config.max_activity_bq,
|
||||
seed=config.seed
|
||||
)
|
||||
|
||||
# Create model
|
||||
if model_config is None:
|
||||
model_config = VegaConfig(
|
||||
num_isotopes=isotope_index.num_isotopes,
|
||||
max_activity_bq=config.max_activity_bq
|
||||
)
|
||||
|
||||
model = VegaModel(model_config)
|
||||
print(model.summary())
|
||||
|
||||
# Create trainer
|
||||
trainer = VegaTrainer(model, config)
|
||||
|
||||
# Train
|
||||
results = trainer.train(train_loader, val_loader)
|
||||
|
||||
# Save isotope index with model
|
||||
index_path = model_path / f"{config.model_name}_isotope_index.txt"
|
||||
isotope_index.save(index_path)
|
||||
|
||||
return model, results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick test training
|
||||
model, results = train_vega()
|
||||
Reference in New Issue
Block a user