API Reference
Core classes and functions are documented here via autodoc. These are the primary entry points for users; internal helper modules are omitted.
Sources
Source classes for JAX-bandflux providing v3.0 functional API.
This module provides source models for supernova light curve fitting: - SALT3Source: SALT3-NIR model with stretch and colour parameters - TimeSeriesSource: Generic spectral time series model (like sncosmo.TimeSeriesSource)
Both use functional API where parameters are passed as dictionaries to methods rather than stored in the source object.
- class jax_supernovae.source.SALT3Source(name='salt3-nir')[source]
Bases:
objectSALT3-NIR supernova source model with functional API.
This class provides a v3.0 functional API where model parameters are passed as arguments to methods rather than being stored as instance attributes.
- Parameters:
name (str, optional) – Model name (default: ‘salt3-nir’)
Examples
Basic usage with string band names:
source = SALT3Source(name='salt3-nir') params = {'x0': 1e-5, 'x1': 0.0, 'c': 0.0} flux = source.bandflux(params, 'bessellb', 0.0, zp=27.5, zpsys='ab')
High-performance usage with precomputed data (for nested sampling):
from jax_supernovae.data import load_and_process_data times, fluxes, fluxerrs, zps, band_indices, bands, bridges, fixed_z = load_and_process_data('19dwz', fix_z=True) source = SALT3Source() # Inside likelihood function: t0 = 58650.0 z = fixed_z[0] phases = (times - t0) / (1 + z) band_names = [bands[i] for i in band_indices] fluxes = source.bandflux(params, band_names, phases, zp=zps, zpsys='ab', band_indices=band_indices, bridges=bridges, unique_bands=bands)
- property param_names
List of SALT3 model parameter names for the functional API.
Returns the core SALT3 parameters that are passed via the params dict in the v3.0 functional API. Note that z and t0 are handled externally via phase calculations (phase = (time - t0) / (1 + z)).
- bandflux(params, bands, phases, zp=None, zpsys=None, band_indices=None, bridges=None, unique_bands=None, shifts=None)[source]
Calculate bandflux using the optimized multiband path only.
- Parameters:
params (dict) – Model parameters. Must contain: - ‘x0’: float - Amplitude parameter - ‘x1’: float - Stretch parameter - ‘c’: float - Color parameter May optionally contain: - ‘z’: float - Redshift (default: 0.0) - ‘t0’: float - Time of peak brightness (default: 0.0)
bands (str or array-like) – Bandpass name(s). Can be a string for a single band or array of band names. If band_indices/unique_bands are provided, bands can be None.
phases (float or array) – Rest-frame phase(s) relative to t0.
zp (float or array, optional) – Zero point(s) for flux scaling (per observation).
zpsys (str, optional) – Zero point system (‘ab’ currently supported).
band_indices (array, optional) – Indices into unique_bands/bridges arrays (for high-performance path).
bridges (tuple, optional) – Precomputed bridge data structures keyed to unique_bands.
unique_bands (list, optional) – List of unique band names corresponding to bridges.
shifts (array or list, optional) – Wavelength shifts in Angstroms for each unique band.
- Returns:
flux – Bandflux value(s) with shape matching the requested bands/phases
- Return type:
array
- bandflux_batch(params, bands, phases, zp=None, zpsys=None, band_indices=None, bridges=None, unique_bands=None, shifts=None)[source]
Batched bandflux evaluation over multiple parameter sets.
All core params (x0, x1, c) and optional z, t0 must be 1D arrays of the same length.
- bandmag(params, bands, magsys, phases, band_indices=None, bridges=None, unique_bands=None)[source]
Calculate magnitude using v3.0 functional API.
- Parameters:
params (dict) – Model parameters (x0, x1, c, optionally z and t0)
bands (str or array-like) – Bandpass name(s)
magsys (str) – Magnitude system (‘ab’ or ‘vega’)
phases (float or array) – Rest-frame phase(s)
band_indices (array, optional) – For performance: indices into unique_bands/bridges arrays
bridges (tuple, optional) – For performance: precomputed bridge data structures
unique_bands (list, optional) – For performance: list of unique band names
- Returns:
mag – Magnitude value(s)
- Return type:
float or array
Notes
Magnitude is calculated as -2.5 * log10(flux/zp0)
- class jax_supernovae.source.TimeSeriesSource(phase, wave, flux, zero_before=False, time_spline_degree=3, name=None, version=None)[source]
Bases:
objectJAX implementation of custom SED time series source.
Matches sncosmo.TimeSeriesSource API with functional parameter passing. Enables fitting arbitrary spectral time series models on GPU with JAX.
This class provides a flexible interface for fitting custom supernova models defined by a 2D grid of flux values across phase and wavelength. It uses bicubic interpolation (matching sncosmo) and supports both simple usage and high-performance modes for MCMC/nested sampling.
- Parameters:
phase (array_like) – 1D array of phase values (days) defining the model grid. Must be sorted in ascending order. Shape (n_phase,)
wave (array_like) – 1D array of wavelength values (Angstroms) defining the model grid. Must be sorted in ascending order. Shape (n_wave,)
flux (array_like) – 2D array of flux values (erg/s/cm²/Å) with shape (n_phase, n_wave). flux[i, j] is the flux at phase[i] and wavelength wave[j].
zero_before (bool, optional) – If True, flux is zero for phases before minphase. If False, extrapolates using edge values. Default is False.
time_spline_degree (int, optional) – Degree of interpolation in time direction. 1 for linear, 3 for cubic. Default is 3 (matches sncosmo default).
name (str, optional) – Name for this source model.
version (str, optional) – Version identifier for this source model.
Examples
Basic usage:
import numpy as np from jax_supernovae import TimeSeriesSource # Create simple Gaussian model phase = np.linspace(-20, 50, 100) wave = np.linspace(3000, 9000, 200) # Gaussian in time and wavelength p_grid, w_grid = np.meshgrid(phase, wave, indexing='ij') flux = np.exp(-0.5 * (p_grid/10)**2) * np.exp(-0.5 * ((w_grid-5000)/1000)**2) flux *= 1e-15 # Scale to realistic flux levels source = TimeSeriesSource(phase, wave, flux) # Calculate bandflux (functional API) params = {'amplitude': 1.0} flux_b = source.bandflux(params, 'bessellb', 0.0, zp=25.0, zpsys='ab')
Notes
Uses functional API: parameters passed to methods, not stored
Compatible with JAX JIT compilation and GPU acceleration
Bicubic interpolation in 2D (phase and wavelength)
Matches sncosmo numerical results to ~0.01%
- property param_names
List of model parameter names.
- bandflux(params, bands, phases, zp=None, zpsys=None, band_indices=None, bridges=None, unique_bands=None)[source]
Calculate bandflux using functional API.
- Parameters:
params (dict) – Parameter dictionary. Must contain ‘amplitude’.
bands (str, list, or None) – Bandpass name(s). Use None in optimised mode with bridges.
phases (float or array_like) – Rest-frame phase(s) at which to evaluate flux (days).
zp (float or array_like, optional) – Zero point(s). If provided, zpsys must also be given.
zpsys (str, optional) – Zero point system (e.g., ‘ab’). Required if zp is provided.
band_indices (array_like, optional) – (Optimised mode) Integer indices mapping observations to unique_bands.
bridges (tuple of dict, optional) – (Optimised mode) Pre-computed bandpass bridges.
unique_bands (list, optional) – (Optimised mode) List of unique band names corresponding to bridges.
- Returns:
Bandflux value(s). Shape matches input phases.
- Return type:
float or jnp.array
- bandmag(params, bands, magsys, phases, band_indices=None, bridges=None, unique_bands=None)[source]
Calculate magnitude using functional API.
- Parameters:
params (dict) – Model parameters. Must contain ‘amplitude’.
bands (str or array-like) – Bandpass name(s)
magsys (str) – Magnitude system (e.g., ‘ab’)
phases (float or array) – Rest-frame phase(s)
band_indices (optional) – For high-performance mode
bridges (optional) – For high-performance mode
unique_bands (optional) – For high-performance mode
- Returns:
mag – Magnitude value(s). Returns NaN for flux ≤ 0.
- Return type:
float or array
Bandpasses
Bandpass handling for JAX supernova models.
- class jax_supernovae.bandpasses.Bandpass(wave, trans, integration_spacing=5.0, name=None)[source]
Bases:
objectBandpass filter class.
- __call__(wave, shift=0.0)[source]
Get interpolated transmission at given wavelengths with optional shift.
- Parameters:
wave (array_like) – Wavelengths at which to evaluate transmission
shift (float, optional) – Constant wavelength shift to apply (in Angstroms)
- property name
Optional human-readable bandpass name.
- property wave
Get wavelength array.
- property trans
Get transmission array.
- property integration_wave
Get pre-computed integration wavelength grid.
- property integration_spacing
Get integration grid spacing.
- jax_supernovae.bandpasses.register_bandpass(name, bandpass, force=False)[source]
Register a bandpass with a given name.
- Parameters:
- Return type:
None
- Raises:
ValueError – If a bandpass with the given name already exists and force=False
- jax_supernovae.bandpasses.get_bandpass(name)[source]
Get a bandpass from the registry.
- Parameters:
name (str or Bandpass) – Name of the bandpass or a Bandpass object
- Returns:
bandpass – The requested bandpass
- Return type:
Notes
Bandpasses must be registered before use. Common bands (Bessell, SDSS, etc.) are automatically registered when SALT3Source is initialized. For custom bands, use register_bandpass() or register_all_bandpasses() before JIT compilation.
- jax_supernovae.bandpasses.load_custom_bandpasses(bandpass_files)[source]
Load custom bandpasses from a list of file paths.
- jax_supernovae.bandpasses.register_all_bandpasses(custom_bandpass_files=None, svo_filters=None)[source]
Register bandpasses in JAX and return dictionaries of bandpasses and bridges.
- Parameters:
custom_bandpass_files (list or dict, optional) – List of file paths to custom bandpass files, or a dictionary mapping bandpass names to file paths
svo_filters (list, optional) – List of dictionaries containing SVO filter information. Each dictionary should have the following keys: - ‘name’: Name to register the bandpass under - ‘filter_id’: SVO filter identifier (e.g., ‘UKIRT/WFCAM.J’) - ‘variants’: Optional list of variant names to register using the same bandpass
- Returns:
A tuple containing: - bandpass_dict: Dictionary mapping bandpass names to Bandpass objects - bridges_dict: Dictionary mapping bandpass names to precomputed bridge data
- Return type:
Data utilities
Data loading and processing utilities for JAX supernova models.
- jax_supernovae.data.load_hsf_data(object_name, base_dir='data')[source]
Load HSF data for a given object.
- Parameters:
- Returns:
Table containing the processed data with columns: - time: observation times (from mjd) - band: filter/band names - flux: flux measurements - fluxerr: flux measurement errors - zp: zero points (defaults to 27.5 if not present)
- Return type:
astropy.table.Table
- Raises:
FileNotFoundError – If no data file is found for the given object
ValueError – If required columns are missing from the data file
- jax_supernovae.data.load_redshift(object_name, redshift_file='data/redshifts.dat', targets_file='data/targets.dat')[source]
Load redshift for a given object.
First tries redshifts.dat (high-quality spectroscopic redshifts), then falls back to targets.dat if object not found.
- Parameters:
- Returns:
(redshift, redshift_err, flag) where: - redshift is the heliocentric redshift - redshift_err is the symmetric error (max of plus/minus for redshifts.dat,
or 0.001 default for targets.dat)
flag is the reliability flag (‘s’=strong, ‘w’=weak, ‘n’=no features, or ‘spu’ from targets.dat)
- Return type:
- Raises:
FileNotFoundError – If neither redshift file nor targets file found
ValueError – If object not found in either file
- jax_supernovae.data.load_and_process_data(sn_name, data_dir='data', fix_z=False)[source]
Load and process supernova data, including bandpass registration and data array setup.
- Parameters:
- Returns:
Contains processed data arrays and bridges: - times (jnp.array): Observation times - fluxes (jnp.array): Flux measurements - fluxerrs (jnp.array): Flux measurement errors - zps (jnp.array): Zero points - band_indices (jnp.array): Band indices - unique_bands (list): List of unique band names - bridges (tuple): Precomputed bridge data for each band - fixed_z (tuple or None): If fix_z is True, returns (z, z_err), else None
- Return type:
Salt3 helpers
SALT3-NIR model implementation in JAX.
- jax_supernovae.salt3.precompute_bandflux_bridge(bandpass)[source]
Precompute static components for a given bandpass.
- Parameters:
bandpass (Bandpass) – Bandpass object to precompute components for
- Returns:
Dictionary containing: - ‘wave’: the integration wavelength grid - ‘dwave’: spacing between grid points - ‘trans’: the transmission values computed on the grid - ‘wave_original’: original wavelength array for shift interpolation - ‘trans_original’: original transmission array - ‘zpbandflux_ab’: AB zeropoint normalization for this band
- Return type:
- jax_supernovae.salt3.optimized_salt3_bandflux(phase, wave, dwave, trans, params, zp=None, zpsys=None, shift=0.0, wave_original=None, trans_original=None)[source]
Calculate bandflux for a single bandpass using precomputed static data.
- Parameters:
phase (array or scalar) – Observer-frame phase(s) at which to compute the flux
wave (array) – Wavelength grid for integration
dwave (float) – Spacing between wavelength grid points
trans (array) – Transmission values on the wavelength grid (used if shift=0)
params (dict) – Dictionary containing model parameters: ‘z’, ‘t0’, ‘x0’, ‘x1’, ‘c’ Optional dust parameters: - ‘dust_type’: int, dust law index (0=ccm89, 1=od94, 2=f99) - ‘ebv’: float, E(B-V) value - ‘r_v’: float, R_V value (default: 3.1)
zp (float or None, optional) – Zero point for flux scaling
zpsys (str or None, optional) – Magnitude system (e.g. ‘ab’)
shift (float, optional) – Constant wavelength shift to apply to transmission curve (in Angstroms)
wave_original (array, optional) – Original wavelength array (required if shift != 0)
trans_original (array, optional) – Original transmission array (required if shift != 0)
- Returns:
Flux in photons/s/cm^2
- Return type:
float or array
- jax_supernovae.salt3.optimized_salt3_multiband_flux(phase, bridges, params, zps=None, zpsys=None, shifts=None)[source]
Calculate fluxes for multiple bandpasses with transmission shifts.
- Parameters:
phase (array) – Observer-frame phases
bridges (list of dict) – Precomputed bridge data for each bandpass
params (dict) – Model parameters
zps (list or array or None, optional) – Zero points for each bandpass
zpsys (str or None, optional) – Magnitude system
shifts (list or array or None, optional) – Constant wavelength shifts for each bandpass (in Angstroms)
- Returns:
Array of flux values for each phase and band
- Return type:
array