"""
Training Data Visualization Script
Generates an interactive HTML dashboard with Plotly visualizations to explore
the synthetic training data distribution, isotope combinations, activities,
durations, and sample spectra.
Usage:
python -m synthetic_spectra.visualize_training_data
python -m synthetic_spectra.visualize_training_data --data-dir data/synthetic/spectra
python -m synthetic_spectra.visualize_training_data --output report.html --max-samples 1000
Output:
An interactive HTML file that can be opened in any browser.
"""
import argparse
import json
import sys
from pathlib import Path
from collections import Counter, defaultdict
from itertools import combinations
from typing import Dict, List, Tuple, Optional
import numpy as np
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
try:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
except ImportError:
print("Error: Plotly is required. Install it with: pip install plotly")
sys.exit(1)
from synthetic_spectra.ground_truth.isotope_data import (
ISOTOPE_DATABASE,
IsotopeCategory,
get_isotopes_by_category,
)
def load_all_metadata(data_dir: Path, max_samples: Optional[int] = None) -> List[Dict]:
"""Load all JSON metadata files from the data directory."""
json_files = sorted(data_dir.glob("*.json"))
if max_samples is not None and len(json_files) > max_samples:
# Randomly sample if we have too many
np.random.seed(42)
indices = np.random.choice(len(json_files), max_samples, replace=False)
json_files = [json_files[i] for i in sorted(indices)]
metadata_list = []
print(f"Loading {len(json_files)} metadata files...")
for i, json_file in enumerate(json_files):
try:
with open(json_file, 'r') as f:
data = json.load(f)
data['_filename'] = json_file.stem
metadata_list.append(data)
except Exception as e:
print(f" Warning: Could not load {json_file}: {e}")
if (i + 1) % 1000 == 0:
print(f" Loaded {i + 1}/{len(json_files)} files...")
print(f"Loaded {len(metadata_list)} samples successfully.")
return metadata_list
def load_sample_spectra(data_dir: Path, sample_ids: List[str]) -> Dict[str, np.ndarray]:
"""Load a few sample spectra for visualization."""
spectra = {}
for sample_id in sample_ids:
npy_file = data_dir / f"{sample_id}.npy"
if npy_file.exists():
try:
spectra[sample_id] = np.load(npy_file)
except Exception as e:
print(f" Warning: Could not load spectrum {npy_file}: {e}")
return spectra
def compute_statistics(metadata_list: List[Dict]) -> Dict:
"""Compute various statistics from the metadata."""
stats = {
'total_samples': len(metadata_list),
'isotope_counts': Counter(),
'isotope_cooccurrence': defaultdict(int),
'num_isotopes_distribution': Counter(),
'durations': [],
'activities': defaultdict(list),
'detectors': Counter(),
'category_counts': Counter(),
'samples_by_num_isotopes': defaultdict(list),
}
for meta in metadata_list:
isotopes = meta.get('isotopes', [])
source_activities = meta.get('source_activities_bq', {})
duration = meta.get('duration_seconds', 0)
detector = meta.get('detector', 'unknown')
# Count isotopes
for iso in isotopes:
stats['isotope_counts'][iso] += 1
# Get category
if iso in ISOTOPE_DATABASE:
cat = ISOTOPE_DATABASE[iso].category.value
stats['category_counts'][cat] += 1
# Count isotope pairs (co-occurrence)
for pair in combinations(sorted(isotopes), 2):
stats['isotope_cooccurrence'][pair] += 1
# Number of isotopes distribution
num_iso = len(isotopes)
stats['num_isotopes_distribution'][num_iso] += 1
stats['samples_by_num_isotopes'][num_iso].append(meta['_filename'])
# Duration
stats['durations'].append(duration)
# Activities per isotope
for iso, activity in source_activities.items():
stats['activities'][iso].append(activity)
# Detector
stats['detectors'][detector] += 1
return stats
def create_isotope_frequency_chart(stats: Dict) -> go.Figure:
"""Create bar chart of isotope frequencies."""
isotope_counts = stats['isotope_counts']
# Sort by frequency
sorted_isotopes = sorted(isotope_counts.items(), key=lambda x: x[1], reverse=True)
isotopes, counts = zip(*sorted_isotopes) if sorted_isotopes else ([], [])
# Color by category
colors = []
category_colors = {
'natural_background': '#2ecc71',
'primordial': '#27ae60',
'cosmogenic': '#1abc9c',
'u238_chain': '#e74c3c',
'th232_chain': '#c0392b',
'u235_chain': '#d35400',
'calibration': '#3498db',
'industrial': '#9b59b6',
'medical': '#f1c40f',
'reactor_fallout': '#e67e22',
'activation': '#95a5a6',
}
for iso in isotopes:
if iso in ISOTOPE_DATABASE:
cat = ISOTOPE_DATABASE[iso].category.value
colors.append(category_colors.get(cat, '#7f8c8d'))
else:
colors.append('#7f8c8d')
fig = go.Figure(data=[
go.Bar(
x=list(isotopes),
y=list(counts),
marker_color=colors,
hovertemplate="%{x}
Count: %{y}"
)
])
fig.update_layout(
title="Isotope Frequency Distribution",
xaxis_title="Isotope",
yaxis_title="Number of Samples",
xaxis_tickangle=-45,
height=500,
showlegend=False
)
return fig
def create_category_pie_chart(stats: Dict) -> go.Figure:
"""Create pie chart of isotope categories."""
category_counts = stats['category_counts']
if not category_counts:
return go.Figure().add_annotation(text="No category data available",
xref="paper", yref="paper", x=0.5, y=0.5)
labels = list(category_counts.keys())
values = list(category_counts.values())
# Pretty names for categories
pretty_names = {
'natural_background': 'Natural Background',
'primordial': 'Primordial',
'cosmogenic': 'Cosmogenic',
'u238_chain': 'U-238 Chain',
'th232_chain': 'Th-232 Chain',
'u235_chain': 'U-235 Chain',
'calibration': 'Calibration',
'industrial': 'Industrial',
'medical': 'Medical',
'reactor_fallout': 'Reactor/Fallout',
'activation': 'Activation Products',
}
labels = [pretty_names.get(l, l) for l in labels]
fig = go.Figure(data=[
go.Pie(
labels=labels,
values=values,
hole=0.4,
hovertemplate="%{label}
Count: %{value}
%{percent}"
)
])
fig.update_layout(
title="Isotope Categories Distribution",
height=450,
)
return fig
def create_num_isotopes_histogram(stats: Dict) -> go.Figure:
"""Create histogram of number of isotopes per sample."""
num_iso_dist = stats['num_isotopes_distribution']
x = sorted(num_iso_dist.keys())
y = [num_iso_dist[k] for k in x]
# Calculate percentages
total = sum(y)
percentages = [f"{(v/total)*100:.1f}%" for v in y]
fig = go.Figure(data=[
go.Bar(
x=[str(k) for k in x],
y=y,
text=percentages,
textposition='auto',
marker_color='#3498db',
hovertemplate="%{x} isotopes
Count: %{y}
%{text}"
)
])
fig.update_layout(
title="Sample Complexity (Number of Isotopes per Sample)",
xaxis_title="Number of Source Isotopes",
yaxis_title="Number of Samples",
height=400,
)
return fig
def create_duration_histogram(stats: Dict) -> go.Figure:
"""Create histogram of measurement durations."""
durations = stats['durations']
if not durations:
return go.Figure().add_annotation(text="No duration data available",
xref="paper", yref="paper", x=0.5, y=0.5)
fig = go.Figure(data=[
go.Histogram(
x=durations,
nbinsx=50,
marker_color='#9b59b6',
hovertemplate="Duration: %{x:.1f}s
Count: %{y}"
)
])
fig.update_layout(
title="Measurement Duration Distribution",
xaxis_title="Duration (seconds)",
yaxis_title="Number of Samples",
height=400,
)
# Add statistics annotation
mean_dur = np.mean(durations)
std_dur = np.std(durations)
min_dur = np.min(durations)
max_dur = np.max(durations)
fig.add_annotation(
text=f"Mean: {mean_dur:.1f}s | Std: {std_dur:.1f}s | Range: [{min_dur:.1f}, {max_dur:.1f}]s",
xref="paper", yref="paper",
x=0.98, y=0.98,
xanchor='right', yanchor='top',
showarrow=False,
bgcolor="white",
bordercolor="black",
borderwidth=1,
font=dict(size=11)
)
return fig
def create_activity_boxplot(stats: Dict) -> go.Figure:
"""Create box plot of activities per isotope."""
activities = stats['activities']
if not activities:
return go.Figure().add_annotation(text="No activity data available",
xref="paper", yref="paper", x=0.5, y=0.5)
# Sort by median activity
sorted_isotopes = sorted(
activities.keys(),
key=lambda x: np.median(activities[x]) if activities[x] else 0,
reverse=True
)
# Only show top 30 for readability
top_isotopes = sorted_isotopes[:30]
fig = go.Figure()
for iso in top_isotopes:
fig.add_trace(go.Box(
y=activities[iso],
name=iso,
boxpoints='outliers',
hovertemplate=f"{iso}
Activity: %{{y:.2f}} Bq"
))
fig.update_layout(
title="Activity Distribution by Isotope (Top 30)",
xaxis_title="Isotope",
yaxis_title="Activity (Bq)",
xaxis_tickangle=-45,
height=500,
showlegend=False
)
return fig
def create_cooccurrence_heatmap(stats: Dict, top_n: int = 20) -> go.Figure:
"""Create heatmap of isotope co-occurrence."""
cooccurrence = stats['isotope_cooccurrence']
isotope_counts = stats['isotope_counts']
if not cooccurrence:
return go.Figure().add_annotation(text="No co-occurrence data (need multi-isotope samples)",
xref="paper", yref="paper", x=0.5, y=0.5)
# Get top N most frequent isotopes
top_isotopes = [iso for iso, _ in isotope_counts.most_common(top_n)]
# Build matrix
n = len(top_isotopes)
matrix = np.zeros((n, n))
for i, iso1 in enumerate(top_isotopes):
for j, iso2 in enumerate(top_isotopes):
if i < j:
pair = tuple(sorted([iso1, iso2]))
matrix[i, j] = cooccurrence.get(pair, 0)
matrix[j, i] = matrix[i, j]
fig = go.Figure(data=go.Heatmap(
z=matrix,
x=top_isotopes,
y=top_isotopes,
colorscale='Blues',
hovertemplate="%{x} + %{y}
Co-occurrences: %{z}"
))
fig.update_layout(
title=f"Isotope Co-occurrence Matrix (Top {top_n} Isotopes)",
xaxis_tickangle=-45,
height=600,
width=700,
)
return fig
def create_activity_vs_duration_scatter(metadata_list: List[Dict]) -> go.Figure:
"""Create scatter plot of total activity vs duration."""
durations = []
total_activities = []
num_isotopes = []
sample_ids = []
for meta in metadata_list:
duration = meta.get('duration_seconds', 0)
activities = meta.get('source_activities_bq', {})
if duration > 0 and activities:
durations.append(duration)
total_activities.append(sum(activities.values()))
num_isotopes.append(len(meta.get('isotopes', [])))
sample_ids.append(meta['_filename'])
if not durations:
return go.Figure().add_annotation(text="No data available",
xref="paper", yref="paper", x=0.5, y=0.5)
fig = go.Figure(data=go.Scatter(
x=durations,
y=total_activities,
mode='markers',
marker=dict(
size=6,
color=num_isotopes,
colorscale='Viridis',
colorbar=dict(title="# Isotopes"),
opacity=0.6
),
text=sample_ids,
hovertemplate="%{text}
Duration: %{x:.1f}s
Total Activity: %{y:.2f} Bq"
))
fig.update_layout(
title="Total Source Activity vs Measurement Duration",
xaxis_title="Duration (seconds)",
yaxis_title="Total Activity (Bq)",
height=500,
)
return fig
def create_sample_spectrum_plot(spectra: Dict[str, np.ndarray], metadata_list: List[Dict]) -> go.Figure:
"""Create interactive plot of sample spectra."""
if not spectra:
return go.Figure().add_annotation(text="No spectrum data loaded",
xref="paper", yref="paper", x=0.5, y=0.5)
# Create a metadata lookup
meta_lookup = {m['_filename']: m for m in metadata_list}
# Energy axis (keV) - 1023 channels from 20 to 3000 keV
num_channels = 1023
energy = np.linspace(20, 3000, num_channels)
fig = go.Figure()
colors = px.colors.qualitative.Set2
for i, (sample_id, spectrum) in enumerate(list(spectra.items())[:6]):
# Sum across time intervals to get total spectrum
total_spectrum = spectrum.sum(axis=0) if spectrum.ndim == 2 else spectrum
# Get isotope info
meta = meta_lookup.get(sample_id, {})
isotopes = meta.get('isotopes', ['Unknown'])
label = f"{sample_id[-6:]}: {', '.join(isotopes)}"
fig.add_trace(go.Scatter(
x=energy,
y=total_spectrum,
mode='lines',
name=label,
line=dict(color=colors[i % len(colors)], width=1),
hovertemplate=f"{label}
Energy: %{{x:.1f}} keV
Counts: %{{y:.2f}}"
))
fig.update_layout(
title="Sample Spectra (Time-Integrated)",
xaxis_title="Energy (keV)",
yaxis_title="Normalized Counts",
height=500,
legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
hovermode='closest'
)
return fig
def create_3d_spectrum_surface(spectrum: np.ndarray, sample_id: str) -> go.Figure:
"""Create 3D surface plot of a single spectrum (time vs energy vs counts)."""
if spectrum.ndim != 2:
return go.Figure().add_annotation(text="Spectrum must be 2D",
xref="paper", yref="paper", x=0.5, y=0.5)
num_intervals, num_channels = spectrum.shape
# Create axes
time_axis = np.arange(num_intervals)
energy_axis = np.linspace(20, 3000, num_channels)
# Downsample for performance if needed
if num_intervals > 100:
step = num_intervals // 100
spectrum = spectrum[::step, :]
time_axis = time_axis[::step]
if num_channels > 256:
ch_step = num_channels // 256
spectrum = spectrum[:, ::ch_step]
energy_axis = energy_axis[::ch_step]
fig = go.Figure(data=[
go.Surface(
z=spectrum,
x=energy_axis,
y=time_axis,
colorscale='Viridis',
hovertemplate="Time: %{y}s
Energy: %{x:.1f} keV
Counts: %{z:.3f}"
)
])
fig.update_layout(
title=f"3D Spectrum View: {sample_id}",
scene=dict(
xaxis_title="Energy (keV)",
yaxis_title="Time (s)",
zaxis_title="Counts",
),
height=600,
)
return fig
def create_summary_table(stats: Dict) -> str:
"""Create an HTML summary table."""
total = stats['total_samples']
num_unique_isotopes = len(stats['isotope_counts'])
avg_isotopes_per_sample = sum(k * v for k, v in stats['num_isotopes_distribution'].items()) / total if total else 0
durations = stats['durations']
activities_all = [a for acts in stats['activities'].values() for a in acts]
html = f"""
š Dataset Summary
| Total Samples |
{total:,} |
| Unique Isotopes |
{num_unique_isotopes} |
| Avg Isotopes per Sample |
{avg_isotopes_per_sample:.2f} |
| Duration Range |
{min(durations) if durations else 0:.1f}s - {max(durations) if durations else 0:.1f}s |
| Mean Duration |
{np.mean(durations) if durations else 0:.1f}s |
| Activity Range |
{min(activities_all) if activities_all else 0:.2f} - {max(activities_all) if activities_all else 0:.2f} Bq |
| Detectors |
{', '.join(stats['detectors'].keys())} |
"""
return html
def create_isotope_database_summary() -> go.Figure:
"""Create a sunburst chart of the isotope database by category."""
# Build hierarchy data
categories = defaultdict(list)
for name, isotope in ISOTOPE_DATABASE.items():
categories[isotope.category.value].append(name)
# Create sunburst data
ids = []
labels = []
parents = []
values = []
# Root
ids.append("Isotope Database")
labels.append("Isotope Database")
parents.append("")
values.append(len(ISOTOPE_DATABASE))
# Categories and isotopes
pretty_names = {
'natural_background': 'Natural Background',
'primordial': 'Primordial',
'cosmogenic': 'Cosmogenic',
'u238_chain': 'U-238 Chain',
'th232_chain': 'Th-232 Chain',
'u235_chain': 'U-235 Chain',
'calibration': 'Calibration',
'industrial': 'Industrial',
'medical': 'Medical',
'reactor_fallout': 'Reactor/Fallout',
'activation': 'Activation',
}
for cat, isotopes in categories.items():
cat_label = pretty_names.get(cat, cat)
ids.append(cat_label)
labels.append(f"{cat_label} ({len(isotopes)})")
parents.append("Isotope Database")
values.append(len(isotopes))
for iso in isotopes:
ids.append(f"{cat_label}/{iso}")
labels.append(iso)
parents.append(cat_label)
values.append(1)
fig = go.Figure(go.Sunburst(
ids=ids,
labels=labels,
parents=parents,
values=values,
branchvalues="total",
hovertemplate="%{label}"
))
fig.update_layout(
title=f"Isotope Database Structure ({len(ISOTOPE_DATABASE)} isotopes)",
height=600,
)
return fig
def generate_html_report(
data_dir: Path,
output_file: Path,
max_samples: Optional[int] = None
):
"""Generate the complete HTML report."""
print("=" * 60)
print("Training Data Visualization Report Generator")
print("=" * 60)
# Load all metadata
metadata_list = load_all_metadata(data_dir, max_samples)
if not metadata_list:
print("Error: No metadata files found!")
return
# Compute statistics
print("\nComputing statistics...")
stats = compute_statistics(metadata_list)
# Load a few sample spectra
print("\nLoading sample spectra for visualization...")
sample_ids = [m['_filename'] for m in metadata_list[:10]]
spectra = load_sample_spectra(data_dir, sample_ids)
print(f"\nGenerating visualizations...")
# Generate all figures
figures = {
'isotope_freq': create_isotope_frequency_chart(stats),
'category_pie': create_category_pie_chart(stats),
'num_isotopes': create_num_isotopes_histogram(stats),
'duration_hist': create_duration_histogram(stats),
'activity_box': create_activity_boxplot(stats),
'cooccurrence': create_cooccurrence_heatmap(stats),
'activity_duration': create_activity_vs_duration_scatter(metadata_list),
'sample_spectra': create_sample_spectrum_plot(spectra, metadata_list),
'isotope_db': create_isotope_database_summary(),
}
# Add 3D spectrum if we have data
if spectra:
first_id = list(spectra.keys())[0]
figures['spectrum_3d'] = create_3d_spectrum_surface(spectra[first_id], first_id)
# Create HTML
print("\nBuilding HTML report...")
html_parts = [
"""
Synthetic Training Data Visualization
š¬ Synthetic Gamma Spectra Training Data Analysis
""",
create_summary_table(stats),
"""
1. Isotope Distribution
What this shows: The frequency of each isotope across all training samples.
Imbalanced distributions may lead to model bias towards common isotopes.
""",
figures['isotope_freq'].to_html(full_html=False, include_plotlyjs=False),
"""
""",
figures['category_pie'].to_html(full_html=False, include_plotlyjs=False),
"""
2. Sample Complexity
What this shows: Distribution of how many source isotopes are present per sample.
Mix of single and multi-isotope samples helps the model handle real-world complexity.
""",
figures['num_isotopes'].to_html(full_html=False, include_plotlyjs=False),
"""
3. Temporal & Activity Analysis
What this shows: Distribution of measurement durations and source activities.
Varied durations simulate different counting scenarios.
""",
figures['duration_hist'].to_html(full_html=False, include_plotlyjs=False),
"""
""",
figures['activity_duration'].to_html(full_html=False, include_plotlyjs=False),
"""
""",
figures['activity_box'].to_html(full_html=False, include_plotlyjs=False),
"""
4. Isotope Co-occurrence
What this shows: Which isotopes frequently appear together in training samples.
This helps understand potential confusion pairs and realistic combinations.
""",
figures['cooccurrence'].to_html(full_html=False, include_plotlyjs=False),
"""
5. Sample Spectra Visualization
What this shows: Actual spectrum shapes from the training data.
Each peak corresponds to gamma emission lines from the source isotopes.
""",
figures['sample_spectra'].to_html(full_html=False, include_plotlyjs=False),
"""
"""
]
# Add 3D spectrum if available
if 'spectrum_3d' in figures:
html_parts.append("""
3D Time-Energy-Counts View
""")
html_parts.append(figures['spectrum_3d'].to_html(full_html=False, include_plotlyjs=False))
html_parts.append("")
html_parts.append("""
6. Isotope Database Overview
What this shows: The complete isotope database structure organized by category.
Click to explore the hierarchy.
""")
html_parts.append(figures['isotope_db'].to_html(full_html=False, include_plotlyjs=False))
html_parts.append("""
""")
# Write HTML file
html_content = ''.join(html_parts)
with open(output_file, 'w', encoding='utf-8') as f:
f.write(html_content)
print(f"\nā
Report generated successfully!")
print(f" Output: {output_file.absolute()}")
print(f"\nOpen in your browser to view the interactive visualizations.")
def main():
parser = argparse.ArgumentParser(
description="Generate interactive HTML visualization of training data"
)
parser.add_argument(
'--data-dir',
type=str,
default='data/synthetic/spectra',
help='Directory containing spectrum .json and .npy files'
)
parser.add_argument(
'--output',
type=str,
default='training_data_report.html',
help='Output HTML file name'
)
parser.add_argument(
'--max-samples',
type=int,
default=None,
help='Maximum number of samples to analyze (for faster generation)'
)
args = parser.parse_args()
data_dir = Path(args.data_dir)
output_file = Path(args.output)
if not data_dir.exists():
print(f"Error: Data directory not found: {data_dir}")
sys.exit(1)
generate_html_report(data_dir, output_file, args.max_samples)
if __name__ == "__main__":
main()