- VegaModel CNN-FCNN 34.5M params, 82 isotopes, val acc 99.89% - Generation 50k spectres synthetiques 1D (12-24h durees) - Entrainement 100 epochs sur RTX 5060 Ti (CUDA 12.8, Blackwell) - Detection continue avec soustraction du background - Capture background 24h avec gestion deconnexion - Docker Compose : conteneur train (GPU) + detect (CPU/USB) - Modele entraite inclus (vega_best.pt, 395 Mo) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
309 lines
10 KiB
Python
309 lines
10 KiB
Python
"""
|
||
Dataset for 2D Vega Model
|
||
|
||
Loads 2D spectra (time × channels) and pads/truncates to fixed dimensions.
|
||
"""
|
||
|
||
import json
|
||
import numpy as np
|
||
import torch
|
||
from torch.utils.data import Dataset, DataLoader, random_split
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dataclasses import dataclass
|
||
|
||
from .isotope_index import IsotopeIndex, get_default_isotope_index
|
||
|
||
|
||
@dataclass
|
||
class SpectrumSample2D:
|
||
"""A single 2D spectrum sample."""
|
||
sample_id: str
|
||
spectrum: np.ndarray # 2D array (time_intervals, channels)
|
||
isotopes_present: List[str]
|
||
activities_bq: Dict[str, float]
|
||
duration_seconds: float
|
||
detector: str
|
||
|
||
|
||
class SpectrumDataset2D(Dataset):
|
||
"""
|
||
PyTorch Dataset for 2D gamma spectra.
|
||
|
||
Pads or truncates time dimension to fixed size for batch processing.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
data_dir: Path,
|
||
isotope_index: Optional[IsotopeIndex] = None,
|
||
max_activity_bq: float = 1000.0,
|
||
target_time_intervals: int = 60,
|
||
transform=None
|
||
):
|
||
"""
|
||
Initialize the dataset.
|
||
|
||
Args:
|
||
data_dir: Path to directory containing spectra/ subdirectory
|
||
isotope_index: Index mapping isotope names to indices
|
||
max_activity_bq: Maximum activity for normalization
|
||
target_time_intervals: Fixed time dimension (pad/truncate to this)
|
||
transform: Optional transform to apply
|
||
"""
|
||
self.data_dir = Path(data_dir)
|
||
self.spectra_dir = self.data_dir / "spectra"
|
||
self.isotope_index = isotope_index or get_default_isotope_index()
|
||
self.max_activity_bq = max_activity_bq
|
||
self.target_time_intervals = target_time_intervals
|
||
self.transform = transform
|
||
|
||
# Detect label format and load sample list
|
||
self.use_individual_labels = self._detect_label_format()
|
||
|
||
if self.use_individual_labels:
|
||
self.sample_ids = self._scan_for_samples()
|
||
self.metadata = None
|
||
print(f"Using individual label files (efficient mode)")
|
||
else:
|
||
self.metadata = self._load_metadata()
|
||
self.sample_ids = list(self.metadata['samples'].keys())
|
||
print(f"Using combined labels.json (legacy mode)")
|
||
|
||
print(f"Loaded 2D dataset with {len(self.sample_ids)} samples")
|
||
print(f"Target shape: ({target_time_intervals}, 1023)")
|
||
print(f"Isotope index has {self.isotope_index.num_isotopes} isotopes")
|
||
|
||
def _detect_label_format(self) -> bool:
|
||
"""Detect whether to use individual JSON files or combined labels.json."""
|
||
json_files = list(self.spectra_dir.glob("spectrum_*.json"))
|
||
if len(json_files) > 0:
|
||
return True
|
||
|
||
labels_path = self.data_dir / "labels.json"
|
||
if labels_path.exists():
|
||
return False
|
||
|
||
raise FileNotFoundError(
|
||
f"No label files found. Expected either:\n"
|
||
f" - Individual files: {self.spectra_dir}/spectrum_*.json\n"
|
||
f" - Combined file: {self.data_dir}/labels.json"
|
||
)
|
||
|
||
def _scan_for_samples(self) -> List[str]:
|
||
"""Scan directory for sample IDs based on .npy files."""
|
||
npy_files = sorted(self.spectra_dir.glob("spectrum_*.npy"))
|
||
sample_ids = []
|
||
for npy_path in npy_files:
|
||
filename = npy_path.stem
|
||
sample_id = filename.replace("spectrum_", "")
|
||
sample_ids.append(sample_id)
|
||
return sample_ids
|
||
|
||
def _load_metadata(self) -> Dict:
|
||
"""Load the combined labels.json metadata file."""
|
||
labels_path = self.data_dir / "labels.json"
|
||
if not labels_path.exists():
|
||
raise FileNotFoundError(f"Labels file not found: {labels_path}")
|
||
|
||
with open(labels_path, 'r') as f:
|
||
return json.load(f)
|
||
|
||
def _load_sample_label(self, sample_id: str) -> Dict:
|
||
"""Load label for a single sample."""
|
||
if self.use_individual_labels:
|
||
json_path = self.spectra_dir / f"spectrum_{sample_id}.json"
|
||
with open(json_path, 'r') as f:
|
||
return json.load(f)
|
||
else:
|
||
return self.metadata['samples'][sample_id]
|
||
|
||
def _pad_or_truncate(self, spectrum: np.ndarray) -> np.ndarray:
|
||
"""
|
||
Pad or truncate spectrum to target time dimension.
|
||
|
||
Args:
|
||
spectrum: 2D array (time, channels)
|
||
|
||
Returns:
|
||
Array of shape (target_time_intervals, channels)
|
||
"""
|
||
current_time = spectrum.shape[0]
|
||
target_time = self.target_time_intervals
|
||
num_channels = spectrum.shape[1]
|
||
|
||
if current_time == target_time:
|
||
return spectrum
|
||
|
||
elif current_time > target_time:
|
||
# Truncate: take evenly spaced intervals to preserve temporal coverage
|
||
indices = np.linspace(0, current_time - 1, target_time, dtype=int)
|
||
return spectrum[indices, :]
|
||
|
||
else:
|
||
# Pad with zeros at the end
|
||
padded = np.zeros((target_time, num_channels), dtype=spectrum.dtype)
|
||
padded[:current_time, :] = spectrum
|
||
return padded
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.sample_ids)
|
||
|
||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||
"""
|
||
Get a single sample.
|
||
|
||
Returns:
|
||
Dictionary containing:
|
||
- spectrum: Tensor of shape (target_time_intervals, num_channels)
|
||
- presence_labels: Binary tensor (num_isotopes,)
|
||
- activity_labels: Tensor (num_isotopes,) with normalized activities
|
||
- sample_id: String identifier
|
||
"""
|
||
sample_id = self.sample_ids[idx]
|
||
sample_meta = self._load_sample_label(sample_id)
|
||
|
||
# Load spectrum
|
||
spectrum_path = self.spectra_dir / f"spectrum_{sample_id}.npy"
|
||
spectrum = np.load(spectrum_path)
|
||
|
||
# Ensure 2D
|
||
if spectrum.ndim == 1:
|
||
spectrum = spectrum.reshape(1, -1)
|
||
|
||
# Pad/truncate to fixed time dimension
|
||
spectrum = self._pad_or_truncate(spectrum)
|
||
|
||
# Normalize (max normalization)
|
||
max_val = spectrum.max()
|
||
if max_val > 0:
|
||
spectrum = spectrum / max_val
|
||
|
||
# Convert to tensor
|
||
spectrum_tensor = torch.tensor(spectrum, dtype=torch.float32)
|
||
|
||
# Apply transform if provided
|
||
if self.transform:
|
||
spectrum_tensor = self.transform(spectrum_tensor)
|
||
|
||
# Create presence labels
|
||
presence_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
|
||
for isotope_name in sample_meta['isotopes']:
|
||
try:
|
||
idx_isotope = self.isotope_index.name_to_index(isotope_name)
|
||
presence_labels[idx_isotope] = 1.0
|
||
except KeyError:
|
||
pass
|
||
|
||
# Create activity labels (normalized)
|
||
activity_labels = torch.zeros(self.isotope_index.num_isotopes, dtype=torch.float32)
|
||
for isotope_name, activity in sample_meta.get('source_activities_bq', {}).items():
|
||
try:
|
||
idx_isotope = self.isotope_index.name_to_index(isotope_name)
|
||
activity_labels[idx_isotope] = min(activity / self.max_activity_bq, 1.0)
|
||
except KeyError:
|
||
pass
|
||
|
||
return {
|
||
'spectrum': spectrum_tensor,
|
||
'presence_labels': presence_labels,
|
||
'activity_labels': activity_labels,
|
||
'sample_id': sample_id
|
||
}
|
||
|
||
|
||
def collate_fn_2d(batch: List[Dict]) -> Dict[str, torch.Tensor]:
|
||
"""Custom collate function for 2D batching."""
|
||
return {
|
||
'spectrum': torch.stack([s['spectrum'] for s in batch]),
|
||
'presence_labels': torch.stack([s['presence_labels'] for s in batch]),
|
||
'activity_labels': torch.stack([s['activity_labels'] for s in batch]),
|
||
'sample_ids': [s['sample_id'] for s in batch]
|
||
}
|
||
|
||
|
||
def create_data_loaders_2d(
|
||
data_dir: Path,
|
||
batch_size: int = 32,
|
||
train_split: float = 0.8,
|
||
val_split: float = 0.1,
|
||
test_split: float = 0.1,
|
||
num_workers: int = 4,
|
||
target_time_intervals: int = 60,
|
||
isotope_index: Optional[IsotopeIndex] = None,
|
||
max_activity_bq: float = 1000.0,
|
||
seed: int = 42
|
||
) -> Tuple[DataLoader, DataLoader, DataLoader]:
|
||
"""
|
||
Create train, validation, and test data loaders for 2D data.
|
||
"""
|
||
# Create full dataset
|
||
dataset = SpectrumDataset2D(
|
||
data_dir=data_dir,
|
||
isotope_index=isotope_index,
|
||
max_activity_bq=max_activity_bq,
|
||
target_time_intervals=target_time_intervals
|
||
)
|
||
|
||
# Calculate split sizes
|
||
total = len(dataset)
|
||
train_size = int(total * train_split)
|
||
val_size = int(total * val_split)
|
||
test_size = total - train_size - val_size
|
||
|
||
# Split dataset
|
||
generator = torch.Generator().manual_seed(seed)
|
||
train_dataset, val_dataset, test_dataset = random_split(
|
||
dataset, [train_size, val_size, test_size], generator=generator
|
||
)
|
||
|
||
print(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
|
||
|
||
# Create loaders
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=True,
|
||
num_workers=num_workers,
|
||
collate_fn=collate_fn_2d,
|
||
pin_memory=True,
|
||
persistent_workers=num_workers > 0
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
collate_fn=collate_fn_2d,
|
||
pin_memory=True,
|
||
persistent_workers=num_workers > 0
|
||
)
|
||
|
||
test_loader = DataLoader(
|
||
test_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
collate_fn=collate_fn_2d,
|
||
pin_memory=True,
|
||
persistent_workers=num_workers > 0
|
||
)
|
||
|
||
return train_loader, val_loader, test_loader
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Test the dataset
|
||
from pathlib import Path
|
||
|
||
data_dir = Path("O:/master_data_collection/isotopev2")
|
||
|
||
dataset = SpectrumDataset2D(data_dir, target_time_intervals=60)
|
||
sample = dataset[0]
|
||
|
||
print(f"\nSample:")
|
||
print(f" Spectrum shape: {sample['spectrum'].shape}")
|
||
print(f" Presence labels: {sample['presence_labels'].sum().item():.0f} isotopes")
|
||
print(f" Sample ID: {sample['sample_id']}")
|