API Reference

Core classes and functions are documented here via autodoc. These are the primary entry points for users; internal helper modules are omitted.

Sources

Source classes for JAX-bandflux providing v3.0 functional API.

This module provides source models for supernova light curve fitting: - SALT3Source: SALT3-NIR model with stretch and colour parameters - TimeSeriesSource: Generic spectral time series model (like sncosmo.TimeSeriesSource)

Both use functional API where parameters are passed as dictionaries to methods rather than stored in the source object.

class jax_supernovae.source.SALT3Source(name='salt3-nir')[source]

Bases: object

SALT3-NIR supernova source model with functional API.

This class provides a v3.0 functional API where model parameters are passed as arguments to methods rather than being stored as instance attributes.

Parameters:

name (str, optional) – Model name (default: ‘salt3-nir’)

Examples

Basic usage with string band names:

source = SALT3Source(name='salt3-nir')
params = {'x0': 1e-5, 'x1': 0.0, 'c': 0.0}
flux = source.bandflux(params, 'bessellb', 0.0, zp=27.5, zpsys='ab')

High-performance usage with precomputed data (for nested sampling):

from jax_supernovae.data import load_and_process_data
times, fluxes, fluxerrs, zps, band_indices, bands, bridges, fixed_z =             load_and_process_data('19dwz', fix_z=True)
source = SALT3Source()
# Inside likelihood function:
t0 = 58650.0
z = fixed_z[0]
phases = (times - t0) / (1 + z)
band_names = [bands[i] for i in band_indices]
fluxes = source.bandflux(params, band_names, phases, zp=zps, zpsys='ab',
                         band_indices=band_indices, bridges=bridges,
                         unique_bands=bands)
property param_names

List of SALT3 model parameter names for the functional API.

Returns the core SALT3 parameters that are passed via the params dict in the v3.0 functional API. Note that z and t0 are handled externally via phase calculations (phase = (time - t0) / (1 + z)).

minphase()[source]

Minimum phase for which the model is defined.

maxphase()[source]

Maximum phase for which the model is defined.

minwave()[source]

Minimum wavelength for which the model is defined (Angstroms).

maxwave()[source]

Maximum wavelength for which the model is defined (Angstroms).

__str__()[source]

String representation.

__repr__()[source]

Official string representation.

bandflux(params, bands, phases, zp=None, zpsys=None, band_indices=None, bridges=None, unique_bands=None, shifts=None)[source]

Calculate bandflux using the optimized multiband path only.

Parameters:
  • params (dict) – Model parameters. Must contain: - ‘x0’: float - Amplitude parameter - ‘x1’: float - Stretch parameter - ‘c’: float - Color parameter May optionally contain: - ‘z’: float - Redshift (default: 0.0) - ‘t0’: float - Time of peak brightness (default: 0.0)

  • bands (str or array-like) – Bandpass name(s). Can be a string for a single band or array of band names. If band_indices/unique_bands are provided, bands can be None.

  • phases (float or array) – Rest-frame phase(s) relative to t0.

  • zp (float or array, optional) – Zero point(s) for flux scaling (per observation).

  • zpsys (str, optional) – Zero point system (‘ab’ currently supported).

  • band_indices (array, optional) – Indices into unique_bands/bridges arrays (for high-performance path).

  • bridges (tuple, optional) – Precomputed bridge data structures keyed to unique_bands.

  • unique_bands (list, optional) – List of unique band names corresponding to bridges.

  • shifts (array or list, optional) – Wavelength shifts in Angstroms for each unique band.

Returns:

flux – Bandflux value(s) with shape matching the requested bands/phases

Return type:

array

bandflux_batch(params, bands, phases, zp=None, zpsys=None, band_indices=None, bridges=None, unique_bands=None, shifts=None)[source]

Batched bandflux evaluation over multiple parameter sets.

All core params (x0, x1, c) and optional z, t0 must be 1D arrays of the same length.

bandmag(params, bands, magsys, phases, band_indices=None, bridges=None, unique_bands=None)[source]

Calculate magnitude using v3.0 functional API.

Parameters:
  • params (dict) – Model parameters (x0, x1, c, optionally z and t0)

  • bands (str or array-like) – Bandpass name(s)

  • magsys (str) – Magnitude system (‘ab’ or ‘vega’)

  • phases (float or array) – Rest-frame phase(s)

  • band_indices (array, optional) – For performance: indices into unique_bands/bridges arrays

  • bridges (tuple, optional) – For performance: precomputed bridge data structures

  • unique_bands (list, optional) – For performance: list of unique band names

Returns:

mag – Magnitude value(s)

Return type:

float or array

Notes

Magnitude is calculated as -2.5 * log10(flux/zp0)

class jax_supernovae.source.TimeSeriesSource(phase, wave, flux, zero_before=False, time_spline_degree=3, name=None, version=None)[source]

Bases: object

JAX implementation of custom SED time series source.

Matches sncosmo.TimeSeriesSource API with functional parameter passing. Enables fitting arbitrary spectral time series models on GPU with JAX.

This class provides a flexible interface for fitting custom supernova models defined by a 2D grid of flux values across phase and wavelength. It uses bicubic interpolation (matching sncosmo) and supports both simple usage and high-performance modes for MCMC/nested sampling.

Parameters:
  • phase (array_like) – 1D array of phase values (days) defining the model grid. Must be sorted in ascending order. Shape (n_phase,)

  • wave (array_like) – 1D array of wavelength values (Angstroms) defining the model grid. Must be sorted in ascending order. Shape (n_wave,)

  • flux (array_like) – 2D array of flux values (erg/s/cm²/Å) with shape (n_phase, n_wave). flux[i, j] is the flux at phase[i] and wavelength wave[j].

  • zero_before (bool, optional) – If True, flux is zero for phases before minphase. If False, extrapolates using edge values. Default is False.

  • time_spline_degree (int, optional) – Degree of interpolation in time direction. 1 for linear, 3 for cubic. Default is 3 (matches sncosmo default).

  • name (str, optional) – Name for this source model.

  • version (str, optional) – Version identifier for this source model.

Examples

Basic usage:

import numpy as np
from jax_supernovae import TimeSeriesSource

# Create simple Gaussian model
phase = np.linspace(-20, 50, 100)
wave = np.linspace(3000, 9000, 200)
# Gaussian in time and wavelength
p_grid, w_grid = np.meshgrid(phase, wave, indexing='ij')
flux = np.exp(-0.5 * (p_grid/10)**2) * np.exp(-0.5 * ((w_grid-5000)/1000)**2)
flux *= 1e-15  # Scale to realistic flux levels

source = TimeSeriesSource(phase, wave, flux)

# Calculate bandflux (functional API)
params = {'amplitude': 1.0}
flux_b = source.bandflux(params, 'bessellb', 0.0, zp=25.0, zpsys='ab')

Notes

  • Uses functional API: parameters passed to methods, not stored

  • Compatible with JAX JIT compilation and GPU acceleration

  • Bicubic interpolation in 2D (phase and wavelength)

  • Matches sncosmo numerical results to ~0.01%

property param_names

List of model parameter names.

minphase()[source]

Minimum phase of model.

maxphase()[source]

Maximum phase of model.

minwave()[source]

Minimum wavelength of model.

maxwave()[source]

Maximum wavelength of model.

__str__()[source]

String representation.

__repr__()[source]

Official string representation.

bandflux(params, bands, phases, zp=None, zpsys=None, band_indices=None, bridges=None, unique_bands=None)[source]

Calculate bandflux using functional API.

Parameters:
  • params (dict) – Parameter dictionary. Must contain ‘amplitude’.

  • bands (str, list, or None) – Bandpass name(s). Use None in optimised mode with bridges.

  • phases (float or array_like) – Rest-frame phase(s) at which to evaluate flux (days).

  • zp (float or array_like, optional) – Zero point(s). If provided, zpsys must also be given.

  • zpsys (str, optional) – Zero point system (e.g., ‘ab’). Required if zp is provided.

  • band_indices (array_like, optional) – (Optimised mode) Integer indices mapping observations to unique_bands.

  • bridges (tuple of dict, optional) – (Optimised mode) Pre-computed bandpass bridges.

  • unique_bands (list, optional) – (Optimised mode) List of unique band names corresponding to bridges.

Returns:

Bandflux value(s). Shape matches input phases.

Return type:

float or jnp.array

bandmag(params, bands, magsys, phases, band_indices=None, bridges=None, unique_bands=None)[source]

Calculate magnitude using functional API.

Parameters:
  • params (dict) – Model parameters. Must contain ‘amplitude’.

  • bands (str or array-like) – Bandpass name(s)

  • magsys (str) – Magnitude system (e.g., ‘ab’)

  • phases (float or array) – Rest-frame phase(s)

  • band_indices (optional) – For high-performance mode

  • bridges (optional) – For high-performance mode

  • unique_bands (optional) – For high-performance mode

Returns:

mag – Magnitude value(s). Returns NaN for flux ≤ 0.

Return type:

float or array

Bandpasses

Bandpass handling for JAX supernova models.

class jax_supernovae.bandpasses.Bandpass(wave, trans, integration_spacing=5.0, name=None)[source]

Bases: object

Bandpass filter class.

__call__(wave, shift=0.0)[source]

Get interpolated transmission at given wavelengths with optional shift.

Parameters:
  • wave (array_like) – Wavelengths at which to evaluate transmission

  • shift (float, optional) – Constant wavelength shift to apply (in Angstroms)

property name

Optional human-readable bandpass name.

minwave()[source]

Get minimum wavelength.

maxwave()[source]

Get maximum wavelength.

property wave

Get wavelength array.

property trans

Get transmission array.

property integration_wave

Get pre-computed integration wavelength grid.

property integration_spacing

Get integration grid spacing.

jax_supernovae.bandpasses.load_bandpass(band)[source]

Load a bandpass from file.

Parameters:

band (str) – Name of the bandpass to load

Returns:

bandpass – A Bandpass object containing the filter transmission curve.

Return type:

Bandpass

jax_supernovae.bandpasses.register_bandpass(name, bandpass, force=False)[source]

Register a bandpass with a given name.

Parameters:
  • name (str) – Name to register the bandpass under

  • bandpass (Bandpass) – Bandpass object to register

  • force (bool, optional) – If True, overwrite any existing bandpass with the same name

Return type:

None

Raises:

ValueError – If a bandpass with the given name already exists and force=False

jax_supernovae.bandpasses.get_bandpass(name)[source]

Get a bandpass from the registry.

Parameters:

name (str or Bandpass) – Name of the bandpass or a Bandpass object

Returns:

bandpass – The requested bandpass

Return type:

Bandpass

Notes

Bandpasses must be registered before use. Common bands (Bessell, SDSS, etc.) are automatically registered when SALT3Source is initialized. For custom bands, use register_bandpass() or register_all_bandpasses() before JIT compilation.

jax_supernovae.bandpasses.load_custom_bandpasses(bandpass_files)[source]

Load custom bandpasses from a list of file paths.

Parameters:

bandpass_files (list of str or dict) – List of file paths to bandpass files, or a dictionary mapping bandpass names to file paths

Returns:

Dictionary mapping bandpass names to Bandpass objects

Return type:

dict

jax_supernovae.bandpasses.register_all_bandpasses(custom_bandpass_files=None, svo_filters=None)[source]

Register bandpasses in JAX and return dictionaries of bandpasses and bridges.

Parameters:
  • custom_bandpass_files (list or dict, optional) – List of file paths to custom bandpass files, or a dictionary mapping bandpass names to file paths

  • svo_filters (list, optional) – List of dictionaries containing SVO filter information. Each dictionary should have the following keys: - ‘name’: Name to register the bandpass under - ‘filter_id’: SVO filter identifier (e.g., ‘UKIRT/WFCAM.J’) - ‘variants’: Optional list of variant names to register using the same bandpass

Returns:

A tuple containing: - bandpass_dict: Dictionary mapping bandpass names to Bandpass objects - bridges_dict: Dictionary mapping bandpass names to precomputed bridge data

Return type:

tuple

Data utilities

Data loading and processing utilities for JAX supernova models.

jax_supernovae.data.load_hsf_data(object_name, base_dir='data')[source]

Load HSF data for a given object.

Parameters:
  • object_name (str) – Name of the object (e.g., ‘19agl’)

  • base_dir (str) – Base directory containing the data files. Defaults to ‘data’. Expected structure is either: - [base_dir]/Ia/[object_name]/all.phot - Or any .dat/.phot file containing the object name

Returns:

Table containing the processed data with columns: - time: observation times (from mjd) - band: filter/band names - flux: flux measurements - fluxerr: flux measurement errors - zp: zero points (defaults to 27.5 if not present)

Return type:

astropy.table.Table

Raises:
  • FileNotFoundError – If no data file is found for the given object

  • ValueError – If required columns are missing from the data file

jax_supernovae.data.load_redshift(object_name, redshift_file='data/redshifts.dat', targets_file='data/targets.dat')[source]

Load redshift for a given object.

First tries redshifts.dat (high-quality spectroscopic redshifts), then falls back to targets.dat if object not found.

Parameters:
  • object_name (str) – Name of the object (e.g., ‘19agl’)

  • redshift_file (str) – Path to redshifts.dat file

  • targets_file (str) – Path to targets.dat file (fallback)

Returns:

(redshift, redshift_err, flag) where: - redshift is the heliocentric redshift - redshift_err is the symmetric error (max of plus/minus for redshifts.dat,

or 0.001 default for targets.dat)

  • flag is the reliability flag (‘s’=strong, ‘w’=weak, ‘n’=no features, or ‘spu’ from targets.dat)

Return type:

tuple

Raises:
jax_supernovae.data.load_and_process_data(sn_name, data_dir='data', fix_z=False)[source]

Load and process supernova data, including bandpass registration and data array setup.

Parameters:
  • sn_name (str) – Name of the supernova to load (e.g., ‘19agl’)

  • data_dir (str) – Directory containing the data files. Defaults to ‘data’.

  • fix_z (bool) – Whether to fix redshift to value from redshifts.dat

Returns:

Contains processed data arrays and bridges: - times (jnp.array): Observation times - fluxes (jnp.array): Flux measurements - fluxerrs (jnp.array): Flux measurement errors - zps (jnp.array): Zero points - band_indices (jnp.array): Band indices - unique_bands (list): List of unique band names - bridges (tuple): Precomputed bridge data for each band - fixed_z (tuple or None): If fix_z is True, returns (z, z_err), else None

Return type:

tuple

jax_supernovae.data.get_all_supernovae_with_redshifts(redshift_file='data/redshifts.dat')[source]

Get all supernovae that have measured redshifts in redshifts.dat.

Parameters:

redshift_file (str) – Path to redshifts.dat file

Returns:

List of tuples (sn_name, z, z_err, flag) for all supernovae with redshifts

Return type:

list

Salt3 helpers

SALT3-NIR model implementation in JAX.

jax_supernovae.salt3.precompute_bandflux_bridge(bandpass)[source]

Precompute static components for a given bandpass.

Parameters:

bandpass (Bandpass) – Bandpass object to precompute components for

Returns:

Dictionary containing: - ‘wave’: the integration wavelength grid - ‘dwave’: spacing between grid points - ‘trans’: the transmission values computed on the grid - ‘wave_original’: original wavelength array for shift interpolation - ‘trans_original’: original transmission array - ‘zpbandflux_ab’: AB zeropoint normalization for this band

Return type:

dict

jax_supernovae.salt3.optimized_salt3_bandflux(phase, wave, dwave, trans, params, zp=None, zpsys=None, shift=0.0, wave_original=None, trans_original=None)[source]

Calculate bandflux for a single bandpass using precomputed static data.

Parameters:
  • phase (array or scalar) – Observer-frame phase(s) at which to compute the flux

  • wave (array) – Wavelength grid for integration

  • dwave (float) – Spacing between wavelength grid points

  • trans (array) – Transmission values on the wavelength grid (used if shift=0)

  • params (dict) – Dictionary containing model parameters: ‘z’, ‘t0’, ‘x0’, ‘x1’, ‘c’ Optional dust parameters: - ‘dust_type’: int, dust law index (0=ccm89, 1=od94, 2=f99) - ‘ebv’: float, E(B-V) value - ‘r_v’: float, R_V value (default: 3.1)

  • zp (float or None, optional) – Zero point for flux scaling

  • zpsys (str or None, optional) – Magnitude system (e.g. ‘ab’)

  • shift (float, optional) – Constant wavelength shift to apply to transmission curve (in Angstroms)

  • wave_original (array, optional) – Original wavelength array (required if shift != 0)

  • trans_original (array, optional) – Original transmission array (required if shift != 0)

Returns:

Flux in photons/s/cm^2

Return type:

float or array

jax_supernovae.salt3.optimized_salt3_multiband_flux(phase, bridges, params, zps=None, zpsys=None, shifts=None)[source]

Calculate fluxes for multiple bandpasses with transmission shifts.

Parameters:
  • phase (array) – Observer-frame phases

  • bridges (list of dict) – Precomputed bridge data for each bandpass

  • params (dict) – Model parameters

  • zps (list or array or None, optional) – Zero points for each bandpass

  • zpsys (str or None, optional) – Magnitude system

  • shifts (list or array or None, optional) – Constant wavelength shifts for each bandpass (in Angstroms)

Returns:

Array of flux values for each phase and band

Return type:

array