Pipeline complet Radiacode 103 - identification automatique d'isotopes

- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89%
- Generation 50k spectres synthetiques 1D (12-24h durees)
- Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell)
- Detection continue avec soustraction du background
- Capture background 24h avec gestion deconnexion
- Docker Compose : conteneur train (GPU) + detect (CPU/USB)
- Modele entraite inclus (vega_best.pt, 395 Mo)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jacquin Antoine
2026-05-19 12:29:56 +02:00
commit 745a64b342
52 changed files with 17558 additions and 0 deletions

View File

@ -0,0 +1,26 @@
"""
Vega Model - CNN-FCNN with Multi-Task Heads for Gamma Spectrum Isotope Identification
Architecture based on research findings from:
- Wang et al. (2026): CNN-FCNN achieves 99.8% accuracy
- Galib et al. (2021): Hybrid CNN outperforms pure architectures
- Turner et al. (2021): 1D CNN robust to gain shifts and shielding
Features:
- 1D CNN backbone for spectral feature extraction
- Multi-task heads for isotope classification + activity regression
- Support for 82 isotopes from the synthetic spectra database
"""
from .model import VegaModel, VegaConfig
from .dataset import SpectrumDataset, create_data_loaders
from .train import train_vega, VegaTrainer
__all__ = [
'VegaModel',
'VegaConfig',
'SpectrumDataset',
'create_data_loaders',
'train_vega',
'VegaTrainer'
]

View File

@ -0,0 +1,373 @@
"""
Dataset and DataLoader for Vega Model Training
Handles loading synthetic gamma spectra from numpy files and converting
them to PyTorch tensors with proper labels for multi-task learning.
Supports two label formats:
1. Individual JSON files per sample (recommended for large datasets)
2. Combined labels.json file (legacy format)
"""
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from .isotope_index import IsotopeIndex, get_default_isotope_index
@dataclass
class SpectrumSample:
"""A single spectrum sample with metadata."""
sample_id: str
spectrum: np.ndarray # 2D array (time_intervals, channels) or 1D (channels,)
isotopes_present: List[str]
activities_bq: Dict[str, float]
duration_seconds: float
detector: str
class SpectrumDataset(Dataset):
"""
PyTorch Dataset for synthetic gamma spectra.
Loads spectra from numpy files and their labels from JSON files.
Supports both individual JSON files per sample (efficient for large datasets)
and combined labels.json (legacy format).
Converts to tensors suitable for the Vega model.
"""
def __init__(
self,
data_dir: Path,
isotope_index: Optional[IsotopeIndex] = None,
max_activity_bq: float = 1000.0,
collapse_time: bool = True,
transform=None
):
"""
Initialize the dataset.
Args:
data_dir: Path to directory containing spectra/ subdirectory
isotope_index: Index mapping isotope names to indices
max_activity_bq: Maximum activity for normalization
collapse_time: If True, average across time dimension to get 1D spectrum
transform: Optional transform to apply to spectra
"""
self.data_dir = Path(data_dir)
self.spectra_dir = self.data_dir / "spectra"
self.isotope_index = isotope_index or get_default_isotope_index()
self.max_activity_bq = max_activity_bq
self.collapse_time = collapse_time
self.transform = transform
# Detect label format and load sample list
self.use_individual_labels = self._detect_label_format()
if self.use_individual_labels:
# Scan for individual JSON files (efficient - no loading needed)
self.sample_ids = self._scan_for_samples()
self.metadata = None # Labels loaded on-demand
print(f"Using individual label files (efficient mode)")
else:
# Load combined labels.json (legacy mode)
self.metadata = self._load_metadata()
self.sample_ids = list(self.metadata['samples'].keys())
print(f"Using combined labels.json (legacy mode)")
print(f"Loaded dataset with {len(self.sample_ids)} samples")
print(f"Isotope index has {self.isotope_index.num_isotopes} isotopes")
def _detect_label_format(self) -> bool:
"""Detect whether to use individual JSON files or combined labels.json."""
# Check if individual JSON files exist
json_files = list(self.spectra_dir.glob("spectrum_*.json"))
if len(json_files) > 0:
return True
# Fall back to combined labels.json
labels_path = self.data_dir / "labels.json"
if labels_path.exists():
return False
raise FileNotFoundError(
f"No label files found. Expected either:\n"
f" - Individual files: {self.spectra_dir}/spectrum_*.json\n"
f" - Combined file: {self.data_dir}/labels.json"
)
def _scan_for_samples(self) -> List[str]:
"""Scan directory for sample IDs based on .npy files."""
npy_files = sorted(self.spectra_dir.glob("spectrum_*.npy"))
sample_ids = []
for npy_path in npy_files:
# Extract sample ID from filename: spectrum_{id}.npy
filename = npy_path.stem # spectrum_{id}
sample_id = filename.replace("spectrum_", "")
sample_ids.append(sample_id)
return sample_ids
def _load_metadata(self) -> Dict:
"""Load the combined labels.json metadata file (legacy format)."""
labels_path = self.data_dir / "labels.json"
if not labels_path.exists():
raise FileNotFoundError(f"Labels file not found: {labels_path}")
with open(labels_path, 'r') as f:
return json.load(f)
def _load_sample_label(self, sample_id: str) -> Dict:
"""Load label for a single sample (individual JSON or from combined)."""
if self.use_individual_labels:
json_path = self.spectra_dir / f"spectrum_{sample_id}.json"
with open(json_path, 'r') as f:
return json.load(f)
else:
return self.metadata['samples'][sample_id]
def __len__(self) -> int:
return len(self.sample_ids)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Get a single sample.
Returns:
Dictionary containing:
- spectrum: Tensor of shape (num_channels,)
- presence_labels: Binary tensor (num_isotopes,) indicating presence
- activity_labels: Tensor (num_isotopes,) with normalized activities
- sample_id: String identifier
"""
sample_id = self.sample_ids[idx]
sample_meta = self._load_sample_label(sample_id)
# Load spectrum
spectrum_path = self.spectra_dir / f"spectrum_{sample_id}.npy"
spectrum = np.load(spectrum_path)
# Collapse time dimension if needed
if self.collapse_time and spectrum.ndim == 2:
# Average across time intervals to get single spectrum
spectrum = spectrum.mean(axis=0)
# Convert to tensor
spectrum_tensor = torch.tensor(spectrum, dtype=torch.float32)
# Apply transform if provided
if self.transform:
spectrum_tensor = self.transform(spectrum_tensor)
# Create presence labels
presence_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
for isotope_name in sample_meta['isotopes']:
try:
idx_isotope = self.isotope_index.name_to_index(isotope_name)
presence_labels[idx_isotope] = 1.0
except KeyError:
# Isotope not in our index (might be a decay product)
pass
# Create activity labels (normalized)
activity_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
for isotope_name, activity in sample_meta.get('source_activities_bq', {}).items():
try:
idx_isotope = self.isotope_index.name_to_index(isotope_name)
# Normalize activity to [0, 1] range
activity_labels[idx_isotope] = min(activity / self.max_activity_bq, 1.0)
except KeyError:
pass
return {
'spectrum': spectrum_tensor,
'presence_labels': presence_labels,
'activity_labels': activity_labels,
'sample_id': sample_id
}
def get_sample_info(self, idx: int) -> Dict:
"""Get metadata for a sample without loading the spectrum."""
sample_id = self.sample_ids[idx]
return {
'sample_id': sample_id,
**self.metadata['samples'][sample_id]
}
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""
Custom collate function to handle batching.
Args:
batch: List of sample dictionaries
Returns:
Batched dictionary with stacked tensors
"""
return {
'spectrum': torch.stack([s['spectrum'] for s in batch]),
'presence_labels': torch.stack([s['presence_labels'] for s in batch]),
'activity_labels': torch.stack([s['activity_labels'] for s in batch]),
'sample_ids': [s['sample_id'] for s in batch]
}
def create_data_loaders(
data_dir: Path,
batch_size: int = 32,
train_split: float = 0.8,
val_split: float = 0.1,
test_split: float = 0.1,
num_workers: int = 8,
prefetch_factor: int = 4,
persistent_workers: bool = True,
isotope_index: Optional[IsotopeIndex] = None,
max_activity_bq: float = 1000.0,
seed: int = 42
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Create train, validation, and test data loaders.
Args:
data_dir: Path to data directory
batch_size: Batch size for training
train_split: Fraction of data for training
val_split: Fraction of data for validation
test_split: Fraction of data for testing
num_workers: Number of data loading workers (parallel I/O)
prefetch_factor: Batches to prefetch per worker
persistent_workers: Keep workers alive between epochs
isotope_index: Isotope name to index mapping
max_activity_bq: Maximum activity for normalization
seed: Random seed for reproducibility
Returns:
Tuple of (train_loader, val_loader, test_loader)
"""
assert abs(train_split + val_split + test_split - 1.0) < 1e-6, \
"Splits must sum to 1.0"
# Create full dataset
full_dataset = SpectrumDataset(
data_dir=data_dir,
isotope_index=isotope_index,
max_activity_bq=max_activity_bq
)
# Calculate split sizes
total_size = len(full_dataset)
train_size = int(total_size * train_split)
val_size = int(total_size * val_split)
test_size = total_size - train_size - val_size
# Handle small datasets
if train_size == 0:
train_size = max(1, total_size - 2)
if val_size == 0 and total_size > 1:
val_size = 1
train_size = max(1, train_size - 1)
if test_size == 0 and total_size > 2:
test_size = 1
train_size = max(1, train_size - 1)
# Ensure sizes add up
test_size = total_size - train_size - val_size
print(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
# Split dataset
generator = torch.Generator().manual_seed(seed)
train_dataset, val_dataset, test_dataset = random_split(
full_dataset,
[train_size, val_size, test_size],
generator=generator
)
# Create data loaders with parallel loading support
# For Windows, num_workers > 0 requires spawn method (handled by PyTorch)
use_workers = num_workers > 0
train_loader = DataLoader(
train_dataset,
batch_size=min(batch_size, train_size),
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
prefetch_factor=prefetch_factor if use_workers else None,
persistent_workers=persistent_workers and use_workers,
drop_last=True # Drop incomplete batches for consistent training
)
val_loader = DataLoader(
val_dataset,
batch_size=min(batch_size, max(1, val_size)),
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
prefetch_factor=prefetch_factor if use_workers else None,
persistent_workers=persistent_workers and use_workers
) if val_size > 0 else None
test_loader = DataLoader(
test_dataset,
batch_size=min(batch_size, max(1, test_size)),
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
prefetch_factor=prefetch_factor if use_workers else None,
persistent_workers=persistent_workers and use_workers
) if test_size > 0 else None
if num_workers > 0:
print(f"DataLoader: {num_workers} workers, prefetch_factor={prefetch_factor}, persistent={persistent_workers}")
return train_loader, val_loader, test_loader
if __name__ == "__main__":
import sys
# Test dataset loading
data_dir = Path(__file__).parent.parent.parent / "data" / "synthetic"
if not data_dir.exists():
print(f"Data directory not found: {data_dir}")
sys.exit(1)
# Create dataset
dataset = SpectrumDataset(data_dir)
print(f"\nDataset size: {len(dataset)}")
# Get a sample
sample = dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Spectrum shape: {sample['spectrum'].shape}")
print(f"Presence labels shape: {sample['presence_labels'].shape}")
print(f"Activity labels shape: {sample['activity_labels'].shape}")
print(f"Presence sum: {sample['presence_labels'].sum().item()}")
# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(
data_dir,
batch_size=4
)
print(f"\nTrain batches: {len(train_loader)}")
if val_loader:
print(f"Val batches: {len(val_loader)}")
if test_loader:
print(f"Test batches: {len(test_loader)}")
# Test a batch
batch = next(iter(train_loader))
print(f"\nBatch spectrum shape: {batch['spectrum'].shape}")
print(f"Batch presence shape: {batch['presence_labels'].shape}")

View File

@ -0,0 +1,308 @@
"""
Dataset for 2D Vega Model
Loads 2D spectra (time × channels) and pads/truncates to fixed dimensions.
"""
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from .isotope_index import IsotopeIndex, get_default_isotope_index
@dataclass
class SpectrumSample2D:
"""A single 2D spectrum sample."""
sample_id: str
spectrum: np.ndarray # 2D array (time_intervals, channels)
isotopes_present: List[str]
activities_bq: Dict[str, float]
duration_seconds: float
detector: str
class SpectrumDataset2D(Dataset):
"""
PyTorch Dataset for 2D gamma spectra.
Pads or truncates time dimension to fixed size for batch processing.
"""
def __init__(
self,
data_dir: Path,
isotope_index: Optional[IsotopeIndex] = None,
max_activity_bq: float = 1000.0,
target_time_intervals: int = 60,
transform=None
):
"""
Initialize the dataset.
Args:
data_dir: Path to directory containing spectra/ subdirectory
isotope_index: Index mapping isotope names to indices
max_activity_bq: Maximum activity for normalization
target_time_intervals: Fixed time dimension (pad/truncate to this)
transform: Optional transform to apply
"""
self.data_dir = Path(data_dir)
self.spectra_dir = self.data_dir / "spectra"
self.isotope_index = isotope_index or get_default_isotope_index()
self.max_activity_bq = max_activity_bq
self.target_time_intervals = target_time_intervals
self.transform = transform
# Detect label format and load sample list
self.use_individual_labels = self._detect_label_format()
if self.use_individual_labels:
self.sample_ids = self._scan_for_samples()
self.metadata = None
print(f"Using individual label files (efficient mode)")
else:
self.metadata = self._load_metadata()
self.sample_ids = list(self.metadata['samples'].keys())
print(f"Using combined labels.json (legacy mode)")
print(f"Loaded 2D dataset with {len(self.sample_ids)} samples")
print(f"Target shape: ({target_time_intervals}, 1023)")
print(f"Isotope index has {self.isotope_index.num_isotopes} isotopes")
def _detect_label_format(self) -> bool:
"""Detect whether to use individual JSON files or combined labels.json."""
json_files = list(self.spectra_dir.glob("spectrum_*.json"))
if len(json_files) > 0:
return True
labels_path = self.data_dir / "labels.json"
if labels_path.exists():
return False
raise FileNotFoundError(
f"No label files found. Expected either:\n"
f" - Individual files: {self.spectra_dir}/spectrum_*.json\n"
f" - Combined file: {self.data_dir}/labels.json"
)
def _scan_for_samples(self) -> List[str]:
"""Scan directory for sample IDs based on .npy files."""
npy_files = sorted(self.spectra_dir.glob("spectrum_*.npy"))
sample_ids = []
for npy_path in npy_files:
filename = npy_path.stem
sample_id = filename.replace("spectrum_", "")
sample_ids.append(sample_id)
return sample_ids
def _load_metadata(self) -> Dict:
"""Load the combined labels.json metadata file."""
labels_path = self.data_dir / "labels.json"
if not labels_path.exists():
raise FileNotFoundError(f"Labels file not found: {labels_path}")
with open(labels_path, 'r') as f:
return json.load(f)
def _load_sample_label(self, sample_id: str) -> Dict:
"""Load label for a single sample."""
if self.use_individual_labels:
json_path = self.spectra_dir / f"spectrum_{sample_id}.json"
with open(json_path, 'r') as f:
return json.load(f)
else:
return self.metadata['samples'][sample_id]
def _pad_or_truncate(self, spectrum: np.ndarray) -> np.ndarray:
"""
Pad or truncate spectrum to target time dimension.
Args:
spectrum: 2D array (time, channels)
Returns:
Array of shape (target_time_intervals, channels)
"""
current_time = spectrum.shape[0]
target_time = self.target_time_intervals
num_channels = spectrum.shape[1]
if current_time == target_time:
return spectrum
elif current_time > target_time:
# Truncate: take evenly spaced intervals to preserve temporal coverage
indices = np.linspace(0, current_time - 1, target_time, dtype=int)
return spectrum[indices, :]
else:
# Pad with zeros at the end
padded = np.zeros((target_time, num_channels), dtype=spectrum.dtype)
padded[:current_time, :] = spectrum
return padded
def __len__(self) -> int:
return len(self.sample_ids)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Get a single sample.
Returns:
Dictionary containing:
- spectrum: Tensor of shape (target_time_intervals, num_channels)
- presence_labels: Binary tensor (num_isotopes,)
- activity_labels: Tensor (num_isotopes,) with normalized activities
- sample_id: String identifier
"""
sample_id = self.sample_ids[idx]
sample_meta = self._load_sample_label(sample_id)
# Load spectrum
spectrum_path = self.spectra_dir / f"spectrum_{sample_id}.npy"
spectrum = np.load(spectrum_path)
# Ensure 2D
if spectrum.ndim == 1:
spectrum = spectrum.reshape(1, -1)
# Pad/truncate to fixed time dimension
spectrum = self._pad_or_truncate(spectrum)
# Normalize (max normalization)
max_val = spectrum.max()
if max_val > 0:
spectrum = spectrum / max_val
# Convert to tensor
spectrum_tensor = torch.tensor(spectrum, dtype=torch.float32)
# Apply transform if provided
if self.transform:
spectrum_tensor = self.transform(spectrum_tensor)
# Create presence labels
presence_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
for isotope_name in sample_meta['isotopes']:
try:
idx_isotope = self.isotope_index.name_to_index(isotope_name)
presence_labels[idx_isotope] = 1.0
except KeyError:
pass
# Create activity labels (normalized)
activity_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
for isotope_name, activity in sample_meta.get('source_activities_bq', {}).items():
try:
idx_isotope = self.isotope_index.name_to_index(isotope_name)
activity_labels[idx_isotope] = min(activity / self.max_activity_bq, 1.0)
except KeyError:
pass
return {
'spectrum': spectrum_tensor,
'presence_labels': presence_labels,
'activity_labels': activity_labels,
'sample_id': sample_id
}
def collate_fn_2d(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Custom collate function for 2D batching."""
return {
'spectrum': torch.stack([s['spectrum'] for s in batch]),
'presence_labels': torch.stack([s['presence_labels'] for s in batch]),
'activity_labels': torch.stack([s['activity_labels'] for s in batch]),
'sample_ids': [s['sample_id'] for s in batch]
}
def create_data_loaders_2d(
data_dir: Path,
batch_size: int = 32,
train_split: float = 0.8,
val_split: float = 0.1,
test_split: float = 0.1,
num_workers: int = 4,
target_time_intervals: int = 60,
isotope_index: Optional[IsotopeIndex] = None,
max_activity_bq: float = 1000.0,
seed: int = 42
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Create train, validation, and test data loaders for 2D data.
"""
# Create full dataset
dataset = SpectrumDataset2D(
data_dir=data_dir,
isotope_index=isotope_index,
max_activity_bq=max_activity_bq,
target_time_intervals=target_time_intervals
)
# Calculate split sizes
total = len(dataset)
train_size = int(total * train_split)
val_size = int(total * val_split)
test_size = total - train_size - val_size
# Split dataset
generator = torch.Generator().manual_seed(seed)
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size], generator=generator
)
print(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
# Create loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn_2d,
pin_memory=True,
persistent_workers=num_workers > 0
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn_2d,
pin_memory=True,
persistent_workers=num_workers > 0
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn_2d,
pin_memory=True,
persistent_workers=num_workers > 0
)
return train_loader, val_loader, test_loader
if __name__ == "__main__":
# Test the dataset
from pathlib import Path
data_dir = Path("O:/master_data_collection/isotopev2")
dataset = SpectrumDataset2D(data_dir, target_time_intervals=60)
sample = dataset[0]
print(f"\nSample:")
print(f" Spectrum shape: {sample['spectrum'].shape}")
print(f" Presence labels: {sample['presence_labels'].sum().item():.0f} isotopes")
print(f" Sample ID: {sample['sample_id']}")

View File

@ -0,0 +1,141 @@
"""
Isotope Index - Mapping between isotope names and model output indices.
This module provides a consistent mapping between isotope names and their
corresponding indices in the model's output tensors. This is critical for
training and inference to ensure consistent label encoding.
"""
import sys
from pathlib import Path
from typing import Dict, List, Optional
# Add project root to path for imports
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from synthetic_spectra.ground_truth.isotope_data import ISOTOPE_DATABASE, get_isotope_names
class IsotopeIndex:
"""
Manages the mapping between isotope names and model indices.
The index is deterministic - isotopes are sorted alphabetically to ensure
consistent ordering across training and inference.
"""
def __init__(self, isotope_names: Optional[List[str]] = None):
"""
Initialize the isotope index.
Args:
isotope_names: Optional list of isotope names. If None, uses all
isotopes from the database.
"""
if isotope_names is None:
isotope_names = get_isotope_names()
# Sort alphabetically for deterministic ordering
self._isotope_names = sorted(isotope_names)
# Build bidirectional mappings
self._name_to_idx: Dict[str, int] = {
name: idx for idx, name in enumerate(self._isotope_names)
}
self._idx_to_name: Dict[int, str] = {
idx: name for idx, name in enumerate(self._isotope_names)
}
@property
def num_isotopes(self) -> int:
"""Total number of isotopes in the index."""
return len(self._isotope_names)
@property
def isotope_names(self) -> List[str]:
"""List of all isotope names in index order."""
return self._isotope_names.copy()
def name_to_index(self, name: str) -> int:
"""
Get the index for an isotope name.
Args:
name: Isotope name (e.g., "Cs-137")
Returns:
Integer index for the isotope
Raises:
KeyError: If isotope name not in index
"""
if name not in self._name_to_idx:
raise KeyError(f"Isotope '{name}' not found in index. "
f"Available isotopes: {self._isotope_names[:5]}...")
return self._name_to_idx[name]
def index_to_name(self, idx: int) -> str:
"""
Get the isotope name for an index.
Args:
idx: Integer index
Returns:
Isotope name string
Raises:
KeyError: If index out of range
"""
if idx not in self._idx_to_name:
raise KeyError(f"Index {idx} out of range. Valid range: 0-{self.num_isotopes-1}")
return self._idx_to_name[idx]
def names_to_indices(self, names: List[str]) -> List[int]:
"""Convert list of names to list of indices."""
return [self.name_to_index(name) for name in names]
def indices_to_names(self, indices: List[int]) -> List[str]:
"""Convert list of indices to list of names."""
return [self.index_to_name(idx) for idx in indices]
def save(self, path: Path):
"""Save the isotope index to a file."""
with open(path, 'w') as f:
for name in self._isotope_names:
f.write(f"{name}\n")
@classmethod
def load(cls, path: Path) -> 'IsotopeIndex':
"""Load an isotope index from a file."""
with open(path, 'r') as f:
isotope_names = [line.strip() for line in f if line.strip()]
return cls(isotope_names)
def __repr__(self) -> str:
return f"IsotopeIndex(num_isotopes={self.num_isotopes})"
def __len__(self) -> int:
return self.num_isotopes
# Global default isotope index using all isotopes from database
DEFAULT_ISOTOPE_INDEX = IsotopeIndex()
def get_default_isotope_index() -> IsotopeIndex:
"""Get the default isotope index with all database isotopes."""
return DEFAULT_ISOTOPE_INDEX
if __name__ == "__main__":
# Print isotope index information
index = get_default_isotope_index()
print(f"Isotope Index: {index}")
print(f"\nFirst 10 isotopes:")
for i in range(min(10, index.num_isotopes)):
print(f" {i:3d}: {index.index_to_name(i)}")
print(f"\nLast 10 isotopes:")
for i in range(max(0, index.num_isotopes - 10), index.num_isotopes):
print(f" {i:3d}: {index.index_to_name(i)}")

View File

@ -0,0 +1,416 @@
"""
Vega Model Architecture - CNN-FCNN with Multi-Task Heads
A hybrid Convolutional Neural Network with Fully Connected Neural Network
for gamma spectrum isotope identification. Based on peer-reviewed research
showing CNN-FCNN achieves state-of-the-art performance (99%+ accuracy).
Architecture:
Input: 1D gamma spectrum (1023 channels, 20-3000 keV)
Feature Extraction: 3 CNN modules with LeakyReLU, MaxPool, Dropout
Classification Head: Dense layers → Sigmoid (multi-label isotope presence)
Regression Head: Dense layers → ReLU (activity estimation in Bq)
"""
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
@dataclass
class VegaConfig:
"""Configuration for the Vega model."""
# Input configuration
num_channels: int = 1023 # Number of energy channels in spectrum
# Number of isotopes to classify
num_isotopes: int = 82 # From isotope database
# CNN backbone configuration
conv_channels: List[int] = field(default_factory=lambda: [64, 128, 256])
conv_kernel_size: int = 7
pool_size: int = 2
# Classification head configuration
fc_hidden_dims: List[int] = field(default_factory=lambda: [512, 256])
# Regularization
dropout_rate: float = 0.3
spatial_dropout_rate: float = 0.1
# Activation
leaky_relu_slope: float = 0.1
# Loss weighting
classification_weight: float = 1.0
regression_weight: float = 0.1
# Training
max_activity_bq: float = 1000.0 # For activity normalization
class ConvBlock(nn.Module):
"""
Convolutional block with two conv layers, activation, pooling, and dropout.
Based on Turner et al. (2021) architecture showing that stacking two
convolutions per module with pooling achieves good feature extraction.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 7,
pool_size: int = 2,
dropout_rate: float = 0.1,
leaky_slope: float = 0.1
):
super().__init__()
# First convolution
self.conv1 = nn.Conv1d(
in_channels, out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2
)
self.bn1 = nn.BatchNorm1d(out_channels)
self.act1 = nn.LeakyReLU(leaky_slope)
# Second convolution
self.conv2 = nn.Conv1d(
out_channels, out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2
)
self.bn2 = nn.BatchNorm1d(out_channels)
self.act2 = nn.LeakyReLU(leaky_slope)
# Pooling and dropout
self.pool = nn.MaxPool1d(pool_size)
self.dropout = nn.Dropout1d(dropout_rate) # Spatial dropout for 1D
def forward(self, x: torch.Tensor) -> torch.Tensor:
# First conv block
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
# Second conv block
x = self.conv2(x)
x = self.bn2(x)
x = self.act2(x)
# Pool and dropout
x = self.pool(x)
x = self.dropout(x)
return x
class VegaModel(nn.Module):
"""
Vega: CNN-FCNN Multi-Task Model for Isotope Identification
Named after the bright star Vega (α Lyrae), which emits radiation
across the electromagnetic spectrum - fitting for a gamma spectrum analyzer.
The model performs two tasks:
1. Multi-label classification: Which isotopes are present?
2. Activity regression: What is the activity (Bq) of each isotope?
"""
def __init__(self, config: VegaConfig):
super().__init__()
self.config = config
# Build CNN backbone
self.backbone = self._build_backbone()
# Calculate flattened size after backbone
self._flat_size = self._calculate_flat_size()
# Build classification head (multi-label)
self.classifier = self._build_classifier()
# Build regression head (activity estimation)
self.regressor = self._build_regressor()
# Initialize weights
self._init_weights()
def _build_backbone(self) -> nn.Sequential:
"""Build the CNN feature extraction backbone."""
layers = []
in_channels = 1 # Input is 1D spectrum with 1 channel
for out_channels in self.config.conv_channels:
layers.append(ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.config.conv_kernel_size,
pool_size=self.config.pool_size,
dropout_rate=self.config.spatial_dropout_rate,
leaky_slope=self.config.leaky_relu_slope
))
in_channels = out_channels
return nn.Sequential(*layers)
def _calculate_flat_size(self) -> int:
"""Calculate the size of flattened features after backbone."""
# Create dummy input to calculate size
dummy = torch.zeros(1, 1, self.config.num_channels)
with torch.no_grad():
out = self.backbone(dummy)
return out.numel()
def _build_classifier(self) -> nn.Sequential:
"""Build the classification head for isotope presence prediction.
Outputs raw logits (not probabilities) for AMP compatibility.
Use BCEWithLogitsLoss for training, apply sigmoid during inference.
"""
layers = []
in_features = self._flat_size
# Hidden layers
for hidden_dim in self.config.fc_hidden_dims:
layers.extend([
nn.Linear(in_features, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.LeakyReLU(self.config.leaky_relu_slope),
nn.Dropout(self.config.dropout_rate)
])
in_features = hidden_dim
# Output layer - raw logits for AMP compatibility
layers.append(nn.Linear(in_features, self.config.num_isotopes))
return nn.Sequential(*layers)
def _build_regressor(self) -> nn.Sequential:
"""Build the regression head for activity estimation."""
layers = []
in_features = self._flat_size
# Hidden layers (shared architecture with classifier)
for hidden_dim in self.config.fc_hidden_dims:
layers.extend([
nn.Linear(in_features, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.LeakyReLU(self.config.leaky_relu_slope),
nn.Dropout(self.config.dropout_rate)
])
in_features = hidden_dim
# Output layer with ReLU for non-negative activity values
layers.extend([
nn.Linear(in_features, self.config.num_isotopes),
nn.ReLU() # Activity must be non-negative
])
return nn.Sequential(*layers)
def _init_weights(self):
"""Initialize weights using He initialization for LeakyReLU."""
for module in self.modules():
if isinstance(module, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(
module.weight,
a=self.config.leaky_relu_slope,
mode='fan_out',
nonlinearity='leaky_relu'
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm1d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(
self,
x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass through the model.
Args:
x: Input spectrum tensor of shape (batch, channels) or (batch, 1, channels)
Values should be normalized [0, 1]
Returns:
Tuple of:
- isotope_logits: Raw logits for each isotope (batch, num_isotopes)
Apply sigmoid to get probabilities for inference
- activity_pred: Predicted activity in Bq for each isotope (batch, num_isotopes)
"""
# Ensure input has channel dimension
if x.dim() == 2:
x = x.unsqueeze(1) # (batch, channels) -> (batch, 1, channels)
# Feature extraction
features = self.backbone(x)
features = features.flatten(start_dim=1)
# Classification head (outputs logits)
isotope_logits = self.classifier(features)
# Regression head
activity_pred = self.regressor(features)
return isotope_logits, activity_pred
def predict(
self,
x: torch.Tensor,
threshold: float = 0.5,
return_all: bool = False
) -> Dict:
"""
Make predictions with post-processing.
Args:
x: Input spectrum tensor
threshold: Probability threshold for isotope presence
return_all: If True, return predictions for all isotopes
Returns:
Dictionary with predictions
"""
self.eval()
with torch.no_grad():
probs, activities = self(x)
# Apply threshold
present = probs >= threshold
# Mask activities by presence
masked_activities = activities * present.float()
return {
'probabilities': probs,
'activities_bq': masked_activities * self.config.max_activity_bq,
'present_mask': present,
'threshold': threshold
}
def count_parameters(self) -> int:
"""Count total trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def summary(self) -> str:
"""Get a summary of the model architecture."""
lines = [
"=" * 60,
"VEGA Model - CNN-FCNN Multi-Task Isotope Identifier",
"=" * 60,
f"Input channels: {self.config.num_channels}",
f"Output isotopes: {self.config.num_isotopes}",
f"CNN channels: {self.config.conv_channels}",
f"FC hidden dims: {self.config.fc_hidden_dims}",
f"Dropout rate: {self.config.dropout_rate}",
f"Total parameters: {self.count_parameters():,}",
"=" * 60
]
return "\n".join(lines)
class VegaLoss(nn.Module):
"""
Combined loss function for Vega multi-task learning.
Combines:
- Binary Cross-Entropy for isotope classification (multi-label)
- Huber Loss for activity regression (robust to outliers)
"""
def __init__(
self,
classification_weight: float = 1.0,
regression_weight: float = 0.1,
huber_delta: float = 1.0
):
super().__init__()
self.classification_weight = classification_weight
self.regression_weight = regression_weight
# Use BCEWithLogitsLoss for AMP safety (combines sigmoid + BCE)
self.bce_loss = nn.BCEWithLogitsLoss()
self.huber_loss = nn.HuberLoss(delta=huber_delta)
def forward(
self,
pred_logits: torch.Tensor,
pred_activities: torch.Tensor,
target_presence: torch.Tensor,
target_activities: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Calculate combined loss.
Args:
pred_logits: Predicted isotope logits (batch, num_isotopes)
pred_activities: Predicted activities (batch, num_isotopes)
target_presence: Ground truth presence labels (batch, num_isotopes)
target_activities: Ground truth activities (batch, num_isotopes)
Returns:
Tuple of total loss and dict of individual losses
"""
# Classification loss (BCEWithLogitsLoss applies sigmoid internally)
cls_loss = self.bce_loss(pred_logits, target_presence.float())
# Regression loss (only for present isotopes)
# Mask to only compute loss where isotopes are actually present
mask = target_presence.float()
if mask.sum() > 0:
masked_pred = pred_activities * mask
masked_target = target_activities * mask
reg_loss = self.huber_loss(masked_pred, masked_target)
else:
reg_loss = torch.tensor(0.0, device=pred_activities.device)
# Combined loss
total_loss = (
self.classification_weight * cls_loss +
self.regression_weight * reg_loss
)
loss_dict = {
'total': total_loss.item(),
'classification': cls_loss.item(),
'regression': reg_loss.item() if isinstance(reg_loss, torch.Tensor) else reg_loss
}
return total_loss, loss_dict
if __name__ == "__main__":
# Test the model
config = VegaConfig()
model = VegaModel(config)
print(model.summary())
# Test forward pass
batch_size = 4
x = torch.randn(batch_size, config.num_channels)
probs, activities = model(x)
print(f"\nInput shape: {x.shape}")
print(f"Output probs shape: {probs.shape}")
print(f"Output activities shape: {activities.shape}")
# Test loss
loss_fn = VegaLoss()
target_presence = torch.randint(0, 2, (batch_size, config.num_isotopes))
target_activities = torch.rand(batch_size, config.num_isotopes) * 100
loss, loss_dict = loss_fn(probs, activities, target_presence, target_activities)
print(f"\nLoss: {loss_dict}")

View File

@ -0,0 +1,231 @@
"""
Vega 2D Model - Uses Full Temporal Information
This model treats gamma spectra as 2D images (time × channels) and uses
Conv2d to extract both spectral and temporal features.
Input shape: (batch, 1, time_intervals, channels) = (B, 1, 60, 1023)
"""
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import List, Tuple
@dataclass
class Vega2DConfig:
"""Configuration for Vega 2D model."""
# Input dimensions
num_channels: int = 1023 # Energy channels
num_time_intervals: int = 60 # Fixed time dimension
# Output
num_isotopes: int = 82
# CNN architecture
conv_channels: List[int] = field(default_factory=lambda: [32, 64, 128])
kernel_size: Tuple[int, int] = (3, 7) # (time, energy) - larger in energy dimension
pool_size: Tuple[int, int] = (2, 2)
# FC layers
fc_hidden_dims: List[int] = field(default_factory=lambda: [512, 256])
# Regularization
dropout_rate: float = 0.3
leaky_relu_slope: float = 0.01
# Activity scaling
max_activity_bq: float = 1000.0
class ConvBlock2D(nn.Module):
"""2D Convolutional block with BatchNorm, activation, pooling, and dropout."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
pool_size: Tuple[int, int],
dropout_rate: float,
leaky_relu_slope: float
):
super().__init__()
# Padding to maintain spatial dimensions before pooling
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
self.bn2 = nn.BatchNorm2d(out_channels)
self.activation = nn.LeakyReLU(leaky_relu_slope)
self.pool = nn.MaxPool2d(pool_size)
self.dropout = nn.Dropout2d(dropout_rate)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.activation(self.bn1(self.conv1(x)))
x = self.activation(self.bn2(self.conv2(x)))
x = self.pool(x)
x = self.dropout(x)
return x
class Vega2DModel(nn.Module):
"""
2D CNN model for gamma spectrum isotope identification.
Treats spectra as images with time on one axis and energy channels on the other.
This preserves temporal information that is lost in the 1D approach.
"""
def __init__(self, config: Vega2DConfig = None):
super().__init__()
self.config = config or Vega2DConfig()
# Build CNN backbone
self.conv_blocks = nn.ModuleList()
in_channels = 1 # Single channel input (like grayscale image)
for out_channels in self.config.conv_channels:
self.conv_blocks.append(ConvBlock2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.config.kernel_size,
pool_size=self.config.pool_size,
dropout_rate=self.config.dropout_rate,
leaky_relu_slope=self.config.leaky_relu_slope
))
in_channels = out_channels
# Calculate flattened size after conv blocks
self.flat_size = self._calculate_flat_size()
# Fully connected classifier
fc_layers = []
fc_in = self.flat_size
for fc_out in self.config.fc_hidden_dims:
fc_layers.extend([
nn.Linear(fc_in, fc_out),
nn.BatchNorm1d(fc_out),
nn.LeakyReLU(self.config.leaky_relu_slope),
nn.Dropout(self.config.dropout_rate)
])
fc_in = fc_out
self.fc_backbone = nn.Sequential(*fc_layers)
# Output heads
self.classifier = nn.Linear(fc_in, self.config.num_isotopes) # Logits for BCE
self.regressor = nn.Sequential(
nn.Linear(fc_in, self.config.num_isotopes),
nn.ReLU() # Activity must be non-negative
)
# Initialize weights
self._init_weights()
def _calculate_flat_size(self) -> int:
"""Calculate the flattened size after all conv blocks."""
# Start with input dimensions
h = self.config.num_time_intervals # 60
w = self.config.num_channels # 1023
# Each conv block applies pooling that halves dimensions
for _ in self.config.conv_channels:
h = h // self.config.pool_size[0]
w = w // self.config.pool_size[1]
# Final size = last_channels * h * w
return self.config.conv_channels[-1] * h * w
def _init_weights(self):
"""Initialize weights using Kaiming initialization."""
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
x: Input tensor of shape (batch, 1, time_intervals, channels)
or (batch, time_intervals, channels) - will add channel dim
Returns:
Tuple of (logits, activities):
- logits: (batch, num_isotopes) - raw scores for BCE loss
- activities: (batch, num_isotopes) - predicted activities (normalized 0-1)
"""
# Add channel dimension if needed: (B, T, C) -> (B, 1, T, C)
if x.dim() == 3:
x = x.unsqueeze(1)
# CNN backbone
for conv_block in self.conv_blocks:
x = conv_block(x)
# Flatten
x = x.view(x.size(0), -1)
# FC backbone
x = self.fc_backbone(x)
# Output heads
logits = self.classifier(x)
activities = self.regressor(x)
return logits, activities
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict isotope presence and activities.
Args:
x: Input spectrum
threshold: Probability threshold for presence
Returns:
Tuple of (presence, activities):
- presence: (batch, num_isotopes) binary predictions
- activities: (batch, num_isotopes) in Bq
"""
logits, activities_norm = self.forward(x)
probs = torch.sigmoid(logits)
presence = (probs >= threshold).float()
activities_bq = activities_norm * self.config.max_activity_bq
return presence, activities_bq
def count_parameters(model: nn.Module) -> int:
"""Count trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
# Test the model
config = Vega2DConfig()
model = Vega2DModel(config)
print(f"Vega 2D Model")
print(f" Input: ({config.num_time_intervals}, {config.num_channels})")
print(f" Conv channels: {config.conv_channels}")
print(f" FC dims: {config.fc_hidden_dims}")
print(f" Flat size: {model.flat_size}")
print(f" Parameters: {count_parameters(model):,}")
# Test forward pass
batch = torch.randn(4, 1, config.num_time_intervals, config.num_channels)
logits, activities = model(batch)
print(f"\n Test batch: {batch.shape}")
print(f" Logits: {logits.shape}")
print(f" Activities: {activities.shape}")

View File

@ -0,0 +1,120 @@
#!/usr/bin/env python
"""
Run Vega Training
Simple script to train the Vega model on synthetic gamma spectra.
Designed for both quick test runs and full-scale training.
"""
import sys
import argparse
from pathlib import Path
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from training.vega.train import train_vega, TrainingConfig
from training.vega.model import VegaConfig
def main():
parser = argparse.ArgumentParser(
description="Train Vega model for isotope identification"
)
# Data paths
parser.add_argument(
"--data-dir", "-d",
type=str,
default="O:/master_data_collection/isotopev2",
help="Path to synthetic data directory"
)
parser.add_argument(
"--model-dir", "-m",
type=str,
default="models",
help="Directory to save trained models"
)
# Training parameters
parser.add_argument(
"--epochs", "-e",
type=int,
default=100,
help="Maximum number of training epochs"
)
parser.add_argument(
"--batch-size", "-b",
type=int,
default=64,
help="Batch size for training (default: 64 for better GPU utilization)"
)
parser.add_argument(
"--learning-rate", "-lr",
type=float,
default=1e-3,
help="Initial learning rate"
)
# Quick test mode
parser.add_argument(
"--test",
action="store_true",
help="Quick test mode with reduced epochs"
)
# Mixed precision
parser.add_argument(
"--no-amp",
action="store_true",
help="Disable automatic mixed precision training"
)
# Data loading parallelism
parser.add_argument(
"--workers", "-w",
type=int,
default=8,
help="Number of data loading workers (default: 8 for parallel I/O)"
)
args = parser.parse_args()
# Create training config
config = TrainingConfig(
data_dir=args.data_dir,
model_dir=args.model_dir,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
num_epochs=args.epochs if not args.test else 5,
patience=10 if not args.test else 3,
use_amp=not args.no_amp,
num_workers=args.workers
)
# Create model config
model_config = VegaConfig()
print("\n" + "=" * 60)
print("VEGA TRAINING")
print("=" * 60)
print(f"Data directory: {args.data_dir}")
print(f"Model directory: {args.model_dir}")
print(f"Epochs: {config.num_epochs}")
print(f"Batch size: {config.batch_size}")
print(f"Learning rate: {config.learning_rate}")
print(f"Mixed precision: {config.use_amp}")
print(f"Data workers: {config.num_workers}")
if args.test:
print("MODE: Quick test run")
print("=" * 60 + "\n")
# Run training
model, results = train_vega(config=config, model_config=model_config)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,614 @@
"""
Training Script for Vega Model
Implements the training loop with:
- Mixed precision training for RTX 5090 efficiency
- Learning rate scheduling
- Early stopping
- Model checkpointing
- Training metrics logging
"""
import os
import sys
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
# Sklearn metrics for comprehensive evaluation
from sklearn.metrics import (
roc_auc_score,
f1_score,
precision_score,
recall_score,
hamming_loss
)
from .model import VegaModel, VegaConfig, VegaLoss
from .dataset import create_data_loaders, SpectrumDataset
from .isotope_index import IsotopeIndex, get_default_isotope_index
@dataclass
class TrainingConfig:
"""Configuration for training."""
# Data
data_dir: str = "O:/master_data_collection/isotopev2"
# Model save path
model_dir: str = "models"
model_name: str = "vega"
# Training hyperparameters
batch_size: int = 64 # Increased from 32 for better GPU utilization
learning_rate: float = 1e-3
weight_decay: float = 1e-4
num_epochs: int = 100
# Early stopping
patience: int = 10
min_delta: float = 1e-4
# Learning rate scheduling
lr_scheduler_patience: int = 5
lr_scheduler_factor: float = 0.5
min_lr: float = 1e-6
# Mixed precision
use_amp: bool = True
# Data splits
train_split: float = 0.8
val_split: float = 0.1
test_split: float = 0.1
# Workers - parallel data loading for better GPU utilization
num_workers: int = 8 # Parallel data loading workers
prefetch_factor: int = 4 # Batches to prefetch per worker
persistent_workers: bool = True # Keep workers alive between epochs
# Reproducibility
seed: int = 42
# Activity normalization
max_activity_bq: float = 1000.0
class VegaTrainer:
"""
Trainer class for the Vega model.
Handles the training loop, validation, checkpointing, and metrics.
"""
def __init__(
self,
model: VegaModel,
config: TrainingConfig,
device: Optional[torch.device] = None,
force_cpu: bool = False
):
self.model = model
self.config = config
# Device selection - force CPU if requested or if CUDA incompatible
if force_cpu:
self.device = torch.device('cpu')
elif device:
self.device = device
else:
# Try CUDA but fall back to CPU if there are compatibility issues
if torch.cuda.is_available():
try:
# Test if CUDA actually works
test_tensor = torch.zeros(1, device='cuda')
_ = test_tensor + 1
self.device = torch.device('cuda')
except RuntimeError:
print("CUDA device found but not compatible, falling back to CPU")
self.device = torch.device('cpu')
else:
self.device = torch.device('cpu')
# Move model to device
self.model = self.model.to(self.device)
# Setup loss function
self.loss_fn = VegaLoss(
classification_weight=model.config.classification_weight,
regression_weight=model.config.regression_weight
)
# Setup optimizer
self.optimizer = Adam(
self.model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Setup learning rate scheduler
self.scheduler = ReduceLROnPlateau(
self.optimizer,
mode='min',
patience=config.lr_scheduler_patience,
factor=config.lr_scheduler_factor,
min_lr=config.min_lr
)
# Setup mixed precision training (only if CUDA is working)
if config.use_amp and self.device.type == 'cuda':
self.scaler = torch.amp.GradScaler('cuda')
else:
self.scaler = None
# Training state
self.current_epoch = 0
self.best_val_loss = float('inf')
self.epochs_without_improvement = 0
self.training_history = []
# Create model directory
self.model_dir = Path(config.model_dir)
self.model_dir.mkdir(parents=True, exist_ok=True)
print(f"Training on device: {self.device}")
if self.device.type == 'cuda':
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Mixed precision: {config.use_amp}")
def train_epoch(self, train_loader) -> Dict[str, float]:
"""Train for one epoch."""
self.model.train()
total_loss = 0.0
cls_loss_sum = 0.0
reg_loss_sum = 0.0
num_batches = 0
# Track accuracy during training
correct_isotopes = 0
total_isotopes = 0
# Timing for profiling - track data loading vs GPU compute
data_time = 0.0
compute_time = 0.0
data_start = time.time()
for batch in train_loader:
# Data loading time (time spent waiting for next batch)
data_time += time.time() - data_start
compute_start = time.time()
# Move to device
spectra = batch['spectrum'].to(self.device)
presence = batch['presence_labels'].to(self.device)
activities = batch['activity_labels'].to(self.device)
# Zero gradients
self.optimizer.zero_grad()
# Forward pass with optional mixed precision
if self.scaler is not None:
with torch.amp.autocast('cuda'):
pred_logits, pred_activities = self.model(spectra)
loss, loss_dict = self.loss_fn(
pred_logits, pred_activities, presence, activities
)
# Backward pass with scaling
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
pred_logits, pred_activities = self.model(spectra)
loss, loss_dict = self.loss_fn(
pred_logits, pred_activities, presence, activities
)
loss.backward()
self.optimizer.step()
total_loss += loss_dict['total']
cls_loss_sum += loss_dict['classification']
reg_loss_sum += loss_dict['regression']
num_batches += 1
# Calculate training accuracy (detach to avoid memory buildup)
with torch.no_grad():
pred_probs = torch.sigmoid(pred_logits)
pred_presence = (pred_probs >= 0.5).float()
correct_isotopes += (pred_presence == presence).sum().item()
total_isotopes += presence.numel()
# Mark compute time and restart timing for data loading
compute_time += time.time() - compute_start
data_start = time.time()
train_accuracy = correct_isotopes / total_isotopes if total_isotopes > 0 else 0.0
return {
'train_loss': total_loss / num_batches,
'train_cls_loss': cls_loss_sum / num_batches,
'train_reg_loss': reg_loss_sum / num_batches,
'train_accuracy': train_accuracy,
'data_time': data_time,
'compute_time': compute_time
}
@torch.no_grad()
def validate(self, val_loader) -> Dict[str, float]:
"""Validate the model with comprehensive metrics."""
if val_loader is None:
return {}
self.model.eval()
total_loss = 0.0
cls_loss_sum = 0.0
reg_loss_sum = 0.0
num_batches = 0
# Collect all predictions and labels for sklearn metrics
all_probs = []
all_preds = []
all_labels = []
for batch in val_loader:
spectra = batch['spectrum'].to(self.device)
presence = batch['presence_labels'].to(self.device)
activities = batch['activity_labels'].to(self.device)
pred_logits, pred_activities = self.model(spectra)
loss, loss_dict = self.loss_fn(
pred_logits, pred_activities, presence, activities
)
total_loss += loss_dict['total']
cls_loss_sum += loss_dict['classification']
reg_loss_sum += loss_dict['regression']
num_batches += 1
# Collect predictions for metrics
pred_probs = torch.sigmoid(pred_logits)
pred_presence = (pred_probs >= 0.5).float()
all_probs.append(pred_probs.cpu().numpy())
all_preds.append(pred_presence.cpu().numpy())
all_labels.append(presence.cpu().numpy())
# Concatenate all batches
all_probs = np.vstack(all_probs)
all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)
# Basic accuracy (element-wise)
correct = (all_preds == all_labels).sum()
total = all_labels.size
accuracy = correct / total if total > 0 else 0.0
# Multi-label metrics using sklearn
metrics = {
'val_loss': total_loss / num_batches,
'val_cls_loss': cls_loss_sum / num_batches,
'val_reg_loss': reg_loss_sum / num_batches,
'val_accuracy': accuracy,
}
try:
# ROC-AUC (macro-averaged over isotopes with both classes present)
# Only compute for columns that have both 0s and 1s
valid_cols = []
for i in range(all_labels.shape[1]):
if len(np.unique(all_labels[:, i])) == 2:
valid_cols.append(i)
if valid_cols:
auc_macro = roc_auc_score(
all_labels[:, valid_cols],
all_probs[:, valid_cols],
average='macro'
)
auc_micro = roc_auc_score(
all_labels[:, valid_cols],
all_probs[:, valid_cols],
average='micro'
)
metrics['val_auc_macro'] = auc_macro
metrics['val_auc_micro'] = auc_micro
else:
metrics['val_auc_macro'] = 0.0
metrics['val_auc_micro'] = 0.0
except ValueError:
# Handle case where AUC can't be computed
metrics['val_auc_macro'] = 0.0
metrics['val_auc_micro'] = 0.0
# F1, Precision, Recall (samples-averaged for multi-label)
metrics['val_f1_macro'] = f1_score(all_labels, all_preds, average='macro', zero_division=0)
metrics['val_f1_micro'] = f1_score(all_labels, all_preds, average='micro', zero_division=0)
metrics['val_precision'] = precision_score(all_labels, all_preds, average='micro', zero_division=0)
metrics['val_recall'] = recall_score(all_labels, all_preds, average='micro', zero_division=0)
# Hamming loss (fraction of labels incorrectly predicted)
metrics['val_hamming'] = hamming_loss(all_labels, all_preds)
# Exact match ratio (all isotopes correct for a sample)
exact_matches = (all_preds == all_labels).all(axis=1).sum()
metrics['val_exact_match'] = exact_matches / len(all_labels)
return metrics
def save_checkpoint(self, path: Path, is_best: bool = False):
"""Save a model checkpoint."""
checkpoint = {
'epoch': self.current_epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_val_loss': self.best_val_loss,
'model_config': asdict(self.model.config),
'training_config': asdict(self.config),
'training_history': self.training_history
}
if self.scaler is not None:
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
torch.save(checkpoint, path)
if is_best:
best_path = path.parent / f"{self.config.model_name}_best.pt"
torch.save(checkpoint, best_path)
def load_checkpoint(self, path: Path):
"""Load a model checkpoint."""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if self.scaler is not None and 'scaler_state_dict' in checkpoint:
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.current_epoch = checkpoint['epoch']
self.best_val_loss = checkpoint['best_val_loss']
self.training_history = checkpoint.get('training_history', [])
print(f"Loaded checkpoint from epoch {self.current_epoch}")
def train(
self,
train_loader,
val_loader,
resume_from: Optional[Path] = None
) -> Dict:
"""
Full training loop.
Args:
train_loader: Training data loader
val_loader: Validation data loader
resume_from: Optional path to checkpoint to resume from
Returns:
Training results dictionary
"""
if resume_from is not None:
self.load_checkpoint(resume_from)
print("\n" + "=" * 60)
print("Starting Vega Training")
print("=" * 60)
print(f"Epochs: {self.config.num_epochs}")
print(f"Batch size: {self.config.batch_size}")
print(f"Learning rate: {self.config.learning_rate}")
print(f"Training samples: {len(train_loader.dataset)}")
if val_loader:
print(f"Validation samples: {len(val_loader.dataset)}")
print("=" * 60 + "\n")
start_time = time.time()
for epoch in range(self.current_epoch, self.config.num_epochs):
self.current_epoch = epoch
epoch_start = time.time()
# Train
train_metrics = self.train_epoch(train_loader)
# Validate
val_metrics = self.validate(val_loader)
# Combine metrics
metrics = {**train_metrics, **val_metrics, 'epoch': epoch}
self.training_history.append(metrics)
# Update learning rate
if val_loader and 'val_loss' in val_metrics:
self.scheduler.step(val_metrics['val_loss'])
else:
self.scheduler.step(train_metrics['train_loss'])
# Check for improvement
val_loss = val_metrics.get('val_loss', train_metrics['train_loss'])
is_best = val_loss < self.best_val_loss - self.config.min_delta
if is_best:
self.best_val_loss = val_loss
self.epochs_without_improvement = 0
else:
self.epochs_without_improvement += 1
# Save checkpoint
checkpoint_path = self.model_dir / f"{self.config.model_name}_epoch_{epoch}.pt"
self.save_checkpoint(checkpoint_path, is_best=is_best)
# Logging
epoch_time = time.time() - epoch_start
lr = self.optimizer.param_groups[0]['lr']
# Primary metrics line
log_str = (
f"Epoch {epoch+1:3d}/{self.config.num_epochs} | "
f"Train Loss: {train_metrics['train_loss']:.4f} | "
f"Train Acc: {train_metrics['train_accuracy']:.4f} | "
)
if val_loader:
log_str += (
f"Val Loss: {val_metrics['val_loss']:.4f} | "
f"Val Acc: {val_metrics['val_accuracy']:.4f} | "
)
log_str += f"LR: {lr:.2e} | Time: {epoch_time:.1f}s"
if is_best:
log_str += " *"
print(log_str)
# Timing breakdown line
data_t = train_metrics.get('data_time', 0)
compute_t = train_metrics.get('compute_time', 0)
if data_t > 0 or compute_t > 0:
data_pct = 100 * data_t / (data_t + compute_t) if (data_t + compute_t) > 0 else 0
print(f" └── Data: {data_t:.1f}s ({data_pct:.0f}%) | Compute: {compute_t:.1f}s ({100-data_pct:.0f}%)")
# Secondary metrics line (detailed classification metrics)
if val_loader and 'val_auc_macro' in val_metrics:
detail_str = (
f" └── AUC: {val_metrics['val_auc_macro']:.4f} | "
f"F1: {val_metrics['val_f1_macro']:.4f} | "
f"Prec: {val_metrics['val_precision']:.4f} | "
f"Recall: {val_metrics['val_recall']:.4f} | "
f"Exact: {val_metrics['val_exact_match']:.4f}"
)
print(detail_str)
# Early stopping
if self.epochs_without_improvement >= self.config.patience:
print(f"\nEarly stopping after {epoch + 1} epochs")
break
total_time = time.time() - start_time
# Save final model
final_path = self.model_dir / f"{self.config.model_name}_final.pt"
self.save_checkpoint(final_path)
# Save training history
history_path = self.model_dir / f"{self.config.model_name}_history.json"
with open(history_path, 'w') as f:
json.dump(self.training_history, f, indent=2)
print("\n" + "=" * 60)
print(f"Training complete!")
print(f"Total time: {total_time / 60:.1f} minutes")
print(f"Best validation loss: {self.best_val_loss:.4f}")
print(f"Model saved to: {final_path}")
print("=" * 60)
return {
'best_val_loss': self.best_val_loss,
'total_epochs': self.current_epoch + 1,
'total_time': total_time,
'history': self.training_history
}
def train_vega(
data_dir: Optional[str] = None,
model_dir: Optional[str] = None,
config: Optional[TrainingConfig] = None,
model_config: Optional[VegaConfig] = None
) -> Tuple[VegaModel, Dict]:
"""
Convenience function to train a Vega model.
Args:
data_dir: Path to data directory
model_dir: Path to save models
config: Training configuration
model_config: Model configuration
Returns:
Tuple of (trained model, training results)
"""
# Setup paths
project_root = Path(__file__).parent.parent.parent
if config is None:
config = TrainingConfig()
if data_dir:
config.data_dir = data_dir
if model_dir:
config.model_dir = model_dir
# Make paths absolute
data_path = Path(config.data_dir)
if not data_path.is_absolute():
data_path = project_root / data_path
model_path = Path(config.model_dir)
if not model_path.is_absolute():
model_path = project_root / model_path
config.model_dir = str(model_path)
# Set random seeds
torch.manual_seed(config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(config.seed)
# Get isotope index
isotope_index = get_default_isotope_index()
# Create data loaders with parallel loading
train_loader, val_loader, test_loader = create_data_loaders(
data_dir=data_path,
batch_size=config.batch_size,
train_split=config.train_split,
val_split=config.val_split,
test_split=config.test_split,
num_workers=config.num_workers,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
isotope_index=isotope_index,
max_activity_bq=config.max_activity_bq,
seed=config.seed
)
# Create model
if model_config is None:
model_config = VegaConfig(
num_isotopes=isotope_index.num_isotopes,
max_activity_bq=config.max_activity_bq
)
model = VegaModel(model_config)
print(model.summary())
# Create trainer
trainer = VegaTrainer(model, config)
# Train
results = trainer.train(train_loader, val_loader)
# Save isotope index with model
index_path = model_path / f"{config.model_name}_isotope_index.txt"
isotope_index.save(index_path)
return model, results
if __name__ == "__main__":
# Quick test training
model, results = train_vega()

View File

@ -0,0 +1,411 @@
"""
Training Script for Vega 2D Model
Uses 2D convolutions to process gamma spectra with temporal information.
"""
import argparse
import json
import time
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from .model_2d import Vega2DModel, Vega2DConfig, count_parameters
from .dataset_2d import create_data_loaders_2d, SpectrumDataset2D
from .isotope_index import get_default_isotope_index
@dataclass
class TrainingConfig2D:
"""Training configuration for 2D model."""
# Data
data_dir: str = "O:/master_data_collection/isotopev2"
model_dir: str = "models"
# Model
target_time_intervals: int = 60
# Training
epochs: int = 50
batch_size: int = 32
learning_rate: float = 1e-3
weight_decay: float = 1e-5
# Loss weights
classification_weight: float = 1.0
regression_weight: float = 0.1
# Mixed precision
use_amp: bool = True
# Early stopping
early_stopping_patience: int = 10
# Learning rate scheduler
lr_scheduler_patience: int = 5
lr_scheduler_factor: float = 0.5
# Data loading
num_workers: int = 4
def train_epoch(
model: nn.Module,
train_loader,
optimizer: optim.Optimizer,
criterion_cls: nn.Module,
criterion_reg: nn.Module,
device: torch.device,
scaler: Optional[GradScaler],
config: TrainingConfig2D
) -> Dict[str, float]:
"""Train for one epoch."""
model.train()
total_loss = 0.0
total_cls_loss = 0.0
total_reg_loss = 0.0
num_batches = 0
for batch in train_loader:
spectra = batch['spectrum'].to(device)
presence = batch['presence_labels'].to(device)
activities = batch['activity_labels'].to(device)
optimizer.zero_grad()
if scaler is not None:
with autocast():
logits, pred_activities = model(spectra)
cls_loss = criterion_cls(logits, presence)
reg_loss = criterion_reg(pred_activities, activities)
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
logits, pred_activities = model(spectra)
cls_loss = criterion_cls(logits, presence)
reg_loss = criterion_reg(pred_activities, activities)
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_reg_loss += reg_loss.item()
num_batches += 1
return {
'loss': total_loss / num_batches,
'cls_loss': total_cls_loss / num_batches,
'reg_loss': total_reg_loss / num_batches
}
@torch.no_grad()
def validate(
model: nn.Module,
val_loader,
criterion_cls: nn.Module,
criterion_reg: nn.Module,
device: torch.device,
config: TrainingConfig2D,
threshold: float = 0.5
) -> Dict[str, float]:
"""Validate the model."""
model.eval()
total_loss = 0.0
total_cls_loss = 0.0
total_reg_loss = 0.0
num_batches = 0
all_preds = []
all_labels = []
for batch in val_loader:
spectra = batch['spectrum'].to(device)
presence = batch['presence_labels'].to(device)
activities = batch['activity_labels'].to(device)
logits, pred_activities = model(spectra)
cls_loss = criterion_cls(logits, presence)
reg_loss = criterion_reg(pred_activities, activities)
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_reg_loss += reg_loss.item()
num_batches += 1
# Collect predictions for metrics
probs = torch.sigmoid(logits)
preds = (probs >= threshold).float()
all_preds.append(preds.cpu())
all_labels.append(presence.cpu())
# Calculate metrics
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Per-sample accuracy (all isotopes correct)
exact_match = (all_preds == all_labels).all(dim=1).float().mean().item()
# Per-isotope metrics
tp = ((all_preds == 1) & (all_labels == 1)).sum().item()
fp = ((all_preds == 1) & (all_labels == 0)).sum().item()
fn = ((all_preds == 0) & (all_labels == 1)).sum().item()
tn = ((all_preds == 0) & (all_labels == 0)).sum().item()
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return {
'loss': total_loss / num_batches,
'cls_loss': total_cls_loss / num_batches,
'reg_loss': total_reg_loss / num_batches,
'exact_match': exact_match,
'precision': precision,
'recall': recall,
'f1': f1
}
def train_vega_2d(
config: TrainingConfig2D = None,
model_config: Vega2DConfig = None
) -> Tuple[Vega2DModel, Dict]:
"""
Train the Vega 2D model.
"""
config = config or TrainingConfig2D()
model_config = model_config or Vega2DConfig(num_time_intervals=config.target_time_intervals)
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Create model
model = Vega2DModel(model_config).to(device)
print(f"\nModel: Vega 2D")
print(f" Input: ({model_config.num_time_intervals}, {model_config.num_channels})")
print(f" Conv channels: {model_config.conv_channels}")
print(f" FC dims: {model_config.fc_hidden_dims}")
print(f" Parameters: {count_parameters(model):,}")
# Create data loaders
print(f"\nLoading data from: {config.data_dir}")
isotope_index = get_default_isotope_index()
train_loader, val_loader, test_loader = create_data_loaders_2d(
data_dir=Path(config.data_dir),
batch_size=config.batch_size,
target_time_intervals=config.target_time_intervals,
isotope_index=isotope_index,
num_workers=config.num_workers
)
# Loss functions
criterion_cls = nn.BCEWithLogitsLoss()
criterion_reg = nn.HuberLoss()
# Optimizer
optimizer = optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=config.lr_scheduler_factor,
patience=config.lr_scheduler_patience
)
# Mixed precision scaler
scaler = GradScaler() if config.use_amp and device.type == 'cuda' else None
# Training history
history = {
'train_loss': [], 'val_loss': [],
'train_cls_loss': [], 'val_cls_loss': [],
'train_reg_loss': [], 'val_reg_loss': [],
'val_exact_match': [], 'val_precision': [], 'val_recall': [], 'val_f1': [],
'lr': []
}
# Early stopping
best_val_loss = float('inf')
patience_counter = 0
# Model directory
model_dir = Path(config.model_dir)
model_dir.mkdir(exist_ok=True)
print(f"\nStarting training for {config.epochs} epochs...")
print(f" Batch size: {config.batch_size}")
print(f" Learning rate: {config.learning_rate}")
print(f" AMP: {scaler is not None}")
print()
start_time = time.time()
for epoch in range(config.epochs):
epoch_start = time.time()
# Train
train_metrics = train_epoch(
model, train_loader, optimizer,
criterion_cls, criterion_reg,
device, scaler, config
)
# Validate
val_metrics = validate(
model, val_loader,
criterion_cls, criterion_reg,
device, config
)
# Update scheduler
scheduler.step(val_metrics['loss'])
current_lr = optimizer.param_groups[0]['lr']
# Record history
history['train_loss'].append(train_metrics['loss'])
history['val_loss'].append(val_metrics['loss'])
history['train_cls_loss'].append(train_metrics['cls_loss'])
history['val_cls_loss'].append(val_metrics['cls_loss'])
history['train_reg_loss'].append(train_metrics['reg_loss'])
history['val_reg_loss'].append(val_metrics['reg_loss'])
history['val_exact_match'].append(val_metrics['exact_match'])
history['val_precision'].append(val_metrics['precision'])
history['val_recall'].append(val_metrics['recall'])
history['val_f1'].append(val_metrics['f1'])
history['lr'].append(current_lr)
epoch_time = time.time() - epoch_start
# Print progress
print(f"Epoch {epoch+1:3d}/{config.epochs} ({epoch_time:.1f}s) | "
f"Train Loss: {train_metrics['loss']:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"F1: {val_metrics['f1']:.4f} | "
f"Recall: {val_metrics['recall']:.4f} | "
f"LR: {current_lr:.2e}")
# Save best model
if val_metrics['loss'] < best_val_loss:
best_val_loss = val_metrics['loss']
patience_counter = 0
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'model_config': asdict(model_config),
'training_config': asdict(config),
'val_metrics': val_metrics,
'history': history
}, model_dir / 'vega_2d_best.pt')
print(f" ✓ Saved best model (val_loss: {best_val_loss:.4f})")
else:
patience_counter += 1
# Early stopping
if patience_counter >= config.early_stopping_patience:
print(f"\nEarly stopping at epoch {epoch+1}")
break
total_time = time.time() - start_time
print(f"\nTraining complete in {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.4f}")
# Save final model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'model_config': asdict(model_config),
'training_config': asdict(config),
'history': history
}, model_dir / 'vega_2d_final.pt')
# Save history
with open(model_dir / 'vega_2d_history.json', 'w') as f:
json.dump(history, f, indent=2)
# Test set evaluation
print("\nEvaluating on test set...")
test_metrics = validate(
model, test_loader,
criterion_cls, criterion_reg,
device, config
)
print(f" Test Loss: {test_metrics['loss']:.4f}")
print(f" Test F1: {test_metrics['f1']:.4f}")
print(f" Test Recall: {test_metrics['recall']:.4f}")
print(f" Test Precision: {test_metrics['precision']:.4f}")
print(f" Test Exact Match: {test_metrics['exact_match']:.4f}")
return model, history
def main():
parser = argparse.ArgumentParser(description='Train Vega 2D Model')
parser.add_argument('--data-dir', type=str, default='O:/master_data_collection/isotopev2',
help='Path to training data')
parser.add_argument('--model-dir', type=str, default='models',
help='Path to save models')
parser.add_argument('--epochs', type=int, default=50,
help='Number of epochs')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size')
parser.add_argument('--lr', type=float, default=1e-3,
help='Learning rate')
parser.add_argument('--time-intervals', type=int, default=60,
help='Target time intervals (pad/truncate)')
parser.add_argument('--no-amp', action='store_true',
help='Disable mixed precision training')
parser.add_argument('--workers', type=int, default=4,
help='Data loading workers')
args = parser.parse_args()
config = TrainingConfig2D(
data_dir=args.data_dir,
model_dir=args.model_dir,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
target_time_intervals=args.time_intervals,
use_amp=not args.no_amp,
num_workers=args.workers
)
model_config = Vega2DConfig(
num_time_intervals=args.time_intervals
)
train_vega_2d(config, model_config)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,847 @@
"""
Vega Training v2 - Optuna Hyperparameter Optimization
Uses Optuna to search for optimal hyperparameters to maximize model performance,
with a focus on improving recall for isotope detection.
Key optimizations:
1. Model architecture (CNN channels, FC dims, kernel sizes)
2. Training hyperparameters (LR, batch size, weight decay, dropout)
3. Loss function weights (classification vs regression balance)
4. Classification threshold optimization
5. Focal loss for handling class imbalance
"""
import os
import sys
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple, List
from dataclasses import dataclass, asdict, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
import numpy as np
import optuna
from optuna.trial import Trial
from optuna.pruners import MedianPruner, HyperbandPruner
from optuna.samplers import TPESampler
# Sklearn metrics
from sklearn.metrics import (
roc_auc_score,
f1_score,
precision_score,
recall_score,
hamming_loss
)
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from training.vega.model import VegaModel, VegaConfig
from training.vega.dataset import create_data_loaders, SpectrumDataset
from training.vega.isotope_index import IsotopeIndex, get_default_isotope_index
class FocalLoss(nn.Module):
"""
Focal Loss for handling class imbalance in multi-label classification.
Reduces the relative loss for well-classified examples (high probability),
putting more focus on hard, misclassified examples.
FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
Args:
alpha: Weighting factor for positive examples (default: 0.25)
gamma: Focusing parameter - higher = more focus on hard examples (default: 2.0)
"""
def __init__(self, alpha: float = 0.25, gamma: float = 2.0, reduction: str = 'mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# inputs are logits, targets are binary labels
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
# Get probabilities
probs = torch.sigmoid(inputs)
p_t = probs * targets + (1 - probs) * (1 - targets)
# Apply focal weighting
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
focal_weight = alpha_t * (1 - p_t) ** self.gamma
focal_loss = focal_weight * BCE_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
return focal_loss
class VegaLossV2(nn.Module):
"""
Enhanced loss function with Focal Loss option and tunable weights.
"""
def __init__(
self,
classification_weight: float = 1.0,
regression_weight: float = 0.1,
use_focal_loss: bool = True,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0,
pos_weight: Optional[torch.Tensor] = None
):
super().__init__()
self.classification_weight = classification_weight
self.regression_weight = regression_weight
self.use_focal_loss = use_focal_loss
if use_focal_loss:
self.cls_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
else:
self.cls_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
self.reg_loss = nn.HuberLoss(delta=0.1)
def forward(
self,
pred_logits: torch.Tensor,
pred_activities: torch.Tensor,
true_presence: torch.Tensor,
true_activities: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
# Classification loss
cls_loss = self.cls_loss(pred_logits, true_presence)
# Regression loss (only for present isotopes)
mask = true_presence > 0.5
if mask.any():
reg_loss = self.reg_loss(
pred_activities[mask],
true_activities[mask]
)
else:
reg_loss = torch.tensor(0.0, device=pred_logits.device)
# Combined loss
total_loss = (
self.classification_weight * cls_loss +
self.regression_weight * reg_loss
)
return total_loss, {
'total': total_loss.item(),
'classification': cls_loss.item(),
'regression': reg_loss.item() if mask.any() else 0.0
}
@dataclass
class OptunaConfig:
"""Configuration for Optuna hyperparameter optimization."""
# Data
data_dir: str = "O:/master_data_collection/isotopev2"
model_dir: str = "models/optuna"
study_name: str = "vega_v2_optimization"
# Optuna settings
n_trials: int = 50
timeout_hours: float = 24.0
n_startup_trials: int = 10 # Random sampling before TPE
# Training settings for each trial
max_epochs: int = 30 # Shorter epochs for faster trials
patience: int = 5 # Early stopping patience
# Data splits
train_split: float = 0.8
val_split: float = 0.1
test_split: float = 0.1
# Fixed settings
num_workers: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
use_amp: bool = True
# Optimization objective
optimize_metric: str = "val_recall" # Focus on recall
# Reproducibility
seed: int = 42
def suggest_hyperparameters(trial: Trial) -> Dict:
"""
Suggest hyperparameters for a trial using Optuna.
Returns a dictionary with all hyperparameters to try.
"""
params = {}
# ========== Model Architecture ==========
# CNN backbone
n_conv_layers = trial.suggest_int("n_conv_layers", 2, 4)
conv_channels = []
for i in range(n_conv_layers):
ch = trial.suggest_categorical(f"conv_ch_{i}", [32, 64, 128, 256, 512])
conv_channels.append(ch)
params["conv_channels"] = conv_channels
params["conv_kernel_size"] = trial.suggest_categorical("conv_kernel_size", [3, 5, 7, 9, 11])
params["pool_size"] = trial.suggest_categorical("pool_size", [2, 3, 4])
# FC layers
n_fc_layers = trial.suggest_int("n_fc_layers", 1, 3)
fc_dims = []
for i in range(n_fc_layers):
dim = trial.suggest_categorical(f"fc_dim_{i}", [128, 256, 512, 1024])
fc_dims.append(dim)
params["fc_hidden_dims"] = fc_dims
# Regularization
params["dropout_rate"] = trial.suggest_float("dropout_rate", 0.1, 0.5)
params["spatial_dropout_rate"] = trial.suggest_float("spatial_dropout_rate", 0.05, 0.3)
params["leaky_relu_slope"] = trial.suggest_float("leaky_relu_slope", 0.01, 0.2)
# ========== Training Hyperparameters ==========
params["batch_size"] = trial.suggest_categorical("batch_size", [128, 256, 512, 1024])
params["learning_rate"] = trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True)
params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
# Optimizer
params["optimizer"] = trial.suggest_categorical("optimizer", ["adam", "adamw"])
# Learning rate scheduler
params["scheduler"] = trial.suggest_categorical("scheduler", ["plateau", "cosine"])
if params["scheduler"] == "plateau":
params["lr_factor"] = trial.suggest_float("lr_factor", 0.1, 0.5)
params["lr_patience"] = trial.suggest_int("lr_patience", 3, 10)
else:
params["cosine_t_0"] = trial.suggest_int("cosine_t_0", 5, 15)
params["cosine_t_mult"] = trial.suggest_int("cosine_t_mult", 1, 2)
# ========== Loss Function ==========
params["use_focal_loss"] = trial.suggest_categorical("use_focal_loss", [True, False])
if params["use_focal_loss"]:
params["focal_alpha"] = trial.suggest_float("focal_alpha", 0.1, 0.5)
params["focal_gamma"] = trial.suggest_float("focal_gamma", 1.0, 3.0)
params["classification_weight"] = trial.suggest_float("classification_weight", 0.5, 2.0)
params["regression_weight"] = trial.suggest_float("regression_weight", 0.01, 0.5, log=True)
# ========== Classification Threshold ==========
params["threshold"] = trial.suggest_float("threshold", 0.3, 0.7)
return params
def create_model_from_params(params: Dict, num_isotopes: int) -> VegaModel:
"""Create a VegaModel from hyperparameters."""
config = VegaConfig(
num_isotopes=num_isotopes,
conv_channels=params["conv_channels"],
conv_kernel_size=params["conv_kernel_size"],
pool_size=params["pool_size"],
fc_hidden_dims=params["fc_hidden_dims"],
dropout_rate=params["dropout_rate"],
spatial_dropout_rate=params["spatial_dropout_rate"],
leaky_relu_slope=params["leaky_relu_slope"],
classification_weight=params["classification_weight"],
regression_weight=params["regression_weight"]
)
return VegaModel(config)
def train_single_trial(
trial: Trial,
params: Dict,
train_loader,
val_loader,
device: torch.device,
optuna_config: OptunaConfig
) -> float:
"""
Train a single trial and return the objective metric.
"""
# Create model
num_isotopes = 82 # From isotope database
model = create_model_from_params(params, num_isotopes)
model = model.to(device)
# Create loss function
loss_fn = VegaLossV2(
classification_weight=params["classification_weight"],
regression_weight=params["regression_weight"],
use_focal_loss=params.get("use_focal_loss", False),
focal_alpha=params.get("focal_alpha", 0.25),
focal_gamma=params.get("focal_gamma", 2.0)
)
# Create optimizer
if params["optimizer"] == "adamw":
optimizer = AdamW(
model.parameters(),
lr=params["learning_rate"],
weight_decay=params["weight_decay"]
)
else:
optimizer = Adam(
model.parameters(),
lr=params["learning_rate"],
weight_decay=params["weight_decay"]
)
# Create scheduler
if params["scheduler"] == "cosine":
scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=params.get("cosine_t_0", 10),
T_mult=params.get("cosine_t_mult", 1)
)
else:
scheduler = ReduceLROnPlateau(
optimizer,
mode='min',
patience=params.get("lr_patience", 5),
factor=params.get("lr_factor", 0.5)
)
# Mixed precision
scaler = torch.amp.GradScaler('cuda') if optuna_config.use_amp and device.type == 'cuda' else None
# Training loop
best_metric = 0.0
epochs_without_improvement = 0
threshold = params["threshold"]
for epoch in range(optuna_config.max_epochs):
# Training
model.train()
for batch in train_loader:
spectra = batch['spectrum'].to(device)
presence = batch['presence_labels'].to(device)
activities = batch['activity_labels'].to(device)
optimizer.zero_grad()
if scaler is not None:
with torch.amp.autocast('cuda'):
pred_logits, pred_activities = model(spectra)
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
pred_logits, pred_activities = model(spectra)
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
loss.backward()
optimizer.step()
# Validation
val_metrics = validate_model(model, val_loader, device, threshold, loss_fn)
# Update scheduler
if params["scheduler"] == "cosine":
scheduler.step()
else:
scheduler.step(val_metrics['val_loss'])
# Get objective metric
current_metric = val_metrics.get(optuna_config.optimize_metric, val_metrics['val_recall'])
# Report to Optuna for pruning
trial.report(current_metric, epoch)
# Handle pruning
if trial.should_prune():
raise optuna.TrialPruned()
# Track best
if current_metric > best_metric:
best_metric = current_metric
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
# Early stopping
if epochs_without_improvement >= optuna_config.patience:
break
return best_metric
@torch.no_grad()
def validate_model(
model: VegaModel,
val_loader,
device: torch.device,
threshold: float,
loss_fn: nn.Module
) -> Dict[str, float]:
"""Validate model and return comprehensive metrics."""
model.eval()
all_probs = []
all_preds = []
all_labels = []
total_loss = 0.0
num_batches = 0
for batch in val_loader:
spectra = batch['spectrum'].to(device)
presence = batch['presence_labels'].to(device)
activities = batch['activity_labels'].to(device)
pred_logits, pred_activities = model(spectra)
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
total_loss += loss.item()
num_batches += 1
# Get predictions
probs = torch.sigmoid(pred_logits)
preds = (probs >= threshold).float()
all_probs.append(probs.cpu().numpy())
all_preds.append(preds.cpu().numpy())
all_labels.append(presence.cpu().numpy())
# Concatenate
all_probs = np.vstack(all_probs)
all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)
# Calculate metrics
metrics = {
'val_loss': total_loss / num_batches,
'val_accuracy': (all_preds == all_labels).mean()
}
# Per-sample exact match
exact_matches = (all_preds == all_labels).all(axis=1).mean()
metrics['val_exact_match'] = exact_matches
# Sklearn metrics (handle edge cases)
try:
# Only for columns with both classes present
valid_cols = (all_labels.sum(axis=0) > 0) & (all_labels.sum(axis=0) < len(all_labels))
if valid_cols.any():
metrics['val_auc_macro'] = roc_auc_score(
all_labels[:, valid_cols],
all_probs[:, valid_cols],
average='macro'
)
except Exception:
metrics['val_auc_macro'] = 0.5
# Flatten for F1, precision, recall
all_preds_flat = all_preds.flatten()
all_labels_flat = all_labels.flatten()
metrics['val_f1_macro'] = f1_score(all_labels_flat, all_preds_flat, average='macro', zero_division=0)
metrics['val_precision'] = precision_score(all_labels_flat, all_preds_flat, average='macro', zero_division=0)
metrics['val_recall'] = recall_score(all_labels_flat, all_preds_flat, average='macro', zero_division=0)
metrics['val_hamming'] = hamming_loss(all_labels, all_preds)
return metrics
def objective(trial: Trial, optuna_config: OptunaConfig) -> float:
"""
Optuna objective function.
Returns the metric to maximize (recall by default).
"""
# Suggest hyperparameters
params = suggest_hyperparameters(trial)
# Log parameters
print(f"\n{'='*60}")
print(f"Trial {trial.number}")
print(f"{'='*60}")
for k, v in params.items():
print(f" {k}: {v}")
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Get isotope index
isotope_index = get_default_isotope_index()
# Create data loaders with trial's batch size
train_loader, val_loader, _ = create_data_loaders(
data_dir=Path(optuna_config.data_dir),
batch_size=params["batch_size"],
train_split=optuna_config.train_split,
val_split=optuna_config.val_split,
test_split=optuna_config.test_split,
num_workers=optuna_config.num_workers,
prefetch_factor=optuna_config.prefetch_factor,
persistent_workers=optuna_config.persistent_workers,
isotope_index=isotope_index,
seed=optuna_config.seed
)
try:
metric = train_single_trial(
trial, params, train_loader, val_loader, device, optuna_config
)
print(f"Trial {trial.number} completed with {optuna_config.optimize_metric}: {metric:.4f}")
return metric
except Exception as e:
print(f"Trial {trial.number} failed: {e}")
raise
def run_optimization(config: OptunaConfig) -> optuna.Study:
"""
Run the full Optuna optimization study.
"""
# Create model directory
model_dir = Path(config.model_dir)
model_dir.mkdir(parents=True, exist_ok=True)
# Create study with TPE sampler and Hyperband pruner
sampler = TPESampler(
n_startup_trials=config.n_startup_trials,
seed=config.seed
)
pruner = HyperbandPruner(
min_resource=3,
max_resource=config.max_epochs,
reduction_factor=3
)
# Create or load study
storage = f"sqlite:///{model_dir / config.study_name}.db"
study = optuna.create_study(
study_name=config.study_name,
storage=storage,
load_if_exists=True,
direction="maximize", # Maximize recall
sampler=sampler,
pruner=pruner
)
print("\n" + "=" * 60)
print("VEGA V2 - OPTUNA HYPERPARAMETER OPTIMIZATION")
print("=" * 60)
print(f"Study name: {config.study_name}")
print(f"Optimization metric: {config.optimize_metric}")
print(f"Number of trials: {config.n_trials}")
print(f"Timeout: {config.timeout_hours} hours")
print(f"Data directory: {config.data_dir}")
print("=" * 60 + "\n")
# Run optimization
study.optimize(
lambda trial: objective(trial, config),
n_trials=config.n_trials,
timeout=config.timeout_hours * 3600,
show_progress_bar=True,
gc_after_trial=True
)
# Print results
print("\n" + "=" * 60)
print("OPTIMIZATION COMPLETE")
print("=" * 60)
print(f"Best trial: {study.best_trial.number}")
print(f"Best {config.optimize_metric}: {study.best_value:.4f}")
print("\nBest hyperparameters:")
for k, v in study.best_params.items():
print(f" {k}: {v}")
# Save best parameters
best_params_path = model_dir / "best_params.json"
with open(best_params_path, 'w') as f:
json.dump({
'best_value': study.best_value,
'best_params': study.best_params,
'study_name': config.study_name,
'optimize_metric': config.optimize_metric
}, f, indent=2)
print(f"\nBest parameters saved to: {best_params_path}")
return study
def train_best_model(
study: optuna.Study,
config: OptunaConfig,
full_epochs: int = 100
) -> Tuple[VegaModel, Dict]:
"""
Train the best model from the study with full epochs.
"""
print("\n" + "=" * 60)
print("TRAINING BEST MODEL")
print("=" * 60)
best_params = study.best_params
# Reconstruct full params dict from best_params
params = {}
# CNN layers
n_conv_layers = best_params.get("n_conv_layers", 3)
conv_channels = [best_params.get(f"conv_ch_{i}", 128) for i in range(n_conv_layers)]
params["conv_channels"] = conv_channels
params["conv_kernel_size"] = best_params.get("conv_kernel_size", 7)
params["pool_size"] = best_params.get("pool_size", 2)
# FC layers
n_fc_layers = best_params.get("n_fc_layers", 2)
fc_dims = [best_params.get(f"fc_dim_{i}", 256) for i in range(n_fc_layers)]
params["fc_hidden_dims"] = fc_dims
# Other params
for key in ["dropout_rate", "spatial_dropout_rate", "leaky_relu_slope",
"batch_size", "learning_rate", "weight_decay", "optimizer",
"scheduler", "lr_factor", "lr_patience", "cosine_t_0", "cosine_t_mult",
"use_focal_loss", "focal_alpha", "focal_gamma",
"classification_weight", "regression_weight", "threshold"]:
if key in best_params:
params[key] = best_params[key]
# Set defaults for missing params
params.setdefault("dropout_rate", 0.3)
params.setdefault("spatial_dropout_rate", 0.1)
params.setdefault("leaky_relu_slope", 0.1)
params.setdefault("batch_size", 512)
params.setdefault("learning_rate", 1e-3)
params.setdefault("weight_decay", 1e-4)
params.setdefault("optimizer", "adamw")
params.setdefault("scheduler", "plateau")
params.setdefault("classification_weight", 1.0)
params.setdefault("regression_weight", 0.1)
params.setdefault("threshold", 0.5)
params.setdefault("use_focal_loss", True)
params.setdefault("focal_alpha", 0.25)
params.setdefault("focal_gamma", 2.0)
print("Training with parameters:")
for k, v in params.items():
print(f" {k}: {v}")
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
isotope_index = get_default_isotope_index()
model_dir = Path(config.model_dir)
# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(
data_dir=Path(config.data_dir),
batch_size=params["batch_size"],
train_split=config.train_split,
val_split=config.val_split,
test_split=config.test_split,
num_workers=config.num_workers,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
isotope_index=isotope_index,
seed=config.seed
)
# Create model
model = create_model_from_params(params, isotope_index.num_isotopes)
model = model.to(device)
# Create loss
loss_fn = VegaLossV2(
classification_weight=params["classification_weight"],
regression_weight=params["regression_weight"],
use_focal_loss=params.get("use_focal_loss", False),
focal_alpha=params.get("focal_alpha", 0.25),
focal_gamma=params.get("focal_gamma", 2.0)
)
# Optimizer
if params["optimizer"] == "adamw":
optimizer = AdamW(model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
else:
optimizer = Adam(model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
# Scheduler
if params.get("scheduler") == "cosine":
scheduler = CosineAnnealingWarmRestarts(
optimizer, T_0=params.get("cosine_t_0", 10), T_mult=params.get("cosine_t_mult", 1)
)
else:
scheduler = ReduceLROnPlateau(
optimizer, mode='min', patience=params.get("lr_patience", 5), factor=params.get("lr_factor", 0.5)
)
# Mixed precision
scaler = torch.amp.GradScaler('cuda') if config.use_amp and device.type == 'cuda' else None
# Training
best_recall = 0.0
threshold = params["threshold"]
history = []
for epoch in range(full_epochs):
# Train
model.train()
train_loss = 0.0
num_batches = 0
for batch in train_loader:
spectra = batch['spectrum'].to(device)
presence = batch['presence_labels'].to(device)
activities = batch['activity_labels'].to(device)
optimizer.zero_grad()
if scaler is not None:
with torch.amp.autocast('cuda'):
pred_logits, pred_activities = model(spectra)
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
pred_logits, pred_activities = model(spectra)
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
loss.backward()
optimizer.step()
train_loss += loss.item()
num_batches += 1
train_loss /= num_batches
# Validate
val_metrics = validate_model(model, val_loader, device, threshold, loss_fn)
# Scheduler step
if params.get("scheduler") == "cosine":
scheduler.step()
else:
scheduler.step(val_metrics['val_loss'])
# Log
lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1:3d}/{full_epochs} | Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_metrics['val_loss']:.4f} | Recall: {val_metrics['val_recall']:.4f} | "
f"F1: {val_metrics['val_f1_macro']:.4f} | Exact: {val_metrics['val_exact_match']:.4f} | LR: {lr:.2e}")
# Save history
history.append({
'epoch': epoch,
'train_loss': train_loss,
**val_metrics,
'lr': lr
})
# Save best model
if val_metrics['val_recall'] > best_recall:
best_recall = val_metrics['val_recall']
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_recall': best_recall,
'params': params,
'val_metrics': val_metrics
}
torch.save(checkpoint, model_dir / "vega_v2_best.pt")
print(f" └── New best! Saved model with recall: {best_recall:.4f}")
# Save final model
torch.save({
'model_state_dict': model.state_dict(),
'params': params,
'history': history
}, model_dir / "vega_v2_final.pt")
# Save history
with open(model_dir / "vega_v2_history.json", 'w') as f:
json.dump(history, f, indent=2)
# Test evaluation
print("\n" + "=" * 60)
print("TEST SET EVALUATION")
print("=" * 60)
test_metrics = validate_model(model, test_loader, device, threshold, loss_fn)
for k, v in test_metrics.items():
print(f" {k}: {v:.4f}")
return model, {
'best_recall': best_recall,
'test_metrics': test_metrics,
'history': history,
'params': params
}
def main():
import argparse
parser = argparse.ArgumentParser(description="Vega V2 - Optuna Hyperparameter Optimization")
parser.add_argument("--data-dir", type=str, default="O:/master_data_collection/isotopev2",
help="Data directory")
parser.add_argument("--model-dir", type=str, default="models/optuna",
help="Model output directory")
parser.add_argument("--study-name", type=str, default="vega_v2_optimization",
help="Optuna study name")
parser.add_argument("--n-trials", type=int, default=50,
help="Number of Optuna trials")
parser.add_argument("--timeout", type=float, default=24.0,
help="Timeout in hours")
parser.add_argument("--max-epochs", type=int, default=30,
help="Max epochs per trial")
parser.add_argument("--optimize-metric", type=str, default="val_recall",
choices=["val_recall", "val_f1_macro", "val_auc_macro", "val_exact_match"],
help="Metric to optimize")
parser.add_argument("--train-best", action="store_true",
help="Train best model with full epochs after optimization")
parser.add_argument("--full-epochs", type=int, default=100,
help="Epochs for training best model")
parser.add_argument("--workers", type=int, default=8,
help="Number of data loading workers")
args = parser.parse_args()
# Create config
config = OptunaConfig(
data_dir=args.data_dir,
model_dir=args.model_dir,
study_name=args.study_name,
n_trials=args.n_trials,
timeout_hours=args.timeout,
max_epochs=args.max_epochs,
optimize_metric=args.optimize_metric,
num_workers=args.workers
)
# Run optimization
study = run_optimization(config)
# Optionally train best model
if args.train_best:
train_best_model(study, config, full_epochs=args.full_epochs)
return 0
if __name__ == "__main__":
sys.exit(main())