Files
radiacode/train/vega_ml/training/vega/dataset_2d.py
Jacquin Antoine 745a64b342 Pipeline complet Radiacode 103 - identification automatique d'isotopes
- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89%
- Generation 50k spectres synthetiques 1D (12-24h durees)
- Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell)
- Detection continue avec soustraction du background
- Capture background 24h avec gestion deconnexion
- Docker Compose : conteneur train (GPU) + detect (CPU/USB)
- Modele entraite inclus (vega_best.pt, 395 Mo)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-19 12:29:56 +02:00

309 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Dataset for 2D Vega Model
Loads 2D spectra (time × channels) and pads/truncates to fixed dimensions.
"""
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from .isotope_index import IsotopeIndex, get_default_isotope_index
@dataclass
class SpectrumSample2D:
"""A single 2D spectrum sample."""
sample_id: str
spectrum: np.ndarray # 2D array (time_intervals, channels)
isotopes_present: List[str]
activities_bq: Dict[str, float]
duration_seconds: float
detector: str
class SpectrumDataset2D(Dataset):
"""
PyTorch Dataset for 2D gamma spectra.
Pads or truncates time dimension to fixed size for batch processing.
"""
def __init__(
self,
data_dir: Path,
isotope_index: Optional[IsotopeIndex] = None,
max_activity_bq: float = 1000.0,
target_time_intervals: int = 60,
transform=None
):
"""
Initialize the dataset.
Args:
data_dir: Path to directory containing spectra/ subdirectory
isotope_index: Index mapping isotope names to indices
max_activity_bq: Maximum activity for normalization
target_time_intervals: Fixed time dimension (pad/truncate to this)
transform: Optional transform to apply
"""
self.data_dir = Path(data_dir)
self.spectra_dir = self.data_dir / "spectra"
self.isotope_index = isotope_index or get_default_isotope_index()
self.max_activity_bq = max_activity_bq
self.target_time_intervals = target_time_intervals
self.transform = transform
# Detect label format and load sample list
self.use_individual_labels = self._detect_label_format()
if self.use_individual_labels:
self.sample_ids = self._scan_for_samples()
self.metadata = None
print(f"Using individual label files (efficient mode)")
else:
self.metadata = self._load_metadata()
self.sample_ids = list(self.metadata['samples'].keys())
print(f"Using combined labels.json (legacy mode)")
print(f"Loaded 2D dataset with {len(self.sample_ids)} samples")
print(f"Target shape: ({target_time_intervals}, 1023)")
print(f"Isotope index has {self.isotope_index.num_isotopes} isotopes")
def _detect_label_format(self) -> bool:
"""Detect whether to use individual JSON files or combined labels.json."""
json_files = list(self.spectra_dir.glob("spectrum_*.json"))
if len(json_files) > 0:
return True
labels_path = self.data_dir / "labels.json"
if labels_path.exists():
return False
raise FileNotFoundError(
f"No label files found. Expected either:\n"
f" - Individual files: {self.spectra_dir}/spectrum_*.json\n"
f" - Combined file: {self.data_dir}/labels.json"
)
def _scan_for_samples(self) -> List[str]:
"""Scan directory for sample IDs based on .npy files."""
npy_files = sorted(self.spectra_dir.glob("spectrum_*.npy"))
sample_ids = []
for npy_path in npy_files:
filename = npy_path.stem
sample_id = filename.replace("spectrum_", "")
sample_ids.append(sample_id)
return sample_ids
def _load_metadata(self) -> Dict:
"""Load the combined labels.json metadata file."""
labels_path = self.data_dir / "labels.json"
if not labels_path.exists():
raise FileNotFoundError(f"Labels file not found: {labels_path}")
with open(labels_path, 'r') as f:
return json.load(f)
def _load_sample_label(self, sample_id: str) -> Dict:
"""Load label for a single sample."""
if self.use_individual_labels:
json_path = self.spectra_dir / f"spectrum_{sample_id}.json"
with open(json_path, 'r') as f:
return json.load(f)
else:
return self.metadata['samples'][sample_id]
def _pad_or_truncate(self, spectrum: np.ndarray) -> np.ndarray:
"""
Pad or truncate spectrum to target time dimension.
Args:
spectrum: 2D array (time, channels)
Returns:
Array of shape (target_time_intervals, channels)
"""
current_time = spectrum.shape[0]
target_time = self.target_time_intervals
num_channels = spectrum.shape[1]
if current_time == target_time:
return spectrum
elif current_time > target_time:
# Truncate: take evenly spaced intervals to preserve temporal coverage
indices = np.linspace(0, current_time - 1, target_time, dtype=int)
return spectrum[indices, :]
else:
# Pad with zeros at the end
padded = np.zeros((target_time, num_channels), dtype=spectrum.dtype)
padded[:current_time, :] = spectrum
return padded
def __len__(self) -> int:
return len(self.sample_ids)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Get a single sample.
Returns:
Dictionary containing:
- spectrum: Tensor of shape (target_time_intervals, num_channels)
- presence_labels: Binary tensor (num_isotopes,)
- activity_labels: Tensor (num_isotopes,) with normalized activities
- sample_id: String identifier
"""
sample_id = self.sample_ids[idx]
sample_meta = self._load_sample_label(sample_id)
# Load spectrum
spectrum_path = self.spectra_dir / f"spectrum_{sample_id}.npy"
spectrum = np.load(spectrum_path)
# Ensure 2D
if spectrum.ndim == 1:
spectrum = spectrum.reshape(1, -1)
# Pad/truncate to fixed time dimension
spectrum = self._pad_or_truncate(spectrum)
# Normalize (max normalization)
max_val = spectrum.max()
if max_val > 0:
spectrum = spectrum / max_val
# Convert to tensor
spectrum_tensor = torch.tensor(spectrum, dtype=torch.float32)
# Apply transform if provided
if self.transform:
spectrum_tensor = self.transform(spectrum_tensor)
# Create presence labels
presence_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
for isotope_name in sample_meta['isotopes']:
try:
idx_isotope = self.isotope_index.name_to_index(isotope_name)
presence_labels[idx_isotope] = 1.0
except KeyError:
pass
# Create activity labels (normalized)
activity_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
for isotope_name, activity in sample_meta.get('source_activities_bq', {}).items():
try:
idx_isotope = self.isotope_index.name_to_index(isotope_name)
activity_labels[idx_isotope] = min(activity / self.max_activity_bq, 1.0)
except KeyError:
pass
return {
'spectrum': spectrum_tensor,
'presence_labels': presence_labels,
'activity_labels': activity_labels,
'sample_id': sample_id
}
def collate_fn_2d(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Custom collate function for 2D batching."""
return {
'spectrum': torch.stack([s['spectrum'] for s in batch]),
'presence_labels': torch.stack([s['presence_labels'] for s in batch]),
'activity_labels': torch.stack([s['activity_labels'] for s in batch]),
'sample_ids': [s['sample_id'] for s in batch]
}
def create_data_loaders_2d(
data_dir: Path,
batch_size: int = 32,
train_split: float = 0.8,
val_split: float = 0.1,
test_split: float = 0.1,
num_workers: int = 4,
target_time_intervals: int = 60,
isotope_index: Optional[IsotopeIndex] = None,
max_activity_bq: float = 1000.0,
seed: int = 42
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Create train, validation, and test data loaders for 2D data.
"""
# Create full dataset
dataset = SpectrumDataset2D(
data_dir=data_dir,
isotope_index=isotope_index,
max_activity_bq=max_activity_bq,
target_time_intervals=target_time_intervals
)
# Calculate split sizes
total = len(dataset)
train_size = int(total * train_split)
val_size = int(total * val_split)
test_size = total - train_size - val_size
# Split dataset
generator = torch.Generator().manual_seed(seed)
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size], generator=generator
)
print(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
# Create loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn_2d,
pin_memory=True,
persistent_workers=num_workers > 0
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn_2d,
pin_memory=True,
persistent_workers=num_workers > 0
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn_2d,
pin_memory=True,
persistent_workers=num_workers > 0
)
return train_loader, val_loader, test_loader
if __name__ == "__main__":
# Test the dataset
from pathlib import Path
data_dir = Path("O:/master_data_collection/isotopev2")
dataset = SpectrumDataset2D(data_dir, target_time_intervals=60)
sample = dataset[0]
print(f"\nSample:")
print(f" Spectrum shape: {sample['spectrum'].shape}")
print(f" Presence labels: {sample['presence_labels'].sum().item():.0f} isotopes")
print(f" Sample ID: {sample['sample_id']}")