Data Loading
JAX-bandflux provides utilities for loading supernova photometry data and preparing it for model fitting.
Synthetic Data
For testing and development, you can create synthetic observations:
# Create synthetic observation times
times = jnp.array([58650.0, 58655.0, 58660.0, 58665.0, 58670.0])
# Synthetic fluxes and errors
fluxes = jnp.array([100.0, 150.0, 180.0, 160.0, 120.0])
fluxerrs = jnp.array([5.0, 6.0, 7.0, 6.5, 5.5])
# All observations in one band
unique_bands = ['bessellb']
band_indices = jnp.zeros(5, dtype=jnp.int32)
print(f"Created {len(times)} observations")
print(f"Flux range: {float(jnp.min(fluxes)):.1f} to {float(jnp.max(fluxes)):.1f}")
Created 5 observations
Flux range: 100.0 to 180.0
Multi-Band Synthetic Data
Generate observations across multiple bands:
# Generate multi-band synthetic data
source = SALT3Source()
params = {'x0': 1e-4, 'x1': 0.5, 'c': 0.0}
z = 0.05
# Observation setup
obs_times = np.array([0, 5, 10, 15, 20]) # Days from peak
bands = ['bessellb', 'bessellv', 'bessellr']
# Generate observations
np.random.seed(42)
all_times, all_fluxes, all_errors, all_bands = [], [], [], []
for band in bands:
phases = obs_times / (1 + z)
true_flux = np.array(source.bandflux(params, band, phases, zp=27.5, zpsys='ab'))
noise = np.random.normal(0, np.abs(true_flux) * 0.05)
all_times.extend(obs_times)
all_fluxes.extend(true_flux + noise)
all_errors.extend(np.abs(true_flux) * 0.05)
all_bands.extend([band] * len(obs_times))
print(f"Generated {len(all_times)} observations across {len(bands)} bands")
print(f"Band distribution: {[(b, all_bands.count(b)) for b in bands]}")
Generated 15 observations across 3 bands
Band distribution: [('bessellb', 5), ('bessellv', 5), ('bessellr', 5)]
Data Structure
For fitting, you need the following arrays:
# Convert to JAX arrays
times = jnp.array(all_times)
fluxes = jnp.array(all_fluxes)
fluxerrs = jnp.array(all_errors)
zps = jnp.full(len(times), 27.5)
# Band indexing for optimized mode
unique_bands = ['bessellb', 'bessellv', 'bessellr']
band_to_idx = {b: i for i, b in enumerate(unique_bands)}
band_indices = jnp.array([band_to_idx[b] for b in all_bands])
# Pre-compute bridges
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
for b in unique_bands)
print(f"Pre-computed {len(bridges)} bridges for bands: {unique_bands}")
# Inspect bridge structure
print(f"B-band bridge wavelength grid: {bridges[0]['wave'].shape[0]} points")
print(f"Grid spacing: {bridges[0]['dwave']} Angstroms")
Pre-computed 3 bridges for bands: ['bessellb', 'bessellv', 'bessellr']
B-band bridge wavelength grid: 400 points
Grid spacing: 5.0 Angstroms
Required Data Arrays
Array |
Type |
Description |
|---|---|---|
|
float array |
Observation times (MJD or days from reference) |
|
float array |
Observed flux values |
|
float array |
Flux uncertainties (1-sigma) |
|
float array |
Zero points for each observation (typically 27.5 for AB mags) |
|
int array |
Index into |
|
tuple |
Pre-computed bandpass integration grids |
|
list |
List of unique bandpass names |
Loading Real Data
For real supernova data in HSF format, use load_and_process_data:
from jax_supernovae.data import load_and_process_data
# Load data for a specific supernova
result = load_and_process_data('19dwz', fix_z=True)
times, fluxes, fluxerrs, zps, band_indices, unique_bands, bridges, fixed_z = result
# fixed_z contains (z, z_err) if fix_z=True
z, z_err = fixed_z
print(f"Redshift: {z:.4f} ± {z_err:.4f}")
This function:
Loads photometry from
data/photometry/{sn_name}.datLoads redshift from
data/redshifts.dat(with fallback todata/targets.dat)Registers all required bandpasses
Pre-computes bridges for each unique band
Returns all arrays ready for fitting
Data File Format
The HSF photometry format expects tab-separated columns:
# time band flux fluxerr zp
58650.0 bessellb 123.45 6.17 27.5
58651.0 bessellv 156.78 7.84 27.5
...
Redshifts
Redshifts are loaded from two possible sources:
Primary:
data/redshifts.dat- High-quality spectroscopic redshiftsFallback:
data/targets.dat- All targets with potentially lower-quality redshifts
If you set fix_z=True, the loader looks in redshifts.dat. To provide your
own value, add a line with:
SN instrument z_hel plus minus flag
For example: 19dwz SNIFS 0.04608 5.2e-06 7.8e-07 s
Using the Data
Once loaded, compute model fluxes:
# Data summary
print(f"Data: {len(times)} observations, {len(set(all_bands))} bands")
# Compute model fluxes
z = 0.05
t0 = 0.0
phases = (times - t0) / (1 + z)
model = source.bandflux(
params, None, phases, zp=zps, zpsys='ab',
band_indices=band_indices,
bridges=bridges,
unique_bands=unique_bands
)
print(f"Computed {len(model)} model fluxes")
# Compare observed vs model
print("First 5 observations:")
for i in range(5):
print(f" Time {float(times[i]):7.1f}: obs={float(fluxes[i]):7.2f} ± {float(fluxerrs[i]):.2f}, model={float(model[i]):.2f}")
Data: 15 observations, 3 bands
Computed 15 model fluxes
First 5 observations:
Time 0.0: obs= 639.50 ± 31.20, model=624.01
Time 5.0: obs= 534.18 ± 26.90, model=537.90
Time 10.0: obs= 415.89 ± 20.14, model=402.84
Time 15.0: obs= 278.35 ± 12.93, model=258.65
Time 20.0: obs= 159.07 ± 8.05, model=160.96
Computing Chi-Squared
chi2 = jnp.sum(((fluxes - model) / fluxerrs)**2)
print(f"Chi-squared: {float(chi2):.2f} for {len(fluxes)} data points")
print(f"Reduced chi-squared: {float(chi2) / (len(fluxes) - 3):.2f}")
Chi-squared: 13.84 for 15 data points
Reduced chi-squared: 1.15
Preparing Your Own Data
The loader expects a simple ASCII table per supernova. By default it looks for
data/<SN_NAME>/all.phot (or any .phot/.dat containing the object name).
Required columns (case-insensitive aliases in parentheses):
time(ormjd): observation times in MJDband(orbandpass): filter name matching a registered bandflux: calibrated flux in linear units consistent withzp/zpsysfluxerr: 1-sigma uncertainty onflux
Optional columns:
zp: zero point (defaults to 27.5 if missing)zpsys: zero-point system, typicallyab
Band names recognised by default: g, r, i, z, ztfg, ztfr, c, o, H plus
bessellb, bessellv, bessellr, besselli, bessellux (from sncosmo). Custom
bandpasses can be registered via register_all_bandpasses(custom_bandpass_files=...).
Converting Magnitudes to Flux
If your data are in magnitudes, convert to flux:
zp = 23.9
mag = 20.0
magerr = 0.05
flux_conv = 10 ** (-0.4 * (mag - zp))
fluxerr_conv = flux_conv * np.log(10) * 0.4 * magerr
print(f"Magnitude {mag:.1f} mag → Flux {flux_conv:.2f} ± {fluxerr_conv:.2f}")
# Example: converting multiple magnitudes
mags = np.array([19.0, 20.0, 21.0, 22.0])
fluxes_from_mag = 10 ** (-0.4 * (mags - zp))
print("Magnitude → Flux conversion:")
for m, f in zip(mags, fluxes_from_mag):
print(f" {m:.1f} mag → {f:.2f}")
Magnitude 20.0 mag → Flux 36.31 ± 1.67
Magnitude → Flux conversion:
19.0 mag → 91.20
20.0 mag → 36.31
21.0 mag → 14.45
22.0 mag → 5.75
Multiple Supernovae
For fitting multiple supernovae simultaneously:
from jax_supernovae.data import load_multiple_supernovae
# Load multiple supernovae with shared band structure
sn_names = ['19dwz', '19agl', '19bcf']
data = load_multiple_supernovae(sn_names, fix_z=True)
# Access data for all supernovae
print(f"Loaded {data['n_sne']} supernovae")
print(f"Unique bands: {data['unique_bands']}")
# Individual supernova data
for i, name in enumerate(sn_names):
times_i = data['times_list'][i]
print(f"{name}: {len(times_i)} observations")
# Combined data for joint fitting
all_times = data['all_times']
sn_indices = data['sn_indices'] # Which SN each observation belongs to
See Sampling for examples of joint fitting with nested sampling.