API Differences from SNCosmo
JAX-bandflux provides a SALT3Source implementation that aims to be as similar as possible to SNCosmo’s API, while accommodating JAX’s requirements for just-in-time (JIT) compilation and GPU acceleration.
Design Philosophy
We’ve designed JAX-bandflux to:
Maintain numerical consistency: Results match SNCosmo within 0.001%
Enable JIT compilation: All code paths are compatible with
@jax.jitMaximize performance: Precomputation strategies for repeated calculations
Minimize API changes: Keep the interface as familiar as possible
Key Differences
1. Functional Parameter API
The most significant difference is how parameters are handled.
SNCosmo approach (stateful):
import sncosmo
# Parameters stored in the source object
source = sncosmo.get_source('salt3-nir')
source['z'] = 0.1
source['t0'] = 58650.0
source['x0'] = 1e-5
source['x1'] = 0.0
source['c'] = 0.0
# Bandflux uses stored parameters
flux = source.bandflux('bessellb', 0.0, zp=27.5, zpsys='ab')
JAX-bandflux approach (functional):
# Parameters passed as dictionary argument
source = SALT3Source()
params = {
'x0': 1e-5,
'x1': 0.0,
'c': 0.0
}
# Bandflux receives parameters as argument
flux = source.bandflux(params, 'bessellb', 0.0, zp=27.5, zpsys='ab')
print(f"Flux: {float(flux):.4e}")
Flux: 6.2422e+01
Why this difference?
JAX’s JIT compiler cannot handle mutable object state. By passing parameters as function arguments rather than storing them as attributes, we enable:
JIT compilation of the entire likelihood function
Automatic differentiation through the model
GPU acceleration without state management issues
The functional API is a requirement for JAX compatibility, not a design choice.
2. Precomputed Bridges for Performance
SNCosmo approach (bandpass by name):
import sncosmo
source = sncosmo.get_source('salt3-nir')
source.set(z=0.1, t0=58650.0, x0=1e-5, x1=0.0, c=0.0)
# Bandpass loaded/interpolated each time
for i in range(100000): # Nested sampling
flux = source.bandflux('bessellb', phases[i], zp=27.5, zpsys='ab')
JAX-bandflux approach (with precomputed bridges):
# Pre-compute bridges ONCE
unique_bands = ['bessellb', 'bessellv', 'bessellr']
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
for b in unique_bands)
# Create test data
phases = jnp.linspace(-10, 30, 20)
band_names = ['bessellb'] * 7 + ['bessellv'] * 7 + ['bessellr'] * 6
band_to_idx = {b: i for i, b in enumerate(unique_bands)}
band_indices = jnp.array([band_to_idx[b] for b in band_names])
zps = jnp.full(len(phases), 27.5)
# Fast calculation using bridges
params = {'x0': 1e-5, 'x1': 0.0, 'c': 0.0}
fluxes = source.bandflux(
params, None, phases, zp=zps, zpsys='ab',
band_indices=band_indices,
bridges=bridges,
unique_bands=unique_bands
)
print(f"Computed {len(fluxes)} fluxes using bridges")
Computed 20 fluxes using bridges
What are bridges?
Bridges are precomputed data structures containing:
wave: Integration wavelength grid (e.g., [3622.5, 3627.5, …, 5617.5] Å)dwave: Grid spacing (e.g., 5.0 Å)trans: Precomputed transmission values on the grid
Why bridges?
For nested sampling or MCMC, you may evaluate the likelihood 100,000+ times. Without bridges:
Each evaluation: Load filter file → Create grid → Interpolate transmission → Integrate
Total: 100,000 × (file I/O + interpolation + integration)
With bridges (precomputed once):
Setup: Load filter files → Create grids → Store in bridges
Each evaluation: Lookup precomputed grid → Integrate
Total: 1 × (file I/O + interpolation) + 100,000 × integration
Speedup: ~100x faster
Performance Comparison
The following table shows performance for a typical nested sampling run:
Configuration |
Time per likelihood call |
100k iterations |
|---|---|---|
SNCosmo |
~10 ms |
~16 hours |
JAX-bandflux (no JIT) |
~8 ms |
~13 hours |
JAX-bandflux (JIT) |
~1 ms |
~1.6 hours |
JAX-bandflux (bridges) |
~0.1 ms |
~10 minutes |
Migration Guide
Converting SNCosmo code to JAX-bandflux:
Step 1: Change parameter assignment
# OLD (SNCosmo)
source['x0'] = 1e-5
source['x1'] = 0.0
# NEW (JAX-bandflux)
params = {'x0': 1e-5, 'x1': 0.0, 'c': 0.0}
Step 2: Pass parameters to bandflux
# OLD (SNCosmo)
flux = source.bandflux('bessellb', phase)
# NEW (JAX-bandflux)
flux = source.bandflux(params, 'bessellb', phase)
Step 3: (Optional) Use bridges for performance
# Pre-compute bridges
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
for b in unique_bands)
# Use bridges in bandflux
flux = source.bandflux(
params, None, phases, zp=zps, zpsys='ab',
band_indices=band_indices,
bridges=bridges,
unique_bands=unique_bands
)
print(f"Mean flux: {float(jnp.mean(flux)):.2e}")
Mean flux: 3.91e+01
Numerical Consistency
Despite the API differences, JAX-bandflux maintains numerical consistency with SNCosmo:
Model components (M0, M1, color law): Match to machine precision with x64 enabled; float32 remains close but may exceed machine-precision-level tolerances
Integration grids: Identical 5.0 Å spacing
Bandflux values: Match within 0.001% (limited by interpolation differences)
Our comprehensive test suite (tests/test_salt3nir_consistency.py) verifies:
✓ Component-level agreement (M0, M1, colorlaw)
✓ Single bandflux calculations
✓ Array-valued phases and bands
✓ Zeropoint scaling
✓ Multi-band light curves
Example: Full Workflow Comparison
SNCosmo:
import sncosmo
import numpy as np
# Setup
source = sncosmo.get_source('salt3-nir')
source.set(z=0.1, t0=58650.0, x0=1e-5, x1=0.0, c=0.0)
# Likelihood function
def loglikelihood(params):
source.set(x0=10**params[0], x1=params[1], c=params[2])
model_fluxes = []
for i, (time, band) in enumerate(zip(times, bands)):
flux = source.bandflux(band, time, zp=zps[i], zpsys='ab')
model_fluxes.append(flux)
model_fluxes = np.array(model_fluxes)
chi2 = np.sum((fluxes - model_fluxes)**2 / fluxerrs**2)
return -0.5 * chi2
JAX-bandflux:
# Setup (create bridges)
source = SALT3Source()
unique_bands = ['bessellb', 'bessellv', 'bessellr']
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
for b in unique_bands)
# Example data
times = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
band_names = ['bessellb', 'bessellv', 'bessellr', 'bessellb', 'bessellv', 'bessellr']
band_to_idx = {b: i for i, b in enumerate(unique_bands)}
band_indices = jnp.array([band_to_idx[b] for b in band_names])
fluxes = jnp.ones(6) * 100.0 # Example observed fluxes
fluxerrs = jnp.ones(6) * 5.0 # Example errors
zps = jnp.full(6, 27.5)
z = 0.1
# JIT-compiled likelihood function
@jax.jit
def loglikelihood(log_x0, x1, c):
params = {'x0': 10**log_x0, 'x1': x1, 'c': c}
model_fluxes = source.bandflux(
params, None, times / (1 + z), zp=zps, zpsys='ab',
band_indices=band_indices,
bridges=bridges,
unique_bands=unique_bands
)
chi2 = jnp.sum((fluxes - model_fluxes)**2 / fluxerrs**2)
return -0.5 * chi2
# Evaluate
logL = loglikelihood(-5.0, 0.0, 0.0)
print(f"Log-likelihood: {float(logL):.2f}")
Log-likelihood: -246.91
Key improvements in JAX-bandflux version:
JIT compiled: ~10x faster after warmup
Vectorized: All fluxes calculated at once
GPU ready: Works on GPU with no code changes
Differentiable: Can compute gradients with
jax.grad(loglikelihood)
Summary
Feature |
SNCosmo |
JAX-bandflux |
|---|---|---|
Parameter storage |
Object attributes |
Function arguments |
Bandflux call |
|
|
JIT compilation |
❌ Not supported |
✅ Supported |
GPU acceleration |
❌ Not supported |
✅ Supported |
Automatic differentiation |
❌ Not supported |
✅ Supported |
Precomputed bridges |
❌ Not available |
✅ Optional (~100x faster) |
Numerical accuracy |
Reference |
Within 0.001% |
Both APIs are designed for the same purpose (supernova light curve modeling), but JAX-bandflux trades a slightly different calling convention for significant performance gains and GPU/gradient support.