API Differences from SNCosmo

JAX-bandflux provides a SALT3Source implementation that aims to be as similar as possible to SNCosmo’s API, while accommodating JAX’s requirements for just-in-time (JIT) compilation and GPU acceleration.

Design Philosophy

We’ve designed JAX-bandflux to:

  1. Maintain numerical consistency: Results match SNCosmo within 0.001%

  2. Enable JIT compilation: All code paths are compatible with @jax.jit

  3. Maximize performance: Precomputation strategies for repeated calculations

  4. Minimize API changes: Keep the interface as familiar as possible

Key Differences

1. Functional Parameter API

The most significant difference is how parameters are handled.

SNCosmo approach (stateful):

import sncosmo

# Parameters stored in the source object
source = sncosmo.get_source('salt3-nir')
source['z'] = 0.1
source['t0'] = 58650.0
source['x0'] = 1e-5
source['x1'] = 0.0
source['c'] = 0.0

# Bandflux uses stored parameters
flux = source.bandflux('bessellb', 0.0, zp=27.5, zpsys='ab')

JAX-bandflux approach (functional):

# Parameters passed as dictionary argument
source = SALT3Source()
params = {
    'x0': 1e-5,
    'x1': 0.0,
    'c': 0.0
}

# Bandflux receives parameters as argument
flux = source.bandflux(params, 'bessellb', 0.0, zp=27.5, zpsys='ab')
print(f"Flux: {float(flux):.4e}")
Flux: 6.2422e+01

Why this difference?

JAX’s JIT compiler cannot handle mutable object state. By passing parameters as function arguments rather than storing them as attributes, we enable:

  • JIT compilation of the entire likelihood function

  • Automatic differentiation through the model

  • GPU acceleration without state management issues

The functional API is a requirement for JAX compatibility, not a design choice.

2. Precomputed Bridges for Performance

SNCosmo approach (bandpass by name):

import sncosmo

source = sncosmo.get_source('salt3-nir')
source.set(z=0.1, t0=58650.0, x0=1e-5, x1=0.0, c=0.0)

# Bandpass loaded/interpolated each time
for i in range(100000):  # Nested sampling
    flux = source.bandflux('bessellb', phases[i], zp=27.5, zpsys='ab')

JAX-bandflux approach (with precomputed bridges):

# Pre-compute bridges ONCE
unique_bands = ['bessellb', 'bessellv', 'bessellr']
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
                for b in unique_bands)

# Create test data
phases = jnp.linspace(-10, 30, 20)
band_names = ['bessellb'] * 7 + ['bessellv'] * 7 + ['bessellr'] * 6
band_to_idx = {b: i for i, b in enumerate(unique_bands)}
band_indices = jnp.array([band_to_idx[b] for b in band_names])
zps = jnp.full(len(phases), 27.5)

# Fast calculation using bridges
params = {'x0': 1e-5, 'x1': 0.0, 'c': 0.0}
fluxes = source.bandflux(
    params, None, phases, zp=zps, zpsys='ab',
    band_indices=band_indices,
    bridges=bridges,
    unique_bands=unique_bands
)
print(f"Computed {len(fluxes)} fluxes using bridges")
Computed 20 fluxes using bridges

What are bridges?

Bridges are precomputed data structures containing:

  • wave: Integration wavelength grid (e.g., [3622.5, 3627.5, …, 5617.5] Å)

  • dwave: Grid spacing (e.g., 5.0 Å)

  • trans: Precomputed transmission values on the grid

Why bridges?

For nested sampling or MCMC, you may evaluate the likelihood 100,000+ times. Without bridges:

  • Each evaluation: Load filter file → Create grid → Interpolate transmission → Integrate

  • Total: 100,000 × (file I/O + interpolation + integration)

With bridges (precomputed once):

  • Setup: Load filter files → Create grids → Store in bridges

  • Each evaluation: Lookup precomputed grid → Integrate

  • Total: 1 × (file I/O + interpolation) + 100,000 × integration

Speedup: ~100x faster

Performance Comparison

The following table shows performance for a typical nested sampling run:

Configuration

Time per likelihood call

100k iterations

SNCosmo

~10 ms

~16 hours

JAX-bandflux (no JIT)

~8 ms

~13 hours

JAX-bandflux (JIT)

~1 ms

~1.6 hours

JAX-bandflux (bridges)

~0.1 ms

~10 minutes

Migration Guide

Converting SNCosmo code to JAX-bandflux:

Step 1: Change parameter assignment

# OLD (SNCosmo)
source['x0'] = 1e-5
source['x1'] = 0.0

# NEW (JAX-bandflux)
params = {'x0': 1e-5, 'x1': 0.0, 'c': 0.0}

Step 2: Pass parameters to bandflux

# OLD (SNCosmo)
flux = source.bandflux('bessellb', phase)

# NEW (JAX-bandflux)
flux = source.bandflux(params, 'bessellb', phase)

Step 3: (Optional) Use bridges for performance

# Pre-compute bridges
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
                for b in unique_bands)

# Use bridges in bandflux
flux = source.bandflux(
    params, None, phases, zp=zps, zpsys='ab',
    band_indices=band_indices,
    bridges=bridges,
    unique_bands=unique_bands
)
print(f"Mean flux: {float(jnp.mean(flux)):.2e}")
Mean flux: 3.91e+01

Numerical Consistency

Despite the API differences, JAX-bandflux maintains numerical consistency with SNCosmo:

  • Model components (M0, M1, color law): Match to machine precision with x64 enabled; float32 remains close but may exceed machine-precision-level tolerances

  • Integration grids: Identical 5.0 Å spacing

  • Bandflux values: Match within 0.001% (limited by interpolation differences)

Our comprehensive test suite (tests/test_salt3nir_consistency.py) verifies:

  • ✓ Component-level agreement (M0, M1, colorlaw)

  • ✓ Single bandflux calculations

  • ✓ Array-valued phases and bands

  • ✓ Zeropoint scaling

  • ✓ Multi-band light curves

Example: Full Workflow Comparison

SNCosmo:

import sncosmo
import numpy as np

# Setup
source = sncosmo.get_source('salt3-nir')
source.set(z=0.1, t0=58650.0, x0=1e-5, x1=0.0, c=0.0)

# Likelihood function
def loglikelihood(params):
    source.set(x0=10**params[0], x1=params[1], c=params[2])
    model_fluxes = []
    for i, (time, band) in enumerate(zip(times, bands)):
        flux = source.bandflux(band, time, zp=zps[i], zpsys='ab')
        model_fluxes.append(flux)
    model_fluxes = np.array(model_fluxes)
    chi2 = np.sum((fluxes - model_fluxes)**2 / fluxerrs**2)
    return -0.5 * chi2

JAX-bandflux:

# Setup (create bridges)
source = SALT3Source()
unique_bands = ['bessellb', 'bessellv', 'bessellr']
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
                for b in unique_bands)

# Example data
times = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
band_names = ['bessellb', 'bessellv', 'bessellr', 'bessellb', 'bessellv', 'bessellr']
band_to_idx = {b: i for i, b in enumerate(unique_bands)}
band_indices = jnp.array([band_to_idx[b] for b in band_names])
fluxes = jnp.ones(6) * 100.0  # Example observed fluxes
fluxerrs = jnp.ones(6) * 5.0  # Example errors
zps = jnp.full(6, 27.5)
z = 0.1

# JIT-compiled likelihood function
@jax.jit
def loglikelihood(log_x0, x1, c):
    params = {'x0': 10**log_x0, 'x1': x1, 'c': c}
    model_fluxes = source.bandflux(
        params, None, times / (1 + z), zp=zps, zpsys='ab',
        band_indices=band_indices,
        bridges=bridges,
        unique_bands=unique_bands
    )
    chi2 = jnp.sum((fluxes - model_fluxes)**2 / fluxerrs**2)
    return -0.5 * chi2

# Evaluate
logL = loglikelihood(-5.0, 0.0, 0.0)
print(f"Log-likelihood: {float(logL):.2f}")
Log-likelihood: -246.91

Key improvements in JAX-bandflux version:

  1. JIT compiled: ~10x faster after warmup

  2. Vectorized: All fluxes calculated at once

  3. GPU ready: Works on GPU with no code changes

  4. Differentiable: Can compute gradients with jax.grad(loglikelihood)

Summary

Feature

SNCosmo

JAX-bandflux

Parameter storage

Object attributes

Function arguments

Bandflux call

source.bandflux(band)

source.bandflux(params, band)

JIT compilation

❌ Not supported

✅ Supported

GPU acceleration

❌ Not supported

✅ Supported

Automatic differentiation

❌ Not supported

✅ Supported

Precomputed bridges

❌ Not available

✅ Optional (~100x faster)

Numerical accuracy

Reference

Within 0.001%

Both APIs are designed for the same purpose (supernova light curve modeling), but JAX-bandflux trades a slightly different calling convention for significant performance gains and GPU/gradient support.