Pipeline complet Radiacode 103 - identification automatique d'isotopes
- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89% - Generation 50k spectres synthetiques 1D (12-24h durees) - Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell) - Detection continue avec soustraction du background - Capture background 24h avec gestion deconnexion - Docker Compose : conteneur train (GPU) + detect (CPU/USB) - Modele entraite inclus (vega_best.pt, 395 Mo) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
26
train/vega_ml/training/vega/__init__.py
Normal file
26
train/vega_ml/training/vega/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
Vega Model - CNN-FCNN with Multi-Task Heads for Gamma Spectrum Isotope Identification
|
||||
|
||||
Architecture based on research findings from:
|
||||
- Wang et al. (2026): CNN-FCNN achieves 99.8% accuracy
|
||||
- Galib et al. (2021): Hybrid CNN outperforms pure architectures
|
||||
- Turner et al. (2021): 1D CNN robust to gain shifts and shielding
|
||||
|
||||
Features:
|
||||
- 1D CNN backbone for spectral feature extraction
|
||||
- Multi-task heads for isotope classification + activity regression
|
||||
- Support for 82 isotopes from the synthetic spectra database
|
||||
"""
|
||||
|
||||
from .model import VegaModel, VegaConfig
|
||||
from .dataset import SpectrumDataset, create_data_loaders
|
||||
from .train import train_vega, VegaTrainer
|
||||
|
||||
__all__ = [
|
||||
'VegaModel',
|
||||
'VegaConfig',
|
||||
'SpectrumDataset',
|
||||
'create_data_loaders',
|
||||
'train_vega',
|
||||
'VegaTrainer'
|
||||
]
|
||||
373
train/vega_ml/training/vega/dataset.py
Normal file
373
train/vega_ml/training/vega/dataset.py
Normal file
@ -0,0 +1,373 @@
|
||||
"""
|
||||
Dataset and DataLoader for Vega Model Training
|
||||
|
||||
Handles loading synthetic gamma spectra from numpy files and converting
|
||||
them to PyTorch tensors with proper labels for multi-task learning.
|
||||
|
||||
Supports two label formats:
|
||||
1. Individual JSON files per sample (recommended for large datasets)
|
||||
2. Combined labels.json file (legacy format)
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .isotope_index import IsotopeIndex, get_default_isotope_index
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpectrumSample:
|
||||
"""A single spectrum sample with metadata."""
|
||||
sample_id: str
|
||||
spectrum: np.ndarray # 2D array (time_intervals, channels) or 1D (channels,)
|
||||
isotopes_present: List[str]
|
||||
activities_bq: Dict[str, float]
|
||||
duration_seconds: float
|
||||
detector: str
|
||||
|
||||
|
||||
class SpectrumDataset(Dataset):
|
||||
"""
|
||||
PyTorch Dataset for synthetic gamma spectra.
|
||||
|
||||
Loads spectra from numpy files and their labels from JSON files.
|
||||
Supports both individual JSON files per sample (efficient for large datasets)
|
||||
and combined labels.json (legacy format).
|
||||
|
||||
Converts to tensors suitable for the Vega model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
isotope_index: Optional[IsotopeIndex] = None,
|
||||
max_activity_bq: float = 1000.0,
|
||||
collapse_time: bool = True,
|
||||
transform=None
|
||||
):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
|
||||
Args:
|
||||
data_dir: Path to directory containing spectra/ subdirectory
|
||||
isotope_index: Index mapping isotope names to indices
|
||||
max_activity_bq: Maximum activity for normalization
|
||||
collapse_time: If True, average across time dimension to get 1D spectrum
|
||||
transform: Optional transform to apply to spectra
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self.spectra_dir = self.data_dir / "spectra"
|
||||
self.isotope_index = isotope_index or get_default_isotope_index()
|
||||
self.max_activity_bq = max_activity_bq
|
||||
self.collapse_time = collapse_time
|
||||
self.transform = transform
|
||||
|
||||
# Detect label format and load sample list
|
||||
self.use_individual_labels = self._detect_label_format()
|
||||
|
||||
if self.use_individual_labels:
|
||||
# Scan for individual JSON files (efficient - no loading needed)
|
||||
self.sample_ids = self._scan_for_samples()
|
||||
self.metadata = None # Labels loaded on-demand
|
||||
print(f"Using individual label files (efficient mode)")
|
||||
else:
|
||||
# Load combined labels.json (legacy mode)
|
||||
self.metadata = self._load_metadata()
|
||||
self.sample_ids = list(self.metadata['samples'].keys())
|
||||
print(f"Using combined labels.json (legacy mode)")
|
||||
|
||||
print(f"Loaded dataset with {len(self.sample_ids)} samples")
|
||||
print(f"Isotope index has {self.isotope_index.num_isotopes} isotopes")
|
||||
|
||||
def _detect_label_format(self) -> bool:
|
||||
"""Detect whether to use individual JSON files or combined labels.json."""
|
||||
# Check if individual JSON files exist
|
||||
json_files = list(self.spectra_dir.glob("spectrum_*.json"))
|
||||
if len(json_files) > 0:
|
||||
return True
|
||||
|
||||
# Fall back to combined labels.json
|
||||
labels_path = self.data_dir / "labels.json"
|
||||
if labels_path.exists():
|
||||
return False
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"No label files found. Expected either:\n"
|
||||
f" - Individual files: {self.spectra_dir}/spectrum_*.json\n"
|
||||
f" - Combined file: {self.data_dir}/labels.json"
|
||||
)
|
||||
|
||||
def _scan_for_samples(self) -> List[str]:
|
||||
"""Scan directory for sample IDs based on .npy files."""
|
||||
npy_files = sorted(self.spectra_dir.glob("spectrum_*.npy"))
|
||||
sample_ids = []
|
||||
for npy_path in npy_files:
|
||||
# Extract sample ID from filename: spectrum_{id}.npy
|
||||
filename = npy_path.stem # spectrum_{id}
|
||||
sample_id = filename.replace("spectrum_", "")
|
||||
sample_ids.append(sample_id)
|
||||
return sample_ids
|
||||
|
||||
def _load_metadata(self) -> Dict:
|
||||
"""Load the combined labels.json metadata file (legacy format)."""
|
||||
labels_path = self.data_dir / "labels.json"
|
||||
if not labels_path.exists():
|
||||
raise FileNotFoundError(f"Labels file not found: {labels_path}")
|
||||
|
||||
with open(labels_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def _load_sample_label(self, sample_id: str) -> Dict:
|
||||
"""Load label for a single sample (individual JSON or from combined)."""
|
||||
if self.use_individual_labels:
|
||||
json_path = self.spectra_dir / f"spectrum_{sample_id}.json"
|
||||
with open(json_path, 'r') as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
return self.metadata['samples'][sample_id]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sample_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get a single sample.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- spectrum: Tensor of shape (num_channels,)
|
||||
- presence_labels: Binary tensor (num_isotopes,) indicating presence
|
||||
- activity_labels: Tensor (num_isotopes,) with normalized activities
|
||||
- sample_id: String identifier
|
||||
"""
|
||||
sample_id = self.sample_ids[idx]
|
||||
sample_meta = self._load_sample_label(sample_id)
|
||||
|
||||
# Load spectrum
|
||||
spectrum_path = self.spectra_dir / f"spectrum_{sample_id}.npy"
|
||||
spectrum = np.load(spectrum_path)
|
||||
|
||||
# Collapse time dimension if needed
|
||||
if self.collapse_time and spectrum.ndim == 2:
|
||||
# Average across time intervals to get single spectrum
|
||||
spectrum = spectrum.mean(axis=0)
|
||||
|
||||
# Convert to tensor
|
||||
spectrum_tensor = torch.tensor(spectrum, dtype=torch.float32)
|
||||
|
||||
# Apply transform if provided
|
||||
if self.transform:
|
||||
spectrum_tensor = self.transform(spectrum_tensor)
|
||||
|
||||
# Create presence labels
|
||||
presence_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
|
||||
for isotope_name in sample_meta['isotopes']:
|
||||
try:
|
||||
idx_isotope = self.isotope_index.name_to_index(isotope_name)
|
||||
presence_labels[idx_isotope] = 1.0
|
||||
except KeyError:
|
||||
# Isotope not in our index (might be a decay product)
|
||||
pass
|
||||
|
||||
# Create activity labels (normalized)
|
||||
activity_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
|
||||
for isotope_name, activity in sample_meta.get('source_activities_bq', {}).items():
|
||||
try:
|
||||
idx_isotope = self.isotope_index.name_to_index(isotope_name)
|
||||
# Normalize activity to [0, 1] range
|
||||
activity_labels[idx_isotope] = min(activity / self.max_activity_bq, 1.0)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return {
|
||||
'spectrum': spectrum_tensor,
|
||||
'presence_labels': presence_labels,
|
||||
'activity_labels': activity_labels,
|
||||
'sample_id': sample_id
|
||||
}
|
||||
|
||||
def get_sample_info(self, idx: int) -> Dict:
|
||||
"""Get metadata for a sample without loading the spectrum."""
|
||||
sample_id = self.sample_ids[idx]
|
||||
return {
|
||||
'sample_id': sample_id,
|
||||
**self.metadata['samples'][sample_id]
|
||||
}
|
||||
|
||||
|
||||
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Custom collate function to handle batching.
|
||||
|
||||
Args:
|
||||
batch: List of sample dictionaries
|
||||
|
||||
Returns:
|
||||
Batched dictionary with stacked tensors
|
||||
"""
|
||||
return {
|
||||
'spectrum': torch.stack([s['spectrum'] for s in batch]),
|
||||
'presence_labels': torch.stack([s['presence_labels'] for s in batch]),
|
||||
'activity_labels': torch.stack([s['activity_labels'] for s in batch]),
|
||||
'sample_ids': [s['sample_id'] for s in batch]
|
||||
}
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
data_dir: Path,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
val_split: float = 0.1,
|
||||
test_split: float = 0.1,
|
||||
num_workers: int = 8,
|
||||
prefetch_factor: int = 4,
|
||||
persistent_workers: bool = True,
|
||||
isotope_index: Optional[IsotopeIndex] = None,
|
||||
max_activity_bq: float = 1000.0,
|
||||
seed: int = 42
|
||||
) -> Tuple[DataLoader, DataLoader, DataLoader]:
|
||||
"""
|
||||
Create train, validation, and test data loaders.
|
||||
|
||||
Args:
|
||||
data_dir: Path to data directory
|
||||
batch_size: Batch size for training
|
||||
train_split: Fraction of data for training
|
||||
val_split: Fraction of data for validation
|
||||
test_split: Fraction of data for testing
|
||||
num_workers: Number of data loading workers (parallel I/O)
|
||||
prefetch_factor: Batches to prefetch per worker
|
||||
persistent_workers: Keep workers alive between epochs
|
||||
isotope_index: Isotope name to index mapping
|
||||
max_activity_bq: Maximum activity for normalization
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader, test_loader)
|
||||
"""
|
||||
assert abs(train_split + val_split + test_split - 1.0) < 1e-6, \
|
||||
"Splits must sum to 1.0"
|
||||
|
||||
# Create full dataset
|
||||
full_dataset = SpectrumDataset(
|
||||
data_dir=data_dir,
|
||||
isotope_index=isotope_index,
|
||||
max_activity_bq=max_activity_bq
|
||||
)
|
||||
|
||||
# Calculate split sizes
|
||||
total_size = len(full_dataset)
|
||||
train_size = int(total_size * train_split)
|
||||
val_size = int(total_size * val_split)
|
||||
test_size = total_size - train_size - val_size
|
||||
|
||||
# Handle small datasets
|
||||
if train_size == 0:
|
||||
train_size = max(1, total_size - 2)
|
||||
if val_size == 0 and total_size > 1:
|
||||
val_size = 1
|
||||
train_size = max(1, train_size - 1)
|
||||
if test_size == 0 and total_size > 2:
|
||||
test_size = 1
|
||||
train_size = max(1, train_size - 1)
|
||||
|
||||
# Ensure sizes add up
|
||||
test_size = total_size - train_size - val_size
|
||||
|
||||
print(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
|
||||
|
||||
# Split dataset
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
full_dataset,
|
||||
[train_size, val_size, test_size],
|
||||
generator=generator
|
||||
)
|
||||
|
||||
# Create data loaders with parallel loading support
|
||||
# For Windows, num_workers > 0 requires spawn method (handled by PyTorch)
|
||||
use_workers = num_workers > 0
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=min(batch_size, train_size),
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True,
|
||||
prefetch_factor=prefetch_factor if use_workers else None,
|
||||
persistent_workers=persistent_workers and use_workers,
|
||||
drop_last=True # Drop incomplete batches for consistent training
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=min(batch_size, max(1, val_size)),
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True,
|
||||
prefetch_factor=prefetch_factor if use_workers else None,
|
||||
persistent_workers=persistent_workers and use_workers
|
||||
) if val_size > 0 else None
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=min(batch_size, max(1, test_size)),
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True,
|
||||
prefetch_factor=prefetch_factor if use_workers else None,
|
||||
persistent_workers=persistent_workers and use_workers
|
||||
) if test_size > 0 else None
|
||||
|
||||
if num_workers > 0:
|
||||
print(f"DataLoader: {num_workers} workers, prefetch_factor={prefetch_factor}, persistent={persistent_workers}")
|
||||
|
||||
return train_loader, val_loader, test_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# Test dataset loading
|
||||
data_dir = Path(__file__).parent.parent.parent / "data" / "synthetic"
|
||||
|
||||
if not data_dir.exists():
|
||||
print(f"Data directory not found: {data_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# Create dataset
|
||||
dataset = SpectrumDataset(data_dir)
|
||||
print(f"\nDataset size: {len(dataset)}")
|
||||
|
||||
# Get a sample
|
||||
sample = dataset[0]
|
||||
print(f"\nSample keys: {sample.keys()}")
|
||||
print(f"Spectrum shape: {sample['spectrum'].shape}")
|
||||
print(f"Presence labels shape: {sample['presence_labels'].shape}")
|
||||
print(f"Activity labels shape: {sample['activity_labels'].shape}")
|
||||
print(f"Presence sum: {sample['presence_labels'].sum().item()}")
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader, test_loader = create_data_loaders(
|
||||
data_dir,
|
||||
batch_size=4
|
||||
)
|
||||
|
||||
print(f"\nTrain batches: {len(train_loader)}")
|
||||
if val_loader:
|
||||
print(f"Val batches: {len(val_loader)}")
|
||||
if test_loader:
|
||||
print(f"Test batches: {len(test_loader)}")
|
||||
|
||||
# Test a batch
|
||||
batch = next(iter(train_loader))
|
||||
print(f"\nBatch spectrum shape: {batch['spectrum'].shape}")
|
||||
print(f"Batch presence shape: {batch['presence_labels'].shape}")
|
||||
308
train/vega_ml/training/vega/dataset_2d.py
Normal file
308
train/vega_ml/training/vega/dataset_2d.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""
|
||||
Dataset for 2D Vega Model
|
||||
|
||||
Loads 2D spectra (time × channels) and pads/truncates to fixed dimensions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .isotope_index import IsotopeIndex, get_default_isotope_index
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpectrumSample2D:
|
||||
"""A single 2D spectrum sample."""
|
||||
sample_id: str
|
||||
spectrum: np.ndarray # 2D array (time_intervals, channels)
|
||||
isotopes_present: List[str]
|
||||
activities_bq: Dict[str, float]
|
||||
duration_seconds: float
|
||||
detector: str
|
||||
|
||||
|
||||
class SpectrumDataset2D(Dataset):
|
||||
"""
|
||||
PyTorch Dataset for 2D gamma spectra.
|
||||
|
||||
Pads or truncates time dimension to fixed size for batch processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
isotope_index: Optional[IsotopeIndex] = None,
|
||||
max_activity_bq: float = 1000.0,
|
||||
target_time_intervals: int = 60,
|
||||
transform=None
|
||||
):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
|
||||
Args:
|
||||
data_dir: Path to directory containing spectra/ subdirectory
|
||||
isotope_index: Index mapping isotope names to indices
|
||||
max_activity_bq: Maximum activity for normalization
|
||||
target_time_intervals: Fixed time dimension (pad/truncate to this)
|
||||
transform: Optional transform to apply
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self.spectra_dir = self.data_dir / "spectra"
|
||||
self.isotope_index = isotope_index or get_default_isotope_index()
|
||||
self.max_activity_bq = max_activity_bq
|
||||
self.target_time_intervals = target_time_intervals
|
||||
self.transform = transform
|
||||
|
||||
# Detect label format and load sample list
|
||||
self.use_individual_labels = self._detect_label_format()
|
||||
|
||||
if self.use_individual_labels:
|
||||
self.sample_ids = self._scan_for_samples()
|
||||
self.metadata = None
|
||||
print(f"Using individual label files (efficient mode)")
|
||||
else:
|
||||
self.metadata = self._load_metadata()
|
||||
self.sample_ids = list(self.metadata['samples'].keys())
|
||||
print(f"Using combined labels.json (legacy mode)")
|
||||
|
||||
print(f"Loaded 2D dataset with {len(self.sample_ids)} samples")
|
||||
print(f"Target shape: ({target_time_intervals}, 1023)")
|
||||
print(f"Isotope index has {self.isotope_index.num_isotopes} isotopes")
|
||||
|
||||
def _detect_label_format(self) -> bool:
|
||||
"""Detect whether to use individual JSON files or combined labels.json."""
|
||||
json_files = list(self.spectra_dir.glob("spectrum_*.json"))
|
||||
if len(json_files) > 0:
|
||||
return True
|
||||
|
||||
labels_path = self.data_dir / "labels.json"
|
||||
if labels_path.exists():
|
||||
return False
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"No label files found. Expected either:\n"
|
||||
f" - Individual files: {self.spectra_dir}/spectrum_*.json\n"
|
||||
f" - Combined file: {self.data_dir}/labels.json"
|
||||
)
|
||||
|
||||
def _scan_for_samples(self) -> List[str]:
|
||||
"""Scan directory for sample IDs based on .npy files."""
|
||||
npy_files = sorted(self.spectra_dir.glob("spectrum_*.npy"))
|
||||
sample_ids = []
|
||||
for npy_path in npy_files:
|
||||
filename = npy_path.stem
|
||||
sample_id = filename.replace("spectrum_", "")
|
||||
sample_ids.append(sample_id)
|
||||
return sample_ids
|
||||
|
||||
def _load_metadata(self) -> Dict:
|
||||
"""Load the combined labels.json metadata file."""
|
||||
labels_path = self.data_dir / "labels.json"
|
||||
if not labels_path.exists():
|
||||
raise FileNotFoundError(f"Labels file not found: {labels_path}")
|
||||
|
||||
with open(labels_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def _load_sample_label(self, sample_id: str) -> Dict:
|
||||
"""Load label for a single sample."""
|
||||
if self.use_individual_labels:
|
||||
json_path = self.spectra_dir / f"spectrum_{sample_id}.json"
|
||||
with open(json_path, 'r') as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
return self.metadata['samples'][sample_id]
|
||||
|
||||
def _pad_or_truncate(self, spectrum: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Pad or truncate spectrum to target time dimension.
|
||||
|
||||
Args:
|
||||
spectrum: 2D array (time, channels)
|
||||
|
||||
Returns:
|
||||
Array of shape (target_time_intervals, channels)
|
||||
"""
|
||||
current_time = spectrum.shape[0]
|
||||
target_time = self.target_time_intervals
|
||||
num_channels = spectrum.shape[1]
|
||||
|
||||
if current_time == target_time:
|
||||
return spectrum
|
||||
|
||||
elif current_time > target_time:
|
||||
# Truncate: take evenly spaced intervals to preserve temporal coverage
|
||||
indices = np.linspace(0, current_time - 1, target_time, dtype=int)
|
||||
return spectrum[indices, :]
|
||||
|
||||
else:
|
||||
# Pad with zeros at the end
|
||||
padded = np.zeros((target_time, num_channels), dtype=spectrum.dtype)
|
||||
padded[:current_time, :] = spectrum
|
||||
return padded
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sample_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get a single sample.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- spectrum: Tensor of shape (target_time_intervals, num_channels)
|
||||
- presence_labels: Binary tensor (num_isotopes,)
|
||||
- activity_labels: Tensor (num_isotopes,) with normalized activities
|
||||
- sample_id: String identifier
|
||||
"""
|
||||
sample_id = self.sample_ids[idx]
|
||||
sample_meta = self._load_sample_label(sample_id)
|
||||
|
||||
# Load spectrum
|
||||
spectrum_path = self.spectra_dir / f"spectrum_{sample_id}.npy"
|
||||
spectrum = np.load(spectrum_path)
|
||||
|
||||
# Ensure 2D
|
||||
if spectrum.ndim == 1:
|
||||
spectrum = spectrum.reshape(1, -1)
|
||||
|
||||
# Pad/truncate to fixed time dimension
|
||||
spectrum = self._pad_or_truncate(spectrum)
|
||||
|
||||
# Normalize (max normalization)
|
||||
max_val = spectrum.max()
|
||||
if max_val > 0:
|
||||
spectrum = spectrum / max_val
|
||||
|
||||
# Convert to tensor
|
||||
spectrum_tensor = torch.tensor(spectrum, dtype=torch.float32)
|
||||
|
||||
# Apply transform if provided
|
||||
if self.transform:
|
||||
spectrum_tensor = self.transform(spectrum_tensor)
|
||||
|
||||
# Create presence labels
|
||||
presence_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
|
||||
for isotope_name in sample_meta['isotopes']:
|
||||
try:
|
||||
idx_isotope = self.isotope_index.name_to_index(isotope_name)
|
||||
presence_labels[idx_isotope] = 1.0
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Create activity labels (normalized)
|
||||
activity_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
|
||||
for isotope_name, activity in sample_meta.get('source_activities_bq', {}).items():
|
||||
try:
|
||||
idx_isotope = self.isotope_index.name_to_index(isotope_name)
|
||||
activity_labels[idx_isotope] = min(activity / self.max_activity_bq, 1.0)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return {
|
||||
'spectrum': spectrum_tensor,
|
||||
'presence_labels': presence_labels,
|
||||
'activity_labels': activity_labels,
|
||||
'sample_id': sample_id
|
||||
}
|
||||
|
||||
|
||||
def collate_fn_2d(batch: List[Dict]) -> Dict[str, torch.Tensor]:
|
||||
"""Custom collate function for 2D batching."""
|
||||
return {
|
||||
'spectrum': torch.stack([s['spectrum'] for s in batch]),
|
||||
'presence_labels': torch.stack([s['presence_labels'] for s in batch]),
|
||||
'activity_labels': torch.stack([s['activity_labels'] for s in batch]),
|
||||
'sample_ids': [s['sample_id'] for s in batch]
|
||||
}
|
||||
|
||||
|
||||
def create_data_loaders_2d(
|
||||
data_dir: Path,
|
||||
batch_size: int = 32,
|
||||
train_split: float = 0.8,
|
||||
val_split: float = 0.1,
|
||||
test_split: float = 0.1,
|
||||
num_workers: int = 4,
|
||||
target_time_intervals: int = 60,
|
||||
isotope_index: Optional[IsotopeIndex] = None,
|
||||
max_activity_bq: float = 1000.0,
|
||||
seed: int = 42
|
||||
) -> Tuple[DataLoader, DataLoader, DataLoader]:
|
||||
"""
|
||||
Create train, validation, and test data loaders for 2D data.
|
||||
"""
|
||||
# Create full dataset
|
||||
dataset = SpectrumDataset2D(
|
||||
data_dir=data_dir,
|
||||
isotope_index=isotope_index,
|
||||
max_activity_bq=max_activity_bq,
|
||||
target_time_intervals=target_time_intervals
|
||||
)
|
||||
|
||||
# Calculate split sizes
|
||||
total = len(dataset)
|
||||
train_size = int(total * train_split)
|
||||
val_size = int(total * val_split)
|
||||
test_size = total - train_size - val_size
|
||||
|
||||
# Split dataset
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
dataset, [train_size, val_size, test_size], generator=generator
|
||||
)
|
||||
|
||||
print(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
|
||||
|
||||
# Create loaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_2d,
|
||||
pin_memory=True,
|
||||
persistent_workers=num_workers > 0
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_2d,
|
||||
pin_memory=True,
|
||||
persistent_workers=num_workers > 0
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_2d,
|
||||
pin_memory=True,
|
||||
persistent_workers=num_workers > 0
|
||||
)
|
||||
|
||||
return train_loader, val_loader, test_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the dataset
|
||||
from pathlib import Path
|
||||
|
||||
data_dir = Path("O:/master_data_collection/isotopev2")
|
||||
|
||||
dataset = SpectrumDataset2D(data_dir, target_time_intervals=60)
|
||||
sample = dataset[0]
|
||||
|
||||
print(f"\nSample:")
|
||||
print(f" Spectrum shape: {sample['spectrum'].shape}")
|
||||
print(f" Presence labels: {sample['presence_labels'].sum().item():.0f} isotopes")
|
||||
print(f" Sample ID: {sample['sample_id']}")
|
||||
141
train/vega_ml/training/vega/isotope_index.py
Normal file
141
train/vega_ml/training/vega/isotope_index.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""
|
||||
Isotope Index - Mapping between isotope names and model output indices.
|
||||
|
||||
This module provides a consistent mapping between isotope names and their
|
||||
corresponding indices in the model's output tensors. This is critical for
|
||||
training and inference to ensure consistent label encoding.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# Add project root to path for imports
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from synthetic_spectra.ground_truth.isotope_data import ISOTOPE_DATABASE, get_isotope_names
|
||||
|
||||
|
||||
class IsotopeIndex:
|
||||
"""
|
||||
Manages the mapping between isotope names and model indices.
|
||||
|
||||
The index is deterministic - isotopes are sorted alphabetically to ensure
|
||||
consistent ordering across training and inference.
|
||||
"""
|
||||
|
||||
def __init__(self, isotope_names: Optional[List[str]] = None):
|
||||
"""
|
||||
Initialize the isotope index.
|
||||
|
||||
Args:
|
||||
isotope_names: Optional list of isotope names. If None, uses all
|
||||
isotopes from the database.
|
||||
"""
|
||||
if isotope_names is None:
|
||||
isotope_names = get_isotope_names()
|
||||
|
||||
# Sort alphabetically for deterministic ordering
|
||||
self._isotope_names = sorted(isotope_names)
|
||||
|
||||
# Build bidirectional mappings
|
||||
self._name_to_idx: Dict[str, int] = {
|
||||
name: idx for idx, name in enumerate(self._isotope_names)
|
||||
}
|
||||
self._idx_to_name: Dict[int, str] = {
|
||||
idx: name for idx, name in enumerate(self._isotope_names)
|
||||
}
|
||||
|
||||
@property
|
||||
def num_isotopes(self) -> int:
|
||||
"""Total number of isotopes in the index."""
|
||||
return len(self._isotope_names)
|
||||
|
||||
@property
|
||||
def isotope_names(self) -> List[str]:
|
||||
"""List of all isotope names in index order."""
|
||||
return self._isotope_names.copy()
|
||||
|
||||
def name_to_index(self, name: str) -> int:
|
||||
"""
|
||||
Get the index for an isotope name.
|
||||
|
||||
Args:
|
||||
name: Isotope name (e.g., "Cs-137")
|
||||
|
||||
Returns:
|
||||
Integer index for the isotope
|
||||
|
||||
Raises:
|
||||
KeyError: If isotope name not in index
|
||||
"""
|
||||
if name not in self._name_to_idx:
|
||||
raise KeyError(f"Isotope '{name}' not found in index. "
|
||||
f"Available isotopes: {self._isotope_names[:5]}...")
|
||||
return self._name_to_idx[name]
|
||||
|
||||
def index_to_name(self, idx: int) -> str:
|
||||
"""
|
||||
Get the isotope name for an index.
|
||||
|
||||
Args:
|
||||
idx: Integer index
|
||||
|
||||
Returns:
|
||||
Isotope name string
|
||||
|
||||
Raises:
|
||||
KeyError: If index out of range
|
||||
"""
|
||||
if idx not in self._idx_to_name:
|
||||
raise KeyError(f"Index {idx} out of range. Valid range: 0-{self.num_isotopes-1}")
|
||||
return self._idx_to_name[idx]
|
||||
|
||||
def names_to_indices(self, names: List[str]) -> List[int]:
|
||||
"""Convert list of names to list of indices."""
|
||||
return [self.name_to_index(name) for name in names]
|
||||
|
||||
def indices_to_names(self, indices: List[int]) -> List[str]:
|
||||
"""Convert list of indices to list of names."""
|
||||
return [self.index_to_name(idx) for idx in indices]
|
||||
|
||||
def save(self, path: Path):
|
||||
"""Save the isotope index to a file."""
|
||||
with open(path, 'w') as f:
|
||||
for name in self._isotope_names:
|
||||
f.write(f"{name}\n")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> 'IsotopeIndex':
|
||||
"""Load an isotope index from a file."""
|
||||
with open(path, 'r') as f:
|
||||
isotope_names = [line.strip() for line in f if line.strip()]
|
||||
return cls(isotope_names)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IsotopeIndex(num_isotopes={self.num_isotopes})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_isotopes
|
||||
|
||||
|
||||
# Global default isotope index using all isotopes from database
|
||||
DEFAULT_ISOTOPE_INDEX = IsotopeIndex()
|
||||
|
||||
|
||||
def get_default_isotope_index() -> IsotopeIndex:
|
||||
"""Get the default isotope index with all database isotopes."""
|
||||
return DEFAULT_ISOTOPE_INDEX
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Print isotope index information
|
||||
index = get_default_isotope_index()
|
||||
print(f"Isotope Index: {index}")
|
||||
print(f"\nFirst 10 isotopes:")
|
||||
for i in range(min(10, index.num_isotopes)):
|
||||
print(f" {i:3d}: {index.index_to_name(i)}")
|
||||
print(f"\nLast 10 isotopes:")
|
||||
for i in range(max(0, index.num_isotopes - 10), index.num_isotopes):
|
||||
print(f" {i:3d}: {index.index_to_name(i)}")
|
||||
416
train/vega_ml/training/vega/model.py
Normal file
416
train/vega_ml/training/vega/model.py
Normal file
@ -0,0 +1,416 @@
|
||||
"""
|
||||
Vega Model Architecture - CNN-FCNN with Multi-Task Heads
|
||||
|
||||
A hybrid Convolutional Neural Network with Fully Connected Neural Network
|
||||
for gamma spectrum isotope identification. Based on peer-reviewed research
|
||||
showing CNN-FCNN achieves state-of-the-art performance (99%+ accuracy).
|
||||
|
||||
Architecture:
|
||||
Input: 1D gamma spectrum (1023 channels, 20-3000 keV)
|
||||
↓
|
||||
Feature Extraction: 3 CNN modules with LeakyReLU, MaxPool, Dropout
|
||||
↓
|
||||
Classification Head: Dense layers → Sigmoid (multi-label isotope presence)
|
||||
↓
|
||||
Regression Head: Dense layers → ReLU (activity estimation in Bq)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class VegaConfig:
|
||||
"""Configuration for the Vega model."""
|
||||
|
||||
# Input configuration
|
||||
num_channels: int = 1023 # Number of energy channels in spectrum
|
||||
|
||||
# Number of isotopes to classify
|
||||
num_isotopes: int = 82 # From isotope database
|
||||
|
||||
# CNN backbone configuration
|
||||
conv_channels: List[int] = field(default_factory=lambda: [64, 128, 256])
|
||||
conv_kernel_size: int = 7
|
||||
pool_size: int = 2
|
||||
|
||||
# Classification head configuration
|
||||
fc_hidden_dims: List[int] = field(default_factory=lambda: [512, 256])
|
||||
|
||||
# Regularization
|
||||
dropout_rate: float = 0.3
|
||||
spatial_dropout_rate: float = 0.1
|
||||
|
||||
# Activation
|
||||
leaky_relu_slope: float = 0.1
|
||||
|
||||
# Loss weighting
|
||||
classification_weight: float = 1.0
|
||||
regression_weight: float = 0.1
|
||||
|
||||
# Training
|
||||
max_activity_bq: float = 1000.0 # For activity normalization
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""
|
||||
Convolutional block with two conv layers, activation, pooling, and dropout.
|
||||
|
||||
Based on Turner et al. (2021) architecture showing that stacking two
|
||||
convolutions per module with pooling achieves good feature extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 7,
|
||||
pool_size: int = 2,
|
||||
dropout_rate: float = 0.1,
|
||||
leaky_slope: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# First convolution
|
||||
self.conv1 = nn.Conv1d(
|
||||
in_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2
|
||||
)
|
||||
self.bn1 = nn.BatchNorm1d(out_channels)
|
||||
self.act1 = nn.LeakyReLU(leaky_slope)
|
||||
|
||||
# Second convolution
|
||||
self.conv2 = nn.Conv1d(
|
||||
out_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2
|
||||
)
|
||||
self.bn2 = nn.BatchNorm1d(out_channels)
|
||||
self.act2 = nn.LeakyReLU(leaky_slope)
|
||||
|
||||
# Pooling and dropout
|
||||
self.pool = nn.MaxPool1d(pool_size)
|
||||
self.dropout = nn.Dropout1d(dropout_rate) # Spatial dropout for 1D
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# First conv block
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Second conv block
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Pool and dropout
|
||||
x = self.pool(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VegaModel(nn.Module):
|
||||
"""
|
||||
Vega: CNN-FCNN Multi-Task Model for Isotope Identification
|
||||
|
||||
Named after the bright star Vega (α Lyrae), which emits radiation
|
||||
across the electromagnetic spectrum - fitting for a gamma spectrum analyzer.
|
||||
|
||||
The model performs two tasks:
|
||||
1. Multi-label classification: Which isotopes are present?
|
||||
2. Activity regression: What is the activity (Bq) of each isotope?
|
||||
"""
|
||||
|
||||
def __init__(self, config: VegaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Build CNN backbone
|
||||
self.backbone = self._build_backbone()
|
||||
|
||||
# Calculate flattened size after backbone
|
||||
self._flat_size = self._calculate_flat_size()
|
||||
|
||||
# Build classification head (multi-label)
|
||||
self.classifier = self._build_classifier()
|
||||
|
||||
# Build regression head (activity estimation)
|
||||
self.regressor = self._build_regressor()
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _build_backbone(self) -> nn.Sequential:
|
||||
"""Build the CNN feature extraction backbone."""
|
||||
layers = []
|
||||
in_channels = 1 # Input is 1D spectrum with 1 channel
|
||||
|
||||
for out_channels in self.config.conv_channels:
|
||||
layers.append(ConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=self.config.conv_kernel_size,
|
||||
pool_size=self.config.pool_size,
|
||||
dropout_rate=self.config.spatial_dropout_rate,
|
||||
leaky_slope=self.config.leaky_relu_slope
|
||||
))
|
||||
in_channels = out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _calculate_flat_size(self) -> int:
|
||||
"""Calculate the size of flattened features after backbone."""
|
||||
# Create dummy input to calculate size
|
||||
dummy = torch.zeros(1, 1, self.config.num_channels)
|
||||
with torch.no_grad():
|
||||
out = self.backbone(dummy)
|
||||
return out.numel()
|
||||
|
||||
def _build_classifier(self) -> nn.Sequential:
|
||||
"""Build the classification head for isotope presence prediction.
|
||||
|
||||
Outputs raw logits (not probabilities) for AMP compatibility.
|
||||
Use BCEWithLogitsLoss for training, apply sigmoid during inference.
|
||||
"""
|
||||
layers = []
|
||||
in_features = self._flat_size
|
||||
|
||||
# Hidden layers
|
||||
for hidden_dim in self.config.fc_hidden_dims:
|
||||
layers.extend([
|
||||
nn.Linear(in_features, hidden_dim),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.LeakyReLU(self.config.leaky_relu_slope),
|
||||
nn.Dropout(self.config.dropout_rate)
|
||||
])
|
||||
in_features = hidden_dim
|
||||
|
||||
# Output layer - raw logits for AMP compatibility
|
||||
layers.append(nn.Linear(in_features, self.config.num_isotopes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_regressor(self) -> nn.Sequential:
|
||||
"""Build the regression head for activity estimation."""
|
||||
layers = []
|
||||
in_features = self._flat_size
|
||||
|
||||
# Hidden layers (shared architecture with classifier)
|
||||
for hidden_dim in self.config.fc_hidden_dims:
|
||||
layers.extend([
|
||||
nn.Linear(in_features, hidden_dim),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.LeakyReLU(self.config.leaky_relu_slope),
|
||||
nn.Dropout(self.config.dropout_rate)
|
||||
])
|
||||
in_features = hidden_dim
|
||||
|
||||
# Output layer with ReLU for non-negative activity values
|
||||
layers.extend([
|
||||
nn.Linear(in_features, self.config.num_isotopes),
|
||||
nn.ReLU() # Activity must be non-negative
|
||||
])
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize weights using He initialization for LeakyReLU."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, (nn.Conv1d, nn.Linear)):
|
||||
nn.init.kaiming_normal_(
|
||||
module.weight,
|
||||
a=self.config.leaky_relu_slope,
|
||||
mode='fan_out',
|
||||
nonlinearity='leaky_relu'
|
||||
)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm1d):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass through the model.
|
||||
|
||||
Args:
|
||||
x: Input spectrum tensor of shape (batch, channels) or (batch, 1, channels)
|
||||
Values should be normalized [0, 1]
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- isotope_logits: Raw logits for each isotope (batch, num_isotopes)
|
||||
Apply sigmoid to get probabilities for inference
|
||||
- activity_pred: Predicted activity in Bq for each isotope (batch, num_isotopes)
|
||||
"""
|
||||
# Ensure input has channel dimension
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1) # (batch, channels) -> (batch, 1, channels)
|
||||
|
||||
# Feature extraction
|
||||
features = self.backbone(x)
|
||||
features = features.flatten(start_dim=1)
|
||||
|
||||
# Classification head (outputs logits)
|
||||
isotope_logits = self.classifier(features)
|
||||
|
||||
# Regression head
|
||||
activity_pred = self.regressor(features)
|
||||
|
||||
return isotope_logits, activity_pred
|
||||
|
||||
def predict(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
threshold: float = 0.5,
|
||||
return_all: bool = False
|
||||
) -> Dict:
|
||||
"""
|
||||
Make predictions with post-processing.
|
||||
|
||||
Args:
|
||||
x: Input spectrum tensor
|
||||
threshold: Probability threshold for isotope presence
|
||||
return_all: If True, return predictions for all isotopes
|
||||
|
||||
Returns:
|
||||
Dictionary with predictions
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
probs, activities = self(x)
|
||||
|
||||
# Apply threshold
|
||||
present = probs >= threshold
|
||||
|
||||
# Mask activities by presence
|
||||
masked_activities = activities * present.float()
|
||||
|
||||
return {
|
||||
'probabilities': probs,
|
||||
'activities_bq': masked_activities * self.config.max_activity_bq,
|
||||
'present_mask': present,
|
||||
'threshold': threshold
|
||||
}
|
||||
|
||||
def count_parameters(self) -> int:
|
||||
"""Count total trainable parameters."""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Get a summary of the model architecture."""
|
||||
lines = [
|
||||
"=" * 60,
|
||||
"VEGA Model - CNN-FCNN Multi-Task Isotope Identifier",
|
||||
"=" * 60,
|
||||
f"Input channels: {self.config.num_channels}",
|
||||
f"Output isotopes: {self.config.num_isotopes}",
|
||||
f"CNN channels: {self.config.conv_channels}",
|
||||
f"FC hidden dims: {self.config.fc_hidden_dims}",
|
||||
f"Dropout rate: {self.config.dropout_rate}",
|
||||
f"Total parameters: {self.count_parameters():,}",
|
||||
"=" * 60
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class VegaLoss(nn.Module):
|
||||
"""
|
||||
Combined loss function for Vega multi-task learning.
|
||||
|
||||
Combines:
|
||||
- Binary Cross-Entropy for isotope classification (multi-label)
|
||||
- Huber Loss for activity regression (robust to outliers)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classification_weight: float = 1.0,
|
||||
regression_weight: float = 0.1,
|
||||
huber_delta: float = 1.0
|
||||
):
|
||||
super().__init__()
|
||||
self.classification_weight = classification_weight
|
||||
self.regression_weight = regression_weight
|
||||
|
||||
# Use BCEWithLogitsLoss for AMP safety (combines sigmoid + BCE)
|
||||
self.bce_loss = nn.BCEWithLogitsLoss()
|
||||
self.huber_loss = nn.HuberLoss(delta=huber_delta)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_logits: torch.Tensor,
|
||||
pred_activities: torch.Tensor,
|
||||
target_presence: torch.Tensor,
|
||||
target_activities: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Calculate combined loss.
|
||||
|
||||
Args:
|
||||
pred_logits: Predicted isotope logits (batch, num_isotopes)
|
||||
pred_activities: Predicted activities (batch, num_isotopes)
|
||||
target_presence: Ground truth presence labels (batch, num_isotopes)
|
||||
target_activities: Ground truth activities (batch, num_isotopes)
|
||||
|
||||
Returns:
|
||||
Tuple of total loss and dict of individual losses
|
||||
"""
|
||||
# Classification loss (BCEWithLogitsLoss applies sigmoid internally)
|
||||
cls_loss = self.bce_loss(pred_logits, target_presence.float())
|
||||
|
||||
# Regression loss (only for present isotopes)
|
||||
# Mask to only compute loss where isotopes are actually present
|
||||
mask = target_presence.float()
|
||||
if mask.sum() > 0:
|
||||
masked_pred = pred_activities * mask
|
||||
masked_target = target_activities * mask
|
||||
reg_loss = self.huber_loss(masked_pred, masked_target)
|
||||
else:
|
||||
reg_loss = torch.tensor(0.0, device=pred_activities.device)
|
||||
|
||||
# Combined loss
|
||||
total_loss = (
|
||||
self.classification_weight * cls_loss +
|
||||
self.regression_weight * reg_loss
|
||||
)
|
||||
|
||||
loss_dict = {
|
||||
'total': total_loss.item(),
|
||||
'classification': cls_loss.item(),
|
||||
'regression': reg_loss.item() if isinstance(reg_loss, torch.Tensor) else reg_loss
|
||||
}
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the model
|
||||
config = VegaConfig()
|
||||
model = VegaModel(config)
|
||||
|
||||
print(model.summary())
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, config.num_channels)
|
||||
|
||||
probs, activities = model(x)
|
||||
print(f"\nInput shape: {x.shape}")
|
||||
print(f"Output probs shape: {probs.shape}")
|
||||
print(f"Output activities shape: {activities.shape}")
|
||||
|
||||
# Test loss
|
||||
loss_fn = VegaLoss()
|
||||
target_presence = torch.randint(0, 2, (batch_size, config.num_isotopes))
|
||||
target_activities = torch.rand(batch_size, config.num_isotopes) * 100
|
||||
|
||||
loss, loss_dict = loss_fn(probs, activities, target_presence, target_activities)
|
||||
print(f"\nLoss: {loss_dict}")
|
||||
231
train/vega_ml/training/vega/model_2d.py
Normal file
231
train/vega_ml/training/vega/model_2d.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""
|
||||
Vega 2D Model - Uses Full Temporal Information
|
||||
|
||||
This model treats gamma spectra as 2D images (time × channels) and uses
|
||||
Conv2d to extract both spectral and temporal features.
|
||||
|
||||
Input shape: (batch, 1, time_intervals, channels) = (B, 1, 60, 1023)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vega2DConfig:
|
||||
"""Configuration for Vega 2D model."""
|
||||
|
||||
# Input dimensions
|
||||
num_channels: int = 1023 # Energy channels
|
||||
num_time_intervals: int = 60 # Fixed time dimension
|
||||
|
||||
# Output
|
||||
num_isotopes: int = 82
|
||||
|
||||
# CNN architecture
|
||||
conv_channels: List[int] = field(default_factory=lambda: [32, 64, 128])
|
||||
kernel_size: Tuple[int, int] = (3, 7) # (time, energy) - larger in energy dimension
|
||||
pool_size: Tuple[int, int] = (2, 2)
|
||||
|
||||
# FC layers
|
||||
fc_hidden_dims: List[int] = field(default_factory=lambda: [512, 256])
|
||||
|
||||
# Regularization
|
||||
dropout_rate: float = 0.3
|
||||
leaky_relu_slope: float = 0.01
|
||||
|
||||
# Activity scaling
|
||||
max_activity_bq: float = 1000.0
|
||||
|
||||
|
||||
class ConvBlock2D(nn.Module):
|
||||
"""2D Convolutional block with BatchNorm, activation, pooling, and dropout."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Tuple[int, int],
|
||||
pool_size: Tuple[int, int],
|
||||
dropout_rate: float,
|
||||
leaky_relu_slope: float
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Padding to maintain spatial dimensions before pooling
|
||||
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
self.activation = nn.LeakyReLU(leaky_relu_slope)
|
||||
self.pool = nn.MaxPool2d(pool_size)
|
||||
self.dropout = nn.Dropout2d(dropout_rate)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.activation(self.bn1(self.conv1(x)))
|
||||
x = self.activation(self.bn2(self.conv2(x)))
|
||||
x = self.pool(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class Vega2DModel(nn.Module):
|
||||
"""
|
||||
2D CNN model for gamma spectrum isotope identification.
|
||||
|
||||
Treats spectra as images with time on one axis and energy channels on the other.
|
||||
This preserves temporal information that is lost in the 1D approach.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Vega2DConfig = None):
|
||||
super().__init__()
|
||||
self.config = config or Vega2DConfig()
|
||||
|
||||
# Build CNN backbone
|
||||
self.conv_blocks = nn.ModuleList()
|
||||
in_channels = 1 # Single channel input (like grayscale image)
|
||||
|
||||
for out_channels in self.config.conv_channels:
|
||||
self.conv_blocks.append(ConvBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=self.config.kernel_size,
|
||||
pool_size=self.config.pool_size,
|
||||
dropout_rate=self.config.dropout_rate,
|
||||
leaky_relu_slope=self.config.leaky_relu_slope
|
||||
))
|
||||
in_channels = out_channels
|
||||
|
||||
# Calculate flattened size after conv blocks
|
||||
self.flat_size = self._calculate_flat_size()
|
||||
|
||||
# Fully connected classifier
|
||||
fc_layers = []
|
||||
fc_in = self.flat_size
|
||||
|
||||
for fc_out in self.config.fc_hidden_dims:
|
||||
fc_layers.extend([
|
||||
nn.Linear(fc_in, fc_out),
|
||||
nn.BatchNorm1d(fc_out),
|
||||
nn.LeakyReLU(self.config.leaky_relu_slope),
|
||||
nn.Dropout(self.config.dropout_rate)
|
||||
])
|
||||
fc_in = fc_out
|
||||
|
||||
self.fc_backbone = nn.Sequential(*fc_layers)
|
||||
|
||||
# Output heads
|
||||
self.classifier = nn.Linear(fc_in, self.config.num_isotopes) # Logits for BCE
|
||||
self.regressor = nn.Sequential(
|
||||
nn.Linear(fc_in, self.config.num_isotopes),
|
||||
nn.ReLU() # Activity must be non-negative
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def _calculate_flat_size(self) -> int:
|
||||
"""Calculate the flattened size after all conv blocks."""
|
||||
# Start with input dimensions
|
||||
h = self.config.num_time_intervals # 60
|
||||
w = self.config.num_channels # 1023
|
||||
|
||||
# Each conv block applies pooling that halves dimensions
|
||||
for _ in self.config.conv_channels:
|
||||
h = h // self.config.pool_size[0]
|
||||
w = w // self.config.pool_size[1]
|
||||
|
||||
# Final size = last_channels * h * w
|
||||
return self.config.conv_channels[-1] * h * w
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize weights using Kaiming initialization."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (batch, 1, time_intervals, channels)
|
||||
or (batch, time_intervals, channels) - will add channel dim
|
||||
|
||||
Returns:
|
||||
Tuple of (logits, activities):
|
||||
- logits: (batch, num_isotopes) - raw scores for BCE loss
|
||||
- activities: (batch, num_isotopes) - predicted activities (normalized 0-1)
|
||||
"""
|
||||
# Add channel dimension if needed: (B, T, C) -> (B, 1, T, C)
|
||||
if x.dim() == 3:
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
# CNN backbone
|
||||
for conv_block in self.conv_blocks:
|
||||
x = conv_block(x)
|
||||
|
||||
# Flatten
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
# FC backbone
|
||||
x = self.fc_backbone(x)
|
||||
|
||||
# Output heads
|
||||
logits = self.classifier(x)
|
||||
activities = self.regressor(x)
|
||||
|
||||
return logits, activities
|
||||
|
||||
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict isotope presence and activities.
|
||||
|
||||
Args:
|
||||
x: Input spectrum
|
||||
threshold: Probability threshold for presence
|
||||
|
||||
Returns:
|
||||
Tuple of (presence, activities):
|
||||
- presence: (batch, num_isotopes) binary predictions
|
||||
- activities: (batch, num_isotopes) in Bq
|
||||
"""
|
||||
logits, activities_norm = self.forward(x)
|
||||
probs = torch.sigmoid(logits)
|
||||
presence = (probs >= threshold).float()
|
||||
activities_bq = activities_norm * self.config.max_activity_bq
|
||||
return presence, activities_bq
|
||||
|
||||
|
||||
def count_parameters(model: nn.Module) -> int:
|
||||
"""Count trainable parameters."""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the model
|
||||
config = Vega2DConfig()
|
||||
model = Vega2DModel(config)
|
||||
|
||||
print(f"Vega 2D Model")
|
||||
print(f" Input: ({config.num_time_intervals}, {config.num_channels})")
|
||||
print(f" Conv channels: {config.conv_channels}")
|
||||
print(f" FC dims: {config.fc_hidden_dims}")
|
||||
print(f" Flat size: {model.flat_size}")
|
||||
print(f" Parameters: {count_parameters(model):,}")
|
||||
|
||||
# Test forward pass
|
||||
batch = torch.randn(4, 1, config.num_time_intervals, config.num_channels)
|
||||
logits, activities = model(batch)
|
||||
print(f"\n Test batch: {batch.shape}")
|
||||
print(f" Logits: {logits.shape}")
|
||||
print(f" Activities: {activities.shape}")
|
||||
120
train/vega_ml/training/vega/run_training.py
Normal file
120
train/vega_ml/training/vega/run_training.py
Normal file
@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Run Vega Training
|
||||
|
||||
Simple script to train the Vega model on synthetic gamma spectra.
|
||||
Designed for both quick test runs and full-scale training.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from training.vega.train import train_vega, TrainingConfig
|
||||
from training.vega.model import VegaConfig
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train Vega model for isotope identification"
|
||||
)
|
||||
|
||||
# Data paths
|
||||
parser.add_argument(
|
||||
"--data-dir", "-d",
|
||||
type=str,
|
||||
default="O:/master_data_collection/isotopev2",
|
||||
help="Path to synthetic data directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-dir", "-m",
|
||||
type=str,
|
||||
default="models",
|
||||
help="Directory to save trained models"
|
||||
)
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument(
|
||||
"--epochs", "-e",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of training epochs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", "-b",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Batch size for training (default: 64 for better GPU utilization)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", "-lr",
|
||||
type=float,
|
||||
default=1e-3,
|
||||
help="Initial learning rate"
|
||||
)
|
||||
|
||||
# Quick test mode
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Quick test mode with reduced epochs"
|
||||
)
|
||||
|
||||
# Mixed precision
|
||||
parser.add_argument(
|
||||
"--no-amp",
|
||||
action="store_true",
|
||||
help="Disable automatic mixed precision training"
|
||||
)
|
||||
|
||||
# Data loading parallelism
|
||||
parser.add_argument(
|
||||
"--workers", "-w",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of data loading workers (default: 8 for parallel I/O)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create training config
|
||||
config = TrainingConfig(
|
||||
data_dir=args.data_dir,
|
||||
model_dir=args.model_dir,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
num_epochs=args.epochs if not args.test else 5,
|
||||
patience=10 if not args.test else 3,
|
||||
use_amp=not args.no_amp,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
# Create model config
|
||||
model_config = VegaConfig()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("VEGA TRAINING")
|
||||
print("=" * 60)
|
||||
print(f"Data directory: {args.data_dir}")
|
||||
print(f"Model directory: {args.model_dir}")
|
||||
print(f"Epochs: {config.num_epochs}")
|
||||
print(f"Batch size: {config.batch_size}")
|
||||
print(f"Learning rate: {config.learning_rate}")
|
||||
print(f"Mixed precision: {config.use_amp}")
|
||||
print(f"Data workers: {config.num_workers}")
|
||||
if args.test:
|
||||
print("MODE: Quick test run")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Run training
|
||||
model, results = train_vega(config=config, model_config=model_config)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
614
train/vega_ml/training/vega/train.py
Normal file
614
train/vega_ml/training/vega/train.py
Normal file
@ -0,0 +1,614 @@
|
||||
"""
|
||||
Training Script for Vega Model
|
||||
|
||||
Implements the training loop with:
|
||||
- Mixed precision training for RTX 5090 efficiency
|
||||
- Learning rate scheduling
|
||||
- Early stopping
|
||||
- Model checkpointing
|
||||
- Training metrics logging
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
import numpy as np
|
||||
|
||||
# Sklearn metrics for comprehensive evaluation
|
||||
from sklearn.metrics import (
|
||||
roc_auc_score,
|
||||
f1_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
hamming_loss
|
||||
)
|
||||
|
||||
from .model import VegaModel, VegaConfig, VegaLoss
|
||||
from .dataset import create_data_loaders, SpectrumDataset
|
||||
from .isotope_index import IsotopeIndex, get_default_isotope_index
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for training."""
|
||||
# Data
|
||||
data_dir: str = "O:/master_data_collection/isotopev2"
|
||||
|
||||
# Model save path
|
||||
model_dir: str = "models"
|
||||
model_name: str = "vega"
|
||||
|
||||
# Training hyperparameters
|
||||
batch_size: int = 64 # Increased from 32 for better GPU utilization
|
||||
learning_rate: float = 1e-3
|
||||
weight_decay: float = 1e-4
|
||||
num_epochs: int = 100
|
||||
|
||||
# Early stopping
|
||||
patience: int = 10
|
||||
min_delta: float = 1e-4
|
||||
|
||||
# Learning rate scheduling
|
||||
lr_scheduler_patience: int = 5
|
||||
lr_scheduler_factor: float = 0.5
|
||||
min_lr: float = 1e-6
|
||||
|
||||
# Mixed precision
|
||||
use_amp: bool = True
|
||||
|
||||
# Data splits
|
||||
train_split: float = 0.8
|
||||
val_split: float = 0.1
|
||||
test_split: float = 0.1
|
||||
|
||||
# Workers - parallel data loading for better GPU utilization
|
||||
num_workers: int = 8 # Parallel data loading workers
|
||||
prefetch_factor: int = 4 # Batches to prefetch per worker
|
||||
persistent_workers: bool = True # Keep workers alive between epochs
|
||||
|
||||
# Reproducibility
|
||||
seed: int = 42
|
||||
|
||||
# Activity normalization
|
||||
max_activity_bq: float = 1000.0
|
||||
|
||||
|
||||
class VegaTrainer:
|
||||
"""
|
||||
Trainer class for the Vega model.
|
||||
|
||||
Handles the training loop, validation, checkpointing, and metrics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: VegaModel,
|
||||
config: TrainingConfig,
|
||||
device: Optional[torch.device] = None,
|
||||
force_cpu: bool = False
|
||||
):
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
# Device selection - force CPU if requested or if CUDA incompatible
|
||||
if force_cpu:
|
||||
self.device = torch.device('cpu')
|
||||
elif device:
|
||||
self.device = device
|
||||
else:
|
||||
# Try CUDA but fall back to CPU if there are compatibility issues
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# Test if CUDA actually works
|
||||
test_tensor = torch.zeros(1, device='cuda')
|
||||
_ = test_tensor + 1
|
||||
self.device = torch.device('cuda')
|
||||
except RuntimeError:
|
||||
print("CUDA device found but not compatible, falling back to CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
# Move model to device
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Setup loss function
|
||||
self.loss_fn = VegaLoss(
|
||||
classification_weight=model.config.classification_weight,
|
||||
regression_weight=model.config.regression_weight
|
||||
)
|
||||
|
||||
# Setup optimizer
|
||||
self.optimizer = Adam(
|
||||
self.model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Setup learning rate scheduler
|
||||
self.scheduler = ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
mode='min',
|
||||
patience=config.lr_scheduler_patience,
|
||||
factor=config.lr_scheduler_factor,
|
||||
min_lr=config.min_lr
|
||||
)
|
||||
|
||||
# Setup mixed precision training (only if CUDA is working)
|
||||
if config.use_amp and self.device.type == 'cuda':
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
else:
|
||||
self.scaler = None
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float('inf')
|
||||
self.epochs_without_improvement = 0
|
||||
self.training_history = []
|
||||
|
||||
# Create model directory
|
||||
self.model_dir = Path(config.model_dir)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Training on device: {self.device}")
|
||||
if self.device.type == 'cuda':
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"Mixed precision: {config.use_amp}")
|
||||
|
||||
def train_epoch(self, train_loader) -> Dict[str, float]:
|
||||
"""Train for one epoch."""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
cls_loss_sum = 0.0
|
||||
reg_loss_sum = 0.0
|
||||
num_batches = 0
|
||||
|
||||
# Track accuracy during training
|
||||
correct_isotopes = 0
|
||||
total_isotopes = 0
|
||||
|
||||
# Timing for profiling - track data loading vs GPU compute
|
||||
data_time = 0.0
|
||||
compute_time = 0.0
|
||||
data_start = time.time()
|
||||
|
||||
for batch in train_loader:
|
||||
# Data loading time (time spent waiting for next batch)
|
||||
data_time += time.time() - data_start
|
||||
compute_start = time.time()
|
||||
|
||||
# Move to device
|
||||
spectra = batch['spectrum'].to(self.device)
|
||||
presence = batch['presence_labels'].to(self.device)
|
||||
activities = batch['activity_labels'].to(self.device)
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Forward pass with optional mixed precision
|
||||
if self.scaler is not None:
|
||||
with torch.amp.autocast('cuda'):
|
||||
pred_logits, pred_activities = self.model(spectra)
|
||||
loss, loss_dict = self.loss_fn(
|
||||
pred_logits, pred_activities, presence, activities
|
||||
)
|
||||
|
||||
# Backward pass with scaling
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
pred_logits, pred_activities = self.model(spectra)
|
||||
loss, loss_dict = self.loss_fn(
|
||||
pred_logits, pred_activities, presence, activities
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss_dict['total']
|
||||
cls_loss_sum += loss_dict['classification']
|
||||
reg_loss_sum += loss_dict['regression']
|
||||
num_batches += 1
|
||||
|
||||
# Calculate training accuracy (detach to avoid memory buildup)
|
||||
with torch.no_grad():
|
||||
pred_probs = torch.sigmoid(pred_logits)
|
||||
pred_presence = (pred_probs >= 0.5).float()
|
||||
correct_isotopes += (pred_presence == presence).sum().item()
|
||||
total_isotopes += presence.numel()
|
||||
|
||||
# Mark compute time and restart timing for data loading
|
||||
compute_time += time.time() - compute_start
|
||||
data_start = time.time()
|
||||
|
||||
train_accuracy = correct_isotopes / total_isotopes if total_isotopes > 0 else 0.0
|
||||
|
||||
return {
|
||||
'train_loss': total_loss / num_batches,
|
||||
'train_cls_loss': cls_loss_sum / num_batches,
|
||||
'train_reg_loss': reg_loss_sum / num_batches,
|
||||
'train_accuracy': train_accuracy,
|
||||
'data_time': data_time,
|
||||
'compute_time': compute_time
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self, val_loader) -> Dict[str, float]:
|
||||
"""Validate the model with comprehensive metrics."""
|
||||
if val_loader is None:
|
||||
return {}
|
||||
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
cls_loss_sum = 0.0
|
||||
reg_loss_sum = 0.0
|
||||
num_batches = 0
|
||||
|
||||
# Collect all predictions and labels for sklearn metrics
|
||||
all_probs = []
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
for batch in val_loader:
|
||||
spectra = batch['spectrum'].to(self.device)
|
||||
presence = batch['presence_labels'].to(self.device)
|
||||
activities = batch['activity_labels'].to(self.device)
|
||||
|
||||
pred_logits, pred_activities = self.model(spectra)
|
||||
loss, loss_dict = self.loss_fn(
|
||||
pred_logits, pred_activities, presence, activities
|
||||
)
|
||||
|
||||
total_loss += loss_dict['total']
|
||||
cls_loss_sum += loss_dict['classification']
|
||||
reg_loss_sum += loss_dict['regression']
|
||||
num_batches += 1
|
||||
|
||||
# Collect predictions for metrics
|
||||
pred_probs = torch.sigmoid(pred_logits)
|
||||
pred_presence = (pred_probs >= 0.5).float()
|
||||
|
||||
all_probs.append(pred_probs.cpu().numpy())
|
||||
all_preds.append(pred_presence.cpu().numpy())
|
||||
all_labels.append(presence.cpu().numpy())
|
||||
|
||||
# Concatenate all batches
|
||||
all_probs = np.vstack(all_probs)
|
||||
all_preds = np.vstack(all_preds)
|
||||
all_labels = np.vstack(all_labels)
|
||||
|
||||
# Basic accuracy (element-wise)
|
||||
correct = (all_preds == all_labels).sum()
|
||||
total = all_labels.size
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
# Multi-label metrics using sklearn
|
||||
metrics = {
|
||||
'val_loss': total_loss / num_batches,
|
||||
'val_cls_loss': cls_loss_sum / num_batches,
|
||||
'val_reg_loss': reg_loss_sum / num_batches,
|
||||
'val_accuracy': accuracy,
|
||||
}
|
||||
|
||||
try:
|
||||
# ROC-AUC (macro-averaged over isotopes with both classes present)
|
||||
# Only compute for columns that have both 0s and 1s
|
||||
valid_cols = []
|
||||
for i in range(all_labels.shape[1]):
|
||||
if len(np.unique(all_labels[:, i])) == 2:
|
||||
valid_cols.append(i)
|
||||
|
||||
if valid_cols:
|
||||
auc_macro = roc_auc_score(
|
||||
all_labels[:, valid_cols],
|
||||
all_probs[:, valid_cols],
|
||||
average='macro'
|
||||
)
|
||||
auc_micro = roc_auc_score(
|
||||
all_labels[:, valid_cols],
|
||||
all_probs[:, valid_cols],
|
||||
average='micro'
|
||||
)
|
||||
metrics['val_auc_macro'] = auc_macro
|
||||
metrics['val_auc_micro'] = auc_micro
|
||||
else:
|
||||
metrics['val_auc_macro'] = 0.0
|
||||
metrics['val_auc_micro'] = 0.0
|
||||
|
||||
except ValueError:
|
||||
# Handle case where AUC can't be computed
|
||||
metrics['val_auc_macro'] = 0.0
|
||||
metrics['val_auc_micro'] = 0.0
|
||||
|
||||
# F1, Precision, Recall (samples-averaged for multi-label)
|
||||
metrics['val_f1_macro'] = f1_score(all_labels, all_preds, average='macro', zero_division=0)
|
||||
metrics['val_f1_micro'] = f1_score(all_labels, all_preds, average='micro', zero_division=0)
|
||||
metrics['val_precision'] = precision_score(all_labels, all_preds, average='micro', zero_division=0)
|
||||
metrics['val_recall'] = recall_score(all_labels, all_preds, average='micro', zero_division=0)
|
||||
|
||||
# Hamming loss (fraction of labels incorrectly predicted)
|
||||
metrics['val_hamming'] = hamming_loss(all_labels, all_preds)
|
||||
|
||||
# Exact match ratio (all isotopes correct for a sample)
|
||||
exact_matches = (all_preds == all_labels).all(axis=1).sum()
|
||||
metrics['val_exact_match'] = exact_matches / len(all_labels)
|
||||
|
||||
return metrics
|
||||
|
||||
def save_checkpoint(self, path: Path, is_best: bool = False):
|
||||
"""Save a model checkpoint."""
|
||||
checkpoint = {
|
||||
'epoch': self.current_epoch,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'best_val_loss': self.best_val_loss,
|
||||
'model_config': asdict(self.model.config),
|
||||
'training_config': asdict(self.config),
|
||||
'training_history': self.training_history
|
||||
}
|
||||
|
||||
if self.scaler is not None:
|
||||
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
||||
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
if is_best:
|
||||
best_path = path.parent / f"{self.config.model_name}_best.pt"
|
||||
torch.save(checkpoint, best_path)
|
||||
|
||||
def load_checkpoint(self, path: Path):
|
||||
"""Load a model checkpoint."""
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
if self.scaler is not None and 'scaler_state_dict' in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||
|
||||
self.current_epoch = checkpoint['epoch']
|
||||
self.best_val_loss = checkpoint['best_val_loss']
|
||||
self.training_history = checkpoint.get('training_history', [])
|
||||
|
||||
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_loader,
|
||||
val_loader,
|
||||
resume_from: Optional[Path] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Full training loop.
|
||||
|
||||
Args:
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
resume_from: Optional path to checkpoint to resume from
|
||||
|
||||
Returns:
|
||||
Training results dictionary
|
||||
"""
|
||||
if resume_from is not None:
|
||||
self.load_checkpoint(resume_from)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting Vega Training")
|
||||
print("=" * 60)
|
||||
print(f"Epochs: {self.config.num_epochs}")
|
||||
print(f"Batch size: {self.config.batch_size}")
|
||||
print(f"Learning rate: {self.config.learning_rate}")
|
||||
print(f"Training samples: {len(train_loader.dataset)}")
|
||||
if val_loader:
|
||||
print(f"Validation samples: {len(val_loader.dataset)}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.current_epoch, self.config.num_epochs):
|
||||
self.current_epoch = epoch
|
||||
epoch_start = time.time()
|
||||
|
||||
# Train
|
||||
train_metrics = self.train_epoch(train_loader)
|
||||
|
||||
# Validate
|
||||
val_metrics = self.validate(val_loader)
|
||||
|
||||
# Combine metrics
|
||||
metrics = {**train_metrics, **val_metrics, 'epoch': epoch}
|
||||
self.training_history.append(metrics)
|
||||
|
||||
# Update learning rate
|
||||
if val_loader and 'val_loss' in val_metrics:
|
||||
self.scheduler.step(val_metrics['val_loss'])
|
||||
else:
|
||||
self.scheduler.step(train_metrics['train_loss'])
|
||||
|
||||
# Check for improvement
|
||||
val_loss = val_metrics.get('val_loss', train_metrics['train_loss'])
|
||||
is_best = val_loss < self.best_val_loss - self.config.min_delta
|
||||
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.epochs_without_improvement = 0
|
||||
else:
|
||||
self.epochs_without_improvement += 1
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = self.model_dir / f"{self.config.model_name}_epoch_{epoch}.pt"
|
||||
self.save_checkpoint(checkpoint_path, is_best=is_best)
|
||||
|
||||
# Logging
|
||||
epoch_time = time.time() - epoch_start
|
||||
lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
# Primary metrics line
|
||||
log_str = (
|
||||
f"Epoch {epoch+1:3d}/{self.config.num_epochs} | "
|
||||
f"Train Loss: {train_metrics['train_loss']:.4f} | "
|
||||
f"Train Acc: {train_metrics['train_accuracy']:.4f} | "
|
||||
)
|
||||
if val_loader:
|
||||
log_str += (
|
||||
f"Val Loss: {val_metrics['val_loss']:.4f} | "
|
||||
f"Val Acc: {val_metrics['val_accuracy']:.4f} | "
|
||||
)
|
||||
log_str += f"LR: {lr:.2e} | Time: {epoch_time:.1f}s"
|
||||
|
||||
if is_best:
|
||||
log_str += " *"
|
||||
|
||||
print(log_str)
|
||||
|
||||
# Timing breakdown line
|
||||
data_t = train_metrics.get('data_time', 0)
|
||||
compute_t = train_metrics.get('compute_time', 0)
|
||||
if data_t > 0 or compute_t > 0:
|
||||
data_pct = 100 * data_t / (data_t + compute_t) if (data_t + compute_t) > 0 else 0
|
||||
print(f" └── Data: {data_t:.1f}s ({data_pct:.0f}%) | Compute: {compute_t:.1f}s ({100-data_pct:.0f}%)")
|
||||
|
||||
# Secondary metrics line (detailed classification metrics)
|
||||
if val_loader and 'val_auc_macro' in val_metrics:
|
||||
detail_str = (
|
||||
f" └── AUC: {val_metrics['val_auc_macro']:.4f} | "
|
||||
f"F1: {val_metrics['val_f1_macro']:.4f} | "
|
||||
f"Prec: {val_metrics['val_precision']:.4f} | "
|
||||
f"Recall: {val_metrics['val_recall']:.4f} | "
|
||||
f"Exact: {val_metrics['val_exact_match']:.4f}"
|
||||
)
|
||||
print(detail_str)
|
||||
|
||||
# Early stopping
|
||||
if self.epochs_without_improvement >= self.config.patience:
|
||||
print(f"\nEarly stopping after {epoch + 1} epochs")
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Save final model
|
||||
final_path = self.model_dir / f"{self.config.model_name}_final.pt"
|
||||
self.save_checkpoint(final_path)
|
||||
|
||||
# Save training history
|
||||
history_path = self.model_dir / f"{self.config.model_name}_history.json"
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.training_history, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Training complete!")
|
||||
print(f"Total time: {total_time / 60:.1f} minutes")
|
||||
print(f"Best validation loss: {self.best_val_loss:.4f}")
|
||||
print(f"Model saved to: {final_path}")
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
'best_val_loss': self.best_val_loss,
|
||||
'total_epochs': self.current_epoch + 1,
|
||||
'total_time': total_time,
|
||||
'history': self.training_history
|
||||
}
|
||||
|
||||
|
||||
def train_vega(
|
||||
data_dir: Optional[str] = None,
|
||||
model_dir: Optional[str] = None,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
model_config: Optional[VegaConfig] = None
|
||||
) -> Tuple[VegaModel, Dict]:
|
||||
"""
|
||||
Convenience function to train a Vega model.
|
||||
|
||||
Args:
|
||||
data_dir: Path to data directory
|
||||
model_dir: Path to save models
|
||||
config: Training configuration
|
||||
model_config: Model configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (trained model, training results)
|
||||
"""
|
||||
# Setup paths
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
if config is None:
|
||||
config = TrainingConfig()
|
||||
|
||||
if data_dir:
|
||||
config.data_dir = data_dir
|
||||
if model_dir:
|
||||
config.model_dir = model_dir
|
||||
|
||||
# Make paths absolute
|
||||
data_path = Path(config.data_dir)
|
||||
if not data_path.is_absolute():
|
||||
data_path = project_root / data_path
|
||||
|
||||
model_path = Path(config.model_dir)
|
||||
if not model_path.is_absolute():
|
||||
model_path = project_root / model_path
|
||||
|
||||
config.model_dir = str(model_path)
|
||||
|
||||
# Set random seeds
|
||||
torch.manual_seed(config.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(config.seed)
|
||||
|
||||
# Get isotope index
|
||||
isotope_index = get_default_isotope_index()
|
||||
|
||||
# Create data loaders with parallel loading
|
||||
train_loader, val_loader, test_loader = create_data_loaders(
|
||||
data_dir=data_path,
|
||||
batch_size=config.batch_size,
|
||||
train_split=config.train_split,
|
||||
val_split=config.val_split,
|
||||
test_split=config.test_split,
|
||||
num_workers=config.num_workers,
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
persistent_workers=config.persistent_workers,
|
||||
isotope_index=isotope_index,
|
||||
max_activity_bq=config.max_activity_bq,
|
||||
seed=config.seed
|
||||
)
|
||||
|
||||
# Create model
|
||||
if model_config is None:
|
||||
model_config = VegaConfig(
|
||||
num_isotopes=isotope_index.num_isotopes,
|
||||
max_activity_bq=config.max_activity_bq
|
||||
)
|
||||
|
||||
model = VegaModel(model_config)
|
||||
print(model.summary())
|
||||
|
||||
# Create trainer
|
||||
trainer = VegaTrainer(model, config)
|
||||
|
||||
# Train
|
||||
results = trainer.train(train_loader, val_loader)
|
||||
|
||||
# Save isotope index with model
|
||||
index_path = model_path / f"{config.model_name}_isotope_index.txt"
|
||||
isotope_index.save(index_path)
|
||||
|
||||
return model, results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick test training
|
||||
model, results = train_vega()
|
||||
411
train/vega_ml/training/vega/train_2d.py
Normal file
411
train/vega_ml/training/vega/train_2d.py
Normal file
@ -0,0 +1,411 @@
|
||||
"""
|
||||
Training Script for Vega 2D Model
|
||||
|
||||
Uses 2D convolutions to process gamma spectra with temporal information.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from .model_2d import Vega2DModel, Vega2DConfig, count_parameters
|
||||
from .dataset_2d import create_data_loaders_2d, SpectrumDataset2D
|
||||
from .isotope_index import get_default_isotope_index
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig2D:
|
||||
"""Training configuration for 2D model."""
|
||||
|
||||
# Data
|
||||
data_dir: str = "O:/master_data_collection/isotopev2"
|
||||
model_dir: str = "models"
|
||||
|
||||
# Model
|
||||
target_time_intervals: int = 60
|
||||
|
||||
# Training
|
||||
epochs: int = 50
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 1e-3
|
||||
weight_decay: float = 1e-5
|
||||
|
||||
# Loss weights
|
||||
classification_weight: float = 1.0
|
||||
regression_weight: float = 0.1
|
||||
|
||||
# Mixed precision
|
||||
use_amp: bool = True
|
||||
|
||||
# Early stopping
|
||||
early_stopping_patience: int = 10
|
||||
|
||||
# Learning rate scheduler
|
||||
lr_scheduler_patience: int = 5
|
||||
lr_scheduler_factor: float = 0.5
|
||||
|
||||
# Data loading
|
||||
num_workers: int = 4
|
||||
|
||||
|
||||
def train_epoch(
|
||||
model: nn.Module,
|
||||
train_loader,
|
||||
optimizer: optim.Optimizer,
|
||||
criterion_cls: nn.Module,
|
||||
criterion_reg: nn.Module,
|
||||
device: torch.device,
|
||||
scaler: Optional[GradScaler],
|
||||
config: TrainingConfig2D
|
||||
) -> Dict[str, float]:
|
||||
"""Train for one epoch."""
|
||||
model.train()
|
||||
|
||||
total_loss = 0.0
|
||||
total_cls_loss = 0.0
|
||||
total_reg_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in train_loader:
|
||||
spectra = batch['spectrum'].to(device)
|
||||
presence = batch['presence_labels'].to(device)
|
||||
activities = batch['activity_labels'].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if scaler is not None:
|
||||
with autocast():
|
||||
logits, pred_activities = model(spectra)
|
||||
cls_loss = criterion_cls(logits, presence)
|
||||
reg_loss = criterion_reg(pred_activities, activities)
|
||||
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
logits, pred_activities = model(spectra)
|
||||
cls_loss = criterion_cls(logits, presence)
|
||||
reg_loss = criterion_reg(pred_activities, activities)
|
||||
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_cls_loss += cls_loss.item()
|
||||
total_reg_loss += reg_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
return {
|
||||
'loss': total_loss / num_batches,
|
||||
'cls_loss': total_cls_loss / num_batches,
|
||||
'reg_loss': total_reg_loss / num_batches
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
model: nn.Module,
|
||||
val_loader,
|
||||
criterion_cls: nn.Module,
|
||||
criterion_reg: nn.Module,
|
||||
device: torch.device,
|
||||
config: TrainingConfig2D,
|
||||
threshold: float = 0.5
|
||||
) -> Dict[str, float]:
|
||||
"""Validate the model."""
|
||||
model.eval()
|
||||
|
||||
total_loss = 0.0
|
||||
total_cls_loss = 0.0
|
||||
total_reg_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
for batch in val_loader:
|
||||
spectra = batch['spectrum'].to(device)
|
||||
presence = batch['presence_labels'].to(device)
|
||||
activities = batch['activity_labels'].to(device)
|
||||
|
||||
logits, pred_activities = model(spectra)
|
||||
cls_loss = criterion_cls(logits, presence)
|
||||
reg_loss = criterion_reg(pred_activities, activities)
|
||||
loss = config.classification_weight * cls_loss + config.regression_weight * reg_loss
|
||||
|
||||
total_loss += loss.item()
|
||||
total_cls_loss += cls_loss.item()
|
||||
total_reg_loss += reg_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Collect predictions for metrics
|
||||
probs = torch.sigmoid(logits)
|
||||
preds = (probs >= threshold).float()
|
||||
all_preds.append(preds.cpu())
|
||||
all_labels.append(presence.cpu())
|
||||
|
||||
# Calculate metrics
|
||||
all_preds = torch.cat(all_preds, dim=0)
|
||||
all_labels = torch.cat(all_labels, dim=0)
|
||||
|
||||
# Per-sample accuracy (all isotopes correct)
|
||||
exact_match = (all_preds == all_labels).all(dim=1).float().mean().item()
|
||||
|
||||
# Per-isotope metrics
|
||||
tp = ((all_preds == 1) & (all_labels == 1)).sum().item()
|
||||
fp = ((all_preds == 1) & (all_labels == 0)).sum().item()
|
||||
fn = ((all_preds == 0) & (all_labels == 1)).sum().item()
|
||||
tn = ((all_preds == 0) & (all_labels == 0)).sum().item()
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
return {
|
||||
'loss': total_loss / num_batches,
|
||||
'cls_loss': total_cls_loss / num_batches,
|
||||
'reg_loss': total_reg_loss / num_batches,
|
||||
'exact_match': exact_match,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1
|
||||
}
|
||||
|
||||
|
||||
def train_vega_2d(
|
||||
config: TrainingConfig2D = None,
|
||||
model_config: Vega2DConfig = None
|
||||
) -> Tuple[Vega2DModel, Dict]:
|
||||
"""
|
||||
Train the Vega 2D model.
|
||||
"""
|
||||
config = config or TrainingConfig2D()
|
||||
model_config = model_config or Vega2DConfig(num_time_intervals=config.target_time_intervals)
|
||||
|
||||
# Setup device
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
if device.type == 'cuda':
|
||||
print(f" GPU: {torch.cuda.get_device_name()}")
|
||||
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
||||
|
||||
# Create model
|
||||
model = Vega2DModel(model_config).to(device)
|
||||
print(f"\nModel: Vega 2D")
|
||||
print(f" Input: ({model_config.num_time_intervals}, {model_config.num_channels})")
|
||||
print(f" Conv channels: {model_config.conv_channels}")
|
||||
print(f" FC dims: {model_config.fc_hidden_dims}")
|
||||
print(f" Parameters: {count_parameters(model):,}")
|
||||
|
||||
# Create data loaders
|
||||
print(f"\nLoading data from: {config.data_dir}")
|
||||
isotope_index = get_default_isotope_index()
|
||||
|
||||
train_loader, val_loader, test_loader = create_data_loaders_2d(
|
||||
data_dir=Path(config.data_dir),
|
||||
batch_size=config.batch_size,
|
||||
target_time_intervals=config.target_time_intervals,
|
||||
isotope_index=isotope_index,
|
||||
num_workers=config.num_workers
|
||||
)
|
||||
|
||||
# Loss functions
|
||||
criterion_cls = nn.BCEWithLogitsLoss()
|
||||
criterion_reg = nn.HuberLoss()
|
||||
|
||||
# Optimizer
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='min',
|
||||
factor=config.lr_scheduler_factor,
|
||||
patience=config.lr_scheduler_patience
|
||||
)
|
||||
|
||||
# Mixed precision scaler
|
||||
scaler = GradScaler() if config.use_amp and device.type == 'cuda' else None
|
||||
|
||||
# Training history
|
||||
history = {
|
||||
'train_loss': [], 'val_loss': [],
|
||||
'train_cls_loss': [], 'val_cls_loss': [],
|
||||
'train_reg_loss': [], 'val_reg_loss': [],
|
||||
'val_exact_match': [], 'val_precision': [], 'val_recall': [], 'val_f1': [],
|
||||
'lr': []
|
||||
}
|
||||
|
||||
# Early stopping
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
# Model directory
|
||||
model_dir = Path(config.model_dir)
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
|
||||
print(f"\nStarting training for {config.epochs} epochs...")
|
||||
print(f" Batch size: {config.batch_size}")
|
||||
print(f" Learning rate: {config.learning_rate}")
|
||||
print(f" AMP: {scaler is not None}")
|
||||
print()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
epoch_start = time.time()
|
||||
|
||||
# Train
|
||||
train_metrics = train_epoch(
|
||||
model, train_loader, optimizer,
|
||||
criterion_cls, criterion_reg,
|
||||
device, scaler, config
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_metrics = validate(
|
||||
model, val_loader,
|
||||
criterion_cls, criterion_reg,
|
||||
device, config
|
||||
)
|
||||
|
||||
# Update scheduler
|
||||
scheduler.step(val_metrics['loss'])
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Record history
|
||||
history['train_loss'].append(train_metrics['loss'])
|
||||
history['val_loss'].append(val_metrics['loss'])
|
||||
history['train_cls_loss'].append(train_metrics['cls_loss'])
|
||||
history['val_cls_loss'].append(val_metrics['cls_loss'])
|
||||
history['train_reg_loss'].append(train_metrics['reg_loss'])
|
||||
history['val_reg_loss'].append(val_metrics['reg_loss'])
|
||||
history['val_exact_match'].append(val_metrics['exact_match'])
|
||||
history['val_precision'].append(val_metrics['precision'])
|
||||
history['val_recall'].append(val_metrics['recall'])
|
||||
history['val_f1'].append(val_metrics['f1'])
|
||||
history['lr'].append(current_lr)
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
|
||||
# Print progress
|
||||
print(f"Epoch {epoch+1:3d}/{config.epochs} ({epoch_time:.1f}s) | "
|
||||
f"Train Loss: {train_metrics['loss']:.4f} | "
|
||||
f"Val Loss: {val_metrics['loss']:.4f} | "
|
||||
f"F1: {val_metrics['f1']:.4f} | "
|
||||
f"Recall: {val_metrics['recall']:.4f} | "
|
||||
f"LR: {current_lr:.2e}")
|
||||
|
||||
# Save best model
|
||||
if val_metrics['loss'] < best_val_loss:
|
||||
best_val_loss = val_metrics['loss']
|
||||
patience_counter = 0
|
||||
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'model_config': asdict(model_config),
|
||||
'training_config': asdict(config),
|
||||
'val_metrics': val_metrics,
|
||||
'history': history
|
||||
}, model_dir / 'vega_2d_best.pt')
|
||||
print(f" ✓ Saved best model (val_loss: {best_val_loss:.4f})")
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
# Early stopping
|
||||
if patience_counter >= config.early_stopping_patience:
|
||||
print(f"\nEarly stopping at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"\nTraining complete in {total_time/60:.1f} minutes")
|
||||
print(f"Best validation loss: {best_val_loss:.4f}")
|
||||
|
||||
# Save final model
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'model_config': asdict(model_config),
|
||||
'training_config': asdict(config),
|
||||
'history': history
|
||||
}, model_dir / 'vega_2d_final.pt')
|
||||
|
||||
# Save history
|
||||
with open(model_dir / 'vega_2d_history.json', 'w') as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
# Test set evaluation
|
||||
print("\nEvaluating on test set...")
|
||||
test_metrics = validate(
|
||||
model, test_loader,
|
||||
criterion_cls, criterion_reg,
|
||||
device, config
|
||||
)
|
||||
print(f" Test Loss: {test_metrics['loss']:.4f}")
|
||||
print(f" Test F1: {test_metrics['f1']:.4f}")
|
||||
print(f" Test Recall: {test_metrics['recall']:.4f}")
|
||||
print(f" Test Precision: {test_metrics['precision']:.4f}")
|
||||
print(f" Test Exact Match: {test_metrics['exact_match']:.4f}")
|
||||
|
||||
return model, history
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train Vega 2D Model')
|
||||
parser.add_argument('--data-dir', type=str, default='O:/master_data_collection/isotopev2',
|
||||
help='Path to training data')
|
||||
parser.add_argument('--model-dir', type=str, default='models',
|
||||
help='Path to save models')
|
||||
parser.add_argument('--epochs', type=int, default=50,
|
||||
help='Number of epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-3,
|
||||
help='Learning rate')
|
||||
parser.add_argument('--time-intervals', type=int, default=60,
|
||||
help='Target time intervals (pad/truncate)')
|
||||
parser.add_argument('--no-amp', action='store_true',
|
||||
help='Disable mixed precision training')
|
||||
parser.add_argument('--workers', type=int, default=4,
|
||||
help='Data loading workers')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TrainingConfig2D(
|
||||
data_dir=args.data_dir,
|
||||
model_dir=args.model_dir,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
target_time_intervals=args.time_intervals,
|
||||
use_amp=not args.no_amp,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
model_config = Vega2DConfig(
|
||||
num_time_intervals=args.time_intervals
|
||||
)
|
||||
|
||||
train_vega_2d(config, model_config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
847
train/vega_ml/training/vega/train_v2_optuna.py
Normal file
847
train/vega_ml/training/vega/train_v2_optuna.py
Normal file
@ -0,0 +1,847 @@
|
||||
"""
|
||||
Vega Training v2 - Optuna Hyperparameter Optimization
|
||||
|
||||
Uses Optuna to search for optimal hyperparameters to maximize model performance,
|
||||
with a focus on improving recall for isotope detection.
|
||||
|
||||
Key optimizations:
|
||||
1. Model architecture (CNN channels, FC dims, kernel sizes)
|
||||
2. Training hyperparameters (LR, batch size, weight decay, dropout)
|
||||
3. Loss function weights (classification vs regression balance)
|
||||
4. Classification threshold optimization
|
||||
5. Focal loss for handling class imbalance
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Adam, AdamW
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
|
||||
import numpy as np
|
||||
|
||||
import optuna
|
||||
from optuna.trial import Trial
|
||||
from optuna.pruners import MedianPruner, HyperbandPruner
|
||||
from optuna.samplers import TPESampler
|
||||
|
||||
# Sklearn metrics
|
||||
from sklearn.metrics import (
|
||||
roc_auc_score,
|
||||
f1_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
hamming_loss
|
||||
)
|
||||
|
||||
# Add project root to path
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from training.vega.model import VegaModel, VegaConfig
|
||||
from training.vega.dataset import create_data_loaders, SpectrumDataset
|
||||
from training.vega.isotope_index import IsotopeIndex, get_default_isotope_index
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""
|
||||
Focal Loss for handling class imbalance in multi-label classification.
|
||||
|
||||
Reduces the relative loss for well-classified examples (high probability),
|
||||
putting more focus on hard, misclassified examples.
|
||||
|
||||
FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
|
||||
|
||||
Args:
|
||||
alpha: Weighting factor for positive examples (default: 0.25)
|
||||
gamma: Focusing parameter - higher = more focus on hard examples (default: 2.0)
|
||||
"""
|
||||
|
||||
def __init__(self, alpha: float = 0.25, gamma: float = 2.0, reduction: str = 'mean'):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
# inputs are logits, targets are binary labels
|
||||
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
|
||||
|
||||
# Get probabilities
|
||||
probs = torch.sigmoid(inputs)
|
||||
p_t = probs * targets + (1 - probs) * (1 - targets)
|
||||
|
||||
# Apply focal weighting
|
||||
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
|
||||
focal_weight = alpha_t * (1 - p_t) ** self.gamma
|
||||
|
||||
focal_loss = focal_weight * BCE_loss
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return focal_loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return focal_loss.sum()
|
||||
return focal_loss
|
||||
|
||||
|
||||
class VegaLossV2(nn.Module):
|
||||
"""
|
||||
Enhanced loss function with Focal Loss option and tunable weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classification_weight: float = 1.0,
|
||||
regression_weight: float = 0.1,
|
||||
use_focal_loss: bool = True,
|
||||
focal_alpha: float = 0.25,
|
||||
focal_gamma: float = 2.0,
|
||||
pos_weight: Optional[torch.Tensor] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.classification_weight = classification_weight
|
||||
self.regression_weight = regression_weight
|
||||
self.use_focal_loss = use_focal_loss
|
||||
|
||||
if use_focal_loss:
|
||||
self.cls_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
|
||||
else:
|
||||
self.cls_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
|
||||
self.reg_loss = nn.HuberLoss(delta=0.1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_logits: torch.Tensor,
|
||||
pred_activities: torch.Tensor,
|
||||
true_presence: torch.Tensor,
|
||||
true_activities: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
# Classification loss
|
||||
cls_loss = self.cls_loss(pred_logits, true_presence)
|
||||
|
||||
# Regression loss (only for present isotopes)
|
||||
mask = true_presence > 0.5
|
||||
if mask.any():
|
||||
reg_loss = self.reg_loss(
|
||||
pred_activities[mask],
|
||||
true_activities[mask]
|
||||
)
|
||||
else:
|
||||
reg_loss = torch.tensor(0.0, device=pred_logits.device)
|
||||
|
||||
# Combined loss
|
||||
total_loss = (
|
||||
self.classification_weight * cls_loss +
|
||||
self.regression_weight * reg_loss
|
||||
)
|
||||
|
||||
return total_loss, {
|
||||
'total': total_loss.item(),
|
||||
'classification': cls_loss.item(),
|
||||
'regression': reg_loss.item() if mask.any() else 0.0
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptunaConfig:
|
||||
"""Configuration for Optuna hyperparameter optimization."""
|
||||
|
||||
# Data
|
||||
data_dir: str = "O:/master_data_collection/isotopev2"
|
||||
model_dir: str = "models/optuna"
|
||||
study_name: str = "vega_v2_optimization"
|
||||
|
||||
# Optuna settings
|
||||
n_trials: int = 50
|
||||
timeout_hours: float = 24.0
|
||||
n_startup_trials: int = 10 # Random sampling before TPE
|
||||
|
||||
# Training settings for each trial
|
||||
max_epochs: int = 30 # Shorter epochs for faster trials
|
||||
patience: int = 5 # Early stopping patience
|
||||
|
||||
# Data splits
|
||||
train_split: float = 0.8
|
||||
val_split: float = 0.1
|
||||
test_split: float = 0.1
|
||||
|
||||
# Fixed settings
|
||||
num_workers: int = 8
|
||||
prefetch_factor: int = 4
|
||||
persistent_workers: bool = True
|
||||
use_amp: bool = True
|
||||
|
||||
# Optimization objective
|
||||
optimize_metric: str = "val_recall" # Focus on recall
|
||||
|
||||
# Reproducibility
|
||||
seed: int = 42
|
||||
|
||||
|
||||
def suggest_hyperparameters(trial: Trial) -> Dict:
|
||||
"""
|
||||
Suggest hyperparameters for a trial using Optuna.
|
||||
|
||||
Returns a dictionary with all hyperparameters to try.
|
||||
"""
|
||||
params = {}
|
||||
|
||||
# ========== Model Architecture ==========
|
||||
# CNN backbone
|
||||
n_conv_layers = trial.suggest_int("n_conv_layers", 2, 4)
|
||||
conv_channels = []
|
||||
for i in range(n_conv_layers):
|
||||
ch = trial.suggest_categorical(f"conv_ch_{i}", [32, 64, 128, 256, 512])
|
||||
conv_channels.append(ch)
|
||||
params["conv_channels"] = conv_channels
|
||||
|
||||
params["conv_kernel_size"] = trial.suggest_categorical("conv_kernel_size", [3, 5, 7, 9, 11])
|
||||
params["pool_size"] = trial.suggest_categorical("pool_size", [2, 3, 4])
|
||||
|
||||
# FC layers
|
||||
n_fc_layers = trial.suggest_int("n_fc_layers", 1, 3)
|
||||
fc_dims = []
|
||||
for i in range(n_fc_layers):
|
||||
dim = trial.suggest_categorical(f"fc_dim_{i}", [128, 256, 512, 1024])
|
||||
fc_dims.append(dim)
|
||||
params["fc_hidden_dims"] = fc_dims
|
||||
|
||||
# Regularization
|
||||
params["dropout_rate"] = trial.suggest_float("dropout_rate", 0.1, 0.5)
|
||||
params["spatial_dropout_rate"] = trial.suggest_float("spatial_dropout_rate", 0.05, 0.3)
|
||||
params["leaky_relu_slope"] = trial.suggest_float("leaky_relu_slope", 0.01, 0.2)
|
||||
|
||||
# ========== Training Hyperparameters ==========
|
||||
params["batch_size"] = trial.suggest_categorical("batch_size", [128, 256, 512, 1024])
|
||||
params["learning_rate"] = trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True)
|
||||
params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
|
||||
|
||||
# Optimizer
|
||||
params["optimizer"] = trial.suggest_categorical("optimizer", ["adam", "adamw"])
|
||||
|
||||
# Learning rate scheduler
|
||||
params["scheduler"] = trial.suggest_categorical("scheduler", ["plateau", "cosine"])
|
||||
if params["scheduler"] == "plateau":
|
||||
params["lr_factor"] = trial.suggest_float("lr_factor", 0.1, 0.5)
|
||||
params["lr_patience"] = trial.suggest_int("lr_patience", 3, 10)
|
||||
else:
|
||||
params["cosine_t_0"] = trial.suggest_int("cosine_t_0", 5, 15)
|
||||
params["cosine_t_mult"] = trial.suggest_int("cosine_t_mult", 1, 2)
|
||||
|
||||
# ========== Loss Function ==========
|
||||
params["use_focal_loss"] = trial.suggest_categorical("use_focal_loss", [True, False])
|
||||
if params["use_focal_loss"]:
|
||||
params["focal_alpha"] = trial.suggest_float("focal_alpha", 0.1, 0.5)
|
||||
params["focal_gamma"] = trial.suggest_float("focal_gamma", 1.0, 3.0)
|
||||
|
||||
params["classification_weight"] = trial.suggest_float("classification_weight", 0.5, 2.0)
|
||||
params["regression_weight"] = trial.suggest_float("regression_weight", 0.01, 0.5, log=True)
|
||||
|
||||
# ========== Classification Threshold ==========
|
||||
params["threshold"] = trial.suggest_float("threshold", 0.3, 0.7)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def create_model_from_params(params: Dict, num_isotopes: int) -> VegaModel:
|
||||
"""Create a VegaModel from hyperparameters."""
|
||||
config = VegaConfig(
|
||||
num_isotopes=num_isotopes,
|
||||
conv_channels=params["conv_channels"],
|
||||
conv_kernel_size=params["conv_kernel_size"],
|
||||
pool_size=params["pool_size"],
|
||||
fc_hidden_dims=params["fc_hidden_dims"],
|
||||
dropout_rate=params["dropout_rate"],
|
||||
spatial_dropout_rate=params["spatial_dropout_rate"],
|
||||
leaky_relu_slope=params["leaky_relu_slope"],
|
||||
classification_weight=params["classification_weight"],
|
||||
regression_weight=params["regression_weight"]
|
||||
)
|
||||
return VegaModel(config)
|
||||
|
||||
|
||||
def train_single_trial(
|
||||
trial: Trial,
|
||||
params: Dict,
|
||||
train_loader,
|
||||
val_loader,
|
||||
device: torch.device,
|
||||
optuna_config: OptunaConfig
|
||||
) -> float:
|
||||
"""
|
||||
Train a single trial and return the objective metric.
|
||||
"""
|
||||
# Create model
|
||||
num_isotopes = 82 # From isotope database
|
||||
model = create_model_from_params(params, num_isotopes)
|
||||
model = model.to(device)
|
||||
|
||||
# Create loss function
|
||||
loss_fn = VegaLossV2(
|
||||
classification_weight=params["classification_weight"],
|
||||
regression_weight=params["regression_weight"],
|
||||
use_focal_loss=params.get("use_focal_loss", False),
|
||||
focal_alpha=params.get("focal_alpha", 0.25),
|
||||
focal_gamma=params.get("focal_gamma", 2.0)
|
||||
)
|
||||
|
||||
# Create optimizer
|
||||
if params["optimizer"] == "adamw":
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=params["learning_rate"],
|
||||
weight_decay=params["weight_decay"]
|
||||
)
|
||||
else:
|
||||
optimizer = Adam(
|
||||
model.parameters(),
|
||||
lr=params["learning_rate"],
|
||||
weight_decay=params["weight_decay"]
|
||||
)
|
||||
|
||||
# Create scheduler
|
||||
if params["scheduler"] == "cosine":
|
||||
scheduler = CosineAnnealingWarmRestarts(
|
||||
optimizer,
|
||||
T_0=params.get("cosine_t_0", 10),
|
||||
T_mult=params.get("cosine_t_mult", 1)
|
||||
)
|
||||
else:
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='min',
|
||||
patience=params.get("lr_patience", 5),
|
||||
factor=params.get("lr_factor", 0.5)
|
||||
)
|
||||
|
||||
# Mixed precision
|
||||
scaler = torch.amp.GradScaler('cuda') if optuna_config.use_amp and device.type == 'cuda' else None
|
||||
|
||||
# Training loop
|
||||
best_metric = 0.0
|
||||
epochs_without_improvement = 0
|
||||
threshold = params["threshold"]
|
||||
|
||||
for epoch in range(optuna_config.max_epochs):
|
||||
# Training
|
||||
model.train()
|
||||
for batch in train_loader:
|
||||
spectra = batch['spectrum'].to(device)
|
||||
presence = batch['presence_labels'].to(device)
|
||||
activities = batch['activity_labels'].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if scaler is not None:
|
||||
with torch.amp.autocast('cuda'):
|
||||
pred_logits, pred_activities = model(spectra)
|
||||
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
pred_logits, pred_activities = model(spectra)
|
||||
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Validation
|
||||
val_metrics = validate_model(model, val_loader, device, threshold, loss_fn)
|
||||
|
||||
# Update scheduler
|
||||
if params["scheduler"] == "cosine":
|
||||
scheduler.step()
|
||||
else:
|
||||
scheduler.step(val_metrics['val_loss'])
|
||||
|
||||
# Get objective metric
|
||||
current_metric = val_metrics.get(optuna_config.optimize_metric, val_metrics['val_recall'])
|
||||
|
||||
# Report to Optuna for pruning
|
||||
trial.report(current_metric, epoch)
|
||||
|
||||
# Handle pruning
|
||||
if trial.should_prune():
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
# Track best
|
||||
if current_metric > best_metric:
|
||||
best_metric = current_metric
|
||||
epochs_without_improvement = 0
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
# Early stopping
|
||||
if epochs_without_improvement >= optuna_config.patience:
|
||||
break
|
||||
|
||||
return best_metric
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_model(
|
||||
model: VegaModel,
|
||||
val_loader,
|
||||
device: torch.device,
|
||||
threshold: float,
|
||||
loss_fn: nn.Module
|
||||
) -> Dict[str, float]:
|
||||
"""Validate model and return comprehensive metrics."""
|
||||
model.eval()
|
||||
|
||||
all_probs = []
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in val_loader:
|
||||
spectra = batch['spectrum'].to(device)
|
||||
presence = batch['presence_labels'].to(device)
|
||||
activities = batch['activity_labels'].to(device)
|
||||
|
||||
pred_logits, pred_activities = model(spectra)
|
||||
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Get predictions
|
||||
probs = torch.sigmoid(pred_logits)
|
||||
preds = (probs >= threshold).float()
|
||||
|
||||
all_probs.append(probs.cpu().numpy())
|
||||
all_preds.append(preds.cpu().numpy())
|
||||
all_labels.append(presence.cpu().numpy())
|
||||
|
||||
# Concatenate
|
||||
all_probs = np.vstack(all_probs)
|
||||
all_preds = np.vstack(all_preds)
|
||||
all_labels = np.vstack(all_labels)
|
||||
|
||||
# Calculate metrics
|
||||
metrics = {
|
||||
'val_loss': total_loss / num_batches,
|
||||
'val_accuracy': (all_preds == all_labels).mean()
|
||||
}
|
||||
|
||||
# Per-sample exact match
|
||||
exact_matches = (all_preds == all_labels).all(axis=1).mean()
|
||||
metrics['val_exact_match'] = exact_matches
|
||||
|
||||
# Sklearn metrics (handle edge cases)
|
||||
try:
|
||||
# Only for columns with both classes present
|
||||
valid_cols = (all_labels.sum(axis=0) > 0) & (all_labels.sum(axis=0) < len(all_labels))
|
||||
if valid_cols.any():
|
||||
metrics['val_auc_macro'] = roc_auc_score(
|
||||
all_labels[:, valid_cols],
|
||||
all_probs[:, valid_cols],
|
||||
average='macro'
|
||||
)
|
||||
except Exception:
|
||||
metrics['val_auc_macro'] = 0.5
|
||||
|
||||
# Flatten for F1, precision, recall
|
||||
all_preds_flat = all_preds.flatten()
|
||||
all_labels_flat = all_labels.flatten()
|
||||
|
||||
metrics['val_f1_macro'] = f1_score(all_labels_flat, all_preds_flat, average='macro', zero_division=0)
|
||||
metrics['val_precision'] = precision_score(all_labels_flat, all_preds_flat, average='macro', zero_division=0)
|
||||
metrics['val_recall'] = recall_score(all_labels_flat, all_preds_flat, average='macro', zero_division=0)
|
||||
metrics['val_hamming'] = hamming_loss(all_labels, all_preds)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def objective(trial: Trial, optuna_config: OptunaConfig) -> float:
|
||||
"""
|
||||
Optuna objective function.
|
||||
|
||||
Returns the metric to maximize (recall by default).
|
||||
"""
|
||||
# Suggest hyperparameters
|
||||
params = suggest_hyperparameters(trial)
|
||||
|
||||
# Log parameters
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Trial {trial.number}")
|
||||
print(f"{'='*60}")
|
||||
for k, v in params.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Setup device
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Get isotope index
|
||||
isotope_index = get_default_isotope_index()
|
||||
|
||||
# Create data loaders with trial's batch size
|
||||
train_loader, val_loader, _ = create_data_loaders(
|
||||
data_dir=Path(optuna_config.data_dir),
|
||||
batch_size=params["batch_size"],
|
||||
train_split=optuna_config.train_split,
|
||||
val_split=optuna_config.val_split,
|
||||
test_split=optuna_config.test_split,
|
||||
num_workers=optuna_config.num_workers,
|
||||
prefetch_factor=optuna_config.prefetch_factor,
|
||||
persistent_workers=optuna_config.persistent_workers,
|
||||
isotope_index=isotope_index,
|
||||
seed=optuna_config.seed
|
||||
)
|
||||
|
||||
try:
|
||||
metric = train_single_trial(
|
||||
trial, params, train_loader, val_loader, device, optuna_config
|
||||
)
|
||||
print(f"Trial {trial.number} completed with {optuna_config.optimize_metric}: {metric:.4f}")
|
||||
return metric
|
||||
except Exception as e:
|
||||
print(f"Trial {trial.number} failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def run_optimization(config: OptunaConfig) -> optuna.Study:
|
||||
"""
|
||||
Run the full Optuna optimization study.
|
||||
"""
|
||||
# Create model directory
|
||||
model_dir = Path(config.model_dir)
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create study with TPE sampler and Hyperband pruner
|
||||
sampler = TPESampler(
|
||||
n_startup_trials=config.n_startup_trials,
|
||||
seed=config.seed
|
||||
)
|
||||
pruner = HyperbandPruner(
|
||||
min_resource=3,
|
||||
max_resource=config.max_epochs,
|
||||
reduction_factor=3
|
||||
)
|
||||
|
||||
# Create or load study
|
||||
storage = f"sqlite:///{model_dir / config.study_name}.db"
|
||||
study = optuna.create_study(
|
||||
study_name=config.study_name,
|
||||
storage=storage,
|
||||
load_if_exists=True,
|
||||
direction="maximize", # Maximize recall
|
||||
sampler=sampler,
|
||||
pruner=pruner
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("VEGA V2 - OPTUNA HYPERPARAMETER OPTIMIZATION")
|
||||
print("=" * 60)
|
||||
print(f"Study name: {config.study_name}")
|
||||
print(f"Optimization metric: {config.optimize_metric}")
|
||||
print(f"Number of trials: {config.n_trials}")
|
||||
print(f"Timeout: {config.timeout_hours} hours")
|
||||
print(f"Data directory: {config.data_dir}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Run optimization
|
||||
study.optimize(
|
||||
lambda trial: objective(trial, config),
|
||||
n_trials=config.n_trials,
|
||||
timeout=config.timeout_hours * 3600,
|
||||
show_progress_bar=True,
|
||||
gc_after_trial=True
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("OPTIMIZATION COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"Best trial: {study.best_trial.number}")
|
||||
print(f"Best {config.optimize_metric}: {study.best_value:.4f}")
|
||||
print("\nBest hyperparameters:")
|
||||
for k, v in study.best_params.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Save best parameters
|
||||
best_params_path = model_dir / "best_params.json"
|
||||
with open(best_params_path, 'w') as f:
|
||||
json.dump({
|
||||
'best_value': study.best_value,
|
||||
'best_params': study.best_params,
|
||||
'study_name': config.study_name,
|
||||
'optimize_metric': config.optimize_metric
|
||||
}, f, indent=2)
|
||||
print(f"\nBest parameters saved to: {best_params_path}")
|
||||
|
||||
return study
|
||||
|
||||
|
||||
def train_best_model(
|
||||
study: optuna.Study,
|
||||
config: OptunaConfig,
|
||||
full_epochs: int = 100
|
||||
) -> Tuple[VegaModel, Dict]:
|
||||
"""
|
||||
Train the best model from the study with full epochs.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("TRAINING BEST MODEL")
|
||||
print("=" * 60)
|
||||
|
||||
best_params = study.best_params
|
||||
|
||||
# Reconstruct full params dict from best_params
|
||||
params = {}
|
||||
|
||||
# CNN layers
|
||||
n_conv_layers = best_params.get("n_conv_layers", 3)
|
||||
conv_channels = [best_params.get(f"conv_ch_{i}", 128) for i in range(n_conv_layers)]
|
||||
params["conv_channels"] = conv_channels
|
||||
params["conv_kernel_size"] = best_params.get("conv_kernel_size", 7)
|
||||
params["pool_size"] = best_params.get("pool_size", 2)
|
||||
|
||||
# FC layers
|
||||
n_fc_layers = best_params.get("n_fc_layers", 2)
|
||||
fc_dims = [best_params.get(f"fc_dim_{i}", 256) for i in range(n_fc_layers)]
|
||||
params["fc_hidden_dims"] = fc_dims
|
||||
|
||||
# Other params
|
||||
for key in ["dropout_rate", "spatial_dropout_rate", "leaky_relu_slope",
|
||||
"batch_size", "learning_rate", "weight_decay", "optimizer",
|
||||
"scheduler", "lr_factor", "lr_patience", "cosine_t_0", "cosine_t_mult",
|
||||
"use_focal_loss", "focal_alpha", "focal_gamma",
|
||||
"classification_weight", "regression_weight", "threshold"]:
|
||||
if key in best_params:
|
||||
params[key] = best_params[key]
|
||||
|
||||
# Set defaults for missing params
|
||||
params.setdefault("dropout_rate", 0.3)
|
||||
params.setdefault("spatial_dropout_rate", 0.1)
|
||||
params.setdefault("leaky_relu_slope", 0.1)
|
||||
params.setdefault("batch_size", 512)
|
||||
params.setdefault("learning_rate", 1e-3)
|
||||
params.setdefault("weight_decay", 1e-4)
|
||||
params.setdefault("optimizer", "adamw")
|
||||
params.setdefault("scheduler", "plateau")
|
||||
params.setdefault("classification_weight", 1.0)
|
||||
params.setdefault("regression_weight", 0.1)
|
||||
params.setdefault("threshold", 0.5)
|
||||
params.setdefault("use_focal_loss", True)
|
||||
params.setdefault("focal_alpha", 0.25)
|
||||
params.setdefault("focal_gamma", 2.0)
|
||||
|
||||
print("Training with parameters:")
|
||||
for k, v in params.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Setup
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
isotope_index = get_default_isotope_index()
|
||||
model_dir = Path(config.model_dir)
|
||||
|
||||
# Create data loaders
|
||||
train_loader, val_loader, test_loader = create_data_loaders(
|
||||
data_dir=Path(config.data_dir),
|
||||
batch_size=params["batch_size"],
|
||||
train_split=config.train_split,
|
||||
val_split=config.val_split,
|
||||
test_split=config.test_split,
|
||||
num_workers=config.num_workers,
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
persistent_workers=config.persistent_workers,
|
||||
isotope_index=isotope_index,
|
||||
seed=config.seed
|
||||
)
|
||||
|
||||
# Create model
|
||||
model = create_model_from_params(params, isotope_index.num_isotopes)
|
||||
model = model.to(device)
|
||||
|
||||
# Create loss
|
||||
loss_fn = VegaLossV2(
|
||||
classification_weight=params["classification_weight"],
|
||||
regression_weight=params["regression_weight"],
|
||||
use_focal_loss=params.get("use_focal_loss", False),
|
||||
focal_alpha=params.get("focal_alpha", 0.25),
|
||||
focal_gamma=params.get("focal_gamma", 2.0)
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
if params["optimizer"] == "adamw":
|
||||
optimizer = AdamW(model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
|
||||
else:
|
||||
optimizer = Adam(model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
|
||||
|
||||
# Scheduler
|
||||
if params.get("scheduler") == "cosine":
|
||||
scheduler = CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=params.get("cosine_t_0", 10), T_mult=params.get("cosine_t_mult", 1)
|
||||
)
|
||||
else:
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer, mode='min', patience=params.get("lr_patience", 5), factor=params.get("lr_factor", 0.5)
|
||||
)
|
||||
|
||||
# Mixed precision
|
||||
scaler = torch.amp.GradScaler('cuda') if config.use_amp and device.type == 'cuda' else None
|
||||
|
||||
# Training
|
||||
best_recall = 0.0
|
||||
threshold = params["threshold"]
|
||||
history = []
|
||||
|
||||
for epoch in range(full_epochs):
|
||||
# Train
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in train_loader:
|
||||
spectra = batch['spectrum'].to(device)
|
||||
presence = batch['presence_labels'].to(device)
|
||||
activities = batch['activity_labels'].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if scaler is not None:
|
||||
with torch.amp.autocast('cuda'):
|
||||
pred_logits, pred_activities = model(spectra)
|
||||
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
pred_logits, pred_activities = model(spectra)
|
||||
loss, _ = loss_fn(pred_logits, pred_activities, presence, activities)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
train_loss /= num_batches
|
||||
|
||||
# Validate
|
||||
val_metrics = validate_model(model, val_loader, device, threshold, loss_fn)
|
||||
|
||||
# Scheduler step
|
||||
if params.get("scheduler") == "cosine":
|
||||
scheduler.step()
|
||||
else:
|
||||
scheduler.step(val_metrics['val_loss'])
|
||||
|
||||
# Log
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
print(f"Epoch {epoch+1:3d}/{full_epochs} | Train Loss: {train_loss:.4f} | "
|
||||
f"Val Loss: {val_metrics['val_loss']:.4f} | Recall: {val_metrics['val_recall']:.4f} | "
|
||||
f"F1: {val_metrics['val_f1_macro']:.4f} | Exact: {val_metrics['val_exact_match']:.4f} | LR: {lr:.2e}")
|
||||
|
||||
# Save history
|
||||
history.append({
|
||||
'epoch': epoch,
|
||||
'train_loss': train_loss,
|
||||
**val_metrics,
|
||||
'lr': lr
|
||||
})
|
||||
|
||||
# Save best model
|
||||
if val_metrics['val_recall'] > best_recall:
|
||||
best_recall = val_metrics['val_recall']
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'best_recall': best_recall,
|
||||
'params': params,
|
||||
'val_metrics': val_metrics
|
||||
}
|
||||
torch.save(checkpoint, model_dir / "vega_v2_best.pt")
|
||||
print(f" └── New best! Saved model with recall: {best_recall:.4f}")
|
||||
|
||||
# Save final model
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict(),
|
||||
'params': params,
|
||||
'history': history
|
||||
}, model_dir / "vega_v2_final.pt")
|
||||
|
||||
# Save history
|
||||
with open(model_dir / "vega_v2_history.json", 'w') as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
# Test evaluation
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST SET EVALUATION")
|
||||
print("=" * 60)
|
||||
test_metrics = validate_model(model, test_loader, device, threshold, loss_fn)
|
||||
for k, v in test_metrics.items():
|
||||
print(f" {k}: {v:.4f}")
|
||||
|
||||
return model, {
|
||||
'best_recall': best_recall,
|
||||
'test_metrics': test_metrics,
|
||||
'history': history,
|
||||
'params': params
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Vega V2 - Optuna Hyperparameter Optimization")
|
||||
|
||||
parser.add_argument("--data-dir", type=str, default="O:/master_data_collection/isotopev2",
|
||||
help="Data directory")
|
||||
parser.add_argument("--model-dir", type=str, default="models/optuna",
|
||||
help="Model output directory")
|
||||
parser.add_argument("--study-name", type=str, default="vega_v2_optimization",
|
||||
help="Optuna study name")
|
||||
parser.add_argument("--n-trials", type=int, default=50,
|
||||
help="Number of Optuna trials")
|
||||
parser.add_argument("--timeout", type=float, default=24.0,
|
||||
help="Timeout in hours")
|
||||
parser.add_argument("--max-epochs", type=int, default=30,
|
||||
help="Max epochs per trial")
|
||||
parser.add_argument("--optimize-metric", type=str, default="val_recall",
|
||||
choices=["val_recall", "val_f1_macro", "val_auc_macro", "val_exact_match"],
|
||||
help="Metric to optimize")
|
||||
parser.add_argument("--train-best", action="store_true",
|
||||
help="Train best model with full epochs after optimization")
|
||||
parser.add_argument("--full-epochs", type=int, default=100,
|
||||
help="Epochs for training best model")
|
||||
parser.add_argument("--workers", type=int, default=8,
|
||||
help="Number of data loading workers")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create config
|
||||
config = OptunaConfig(
|
||||
data_dir=args.data_dir,
|
||||
model_dir=args.model_dir,
|
||||
study_name=args.study_name,
|
||||
n_trials=args.n_trials,
|
||||
timeout_hours=args.timeout,
|
||||
max_epochs=args.max_epochs,
|
||||
optimize_metric=args.optimize_metric,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
# Run optimization
|
||||
study = run_optimization(config)
|
||||
|
||||
# Optionally train best model
|
||||
if args.train_best:
|
||||
train_best_model(study, config, full_epochs=args.full_epochs)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user