"""
Source classes for JAX-bandflux providing v3.0 functional API.
This module provides source models for supernova light curve fitting:
- SALT3Source: SALT3-NIR model with stretch and colour parameters
- TimeSeriesSource: Generic spectral time series model (like sncosmo.TimeSeriesSource)
Both use functional API where parameters are passed as dictionaries to methods
rather than stored in the source object.
"""
import jax
import jax.numpy as jnp
import numpy as np
from jax_supernovae.salt3 import optimized_salt3_multiband_flux, precompute_bandflux_bridge
from jax_supernovae.bandpasses import register_all_bandpasses, get_bandpass
from jax_supernovae.timeseries import (
timeseries_bandflux,
timeseries_multiband_flux
)
[docs]
class SALT3Source:
"""SALT3-NIR supernova source model with functional API.
This class provides a v3.0 functional API where model parameters are passed
as arguments to methods rather than being stored as instance attributes.
Parameters
----------
name : str, optional
Model name (default: 'salt3-nir')
Examples
--------
Basic usage with string band names::
source = SALT3Source(name='salt3-nir')
params = {'x0': 1e-5, 'x1': 0.0, 'c': 0.0}
flux = source.bandflux(params, 'bessellb', 0.0, zp=27.5, zpsys='ab')
High-performance usage with precomputed data (for nested sampling)::
from jax_supernovae.data import load_and_process_data
times, fluxes, fluxerrs, zps, band_indices, bands, bridges, fixed_z = \
load_and_process_data('19dwz', fix_z=True)
source = SALT3Source()
# Inside likelihood function:
t0 = 58650.0
z = fixed_z[0]
phases = (times - t0) / (1 + z)
band_names = [bands[i] for i in band_indices]
fluxes = source.bandflux(params, band_names, phases, zp=zps, zpsys='ab',
band_indices=band_indices, bridges=bridges,
unique_bands=bands)
"""
def __init__(self, name='salt3-nir'):
"""Initialize SALT3 source.
Parameters
----------
name : str
Model name (currently only 'salt3-nir' is supported)
"""
if name != 'salt3-nir':
raise ValueError(f"Only 'salt3-nir' model is supported, got '{name}'")
self.name = name
# Register all bandpasses to ensure they're available
register_all_bandpasses()
# Cache of precomputed bridges keyed by band name to avoid slow rebuilding
self._bridge_cache = {}
# Cache of compiled bandflux functions keyed by band set and zeropoint usage
self._compiled_cache = {}
@property
def param_names(self):
"""List of SALT3 model parameter names for the functional API.
Returns the core SALT3 parameters that are passed via the params dict
in the v3.0 functional API. Note that z and t0 are handled externally
via phase calculations (phase = (time - t0) / (1 + z)).
"""
return ['x0', 'x1', 'c']
[docs]
def minphase(self):
"""Minimum phase for which the model is defined."""
return -20.0
[docs]
def maxphase(self):
"""Maximum phase for which the model is defined."""
return 50.0
[docs]
def minwave(self):
"""Minimum wavelength for which the model is defined (Angstroms)."""
return 2000.0
[docs]
def maxwave(self):
"""Maximum wavelength for which the model is defined (Angstroms)."""
return 20000.0
[docs]
def __str__(self):
"""String representation."""
return f"SALT3Source(name='{self.name}', v3.0 functional API)"
[docs]
def __repr__(self):
"""Official string representation."""
return f"SALT3Source(name='{self.name}')"
def _get_bridges(self, unique_bands):
"""Return cached bridges for the given bands, computing missing ones."""
bridges = []
for band in unique_bands:
if band not in self._bridge_cache:
self._bridge_cache[band] = precompute_bandflux_bridge(get_bandpass(band))
bridges.append(self._bridge_cache[band])
return tuple(bridges)
def _get_compiled_bandflux(self, band_key, bridges, zpsys, apply_zp):
"""Return a jitted bandflux function for a fixed band set."""
cache_key = (band_key, zpsys, apply_zp)
if cache_key in self._compiled_cache:
return self._compiled_cache[cache_key]
if zpsys not in (None, 'ab'):
raise ValueError(f"Unsupported magnitude system: {zpsys}")
band_zp_denoms = jnp.array([bridge['zpbandflux_ab'] for bridge in bridges])
def _fn(phases, band_indices, params, zps, shifts):
flux_matrix = optimized_salt3_multiband_flux(
phases, bridges, params, zps=None, zpsys=zpsys, shifts=shifts
)
gathered = flux_matrix[jnp.arange(len(phases)), band_indices]
if apply_zp:
zp_norms = 10 ** (0.4 * zps)
gathered = gathered * (zp_norms / band_zp_denoms[band_indices])
return gathered
compiled = jax.jit(_fn)
self._compiled_cache[cache_key] = compiled
return compiled
[docs]
def bandflux(self, params, bands, phases, zp=None, zpsys=None,
band_indices=None, bridges=None, unique_bands=None, shifts=None):
"""Calculate bandflux using the optimized multiband path only.
Parameters
----------
params : dict
Model parameters. Must contain:
- 'x0': float - Amplitude parameter
- 'x1': float - Stretch parameter
- 'c': float - Color parameter
May optionally contain:
- 'z': float - Redshift (default: 0.0)
- 't0': float - Time of peak brightness (default: 0.0)
bands : str or array-like
Bandpass name(s). Can be a string for a single band or array of band names.
If `band_indices`/`unique_bands` are provided, `bands` can be None.
phases : float or array
Rest-frame phase(s) relative to t0.
zp : float or array, optional
Zero point(s) for flux scaling (per observation).
zpsys : str, optional
Zero point system ('ab' currently supported).
band_indices : array, optional
Indices into unique_bands/bridges arrays (for high-performance path).
bridges : tuple, optional
Precomputed bridge data structures keyed to unique_bands.
unique_bands : list, optional
List of unique band names corresponding to bridges.
shifts : array or list, optional
Wavelength shifts in Angstroms for each unique band.
Returns
-------
flux : array
Bandflux value(s) with shape matching the requested bands/phases
"""
if 'x0' not in params or 'x1' not in params or 'c' not in params:
raise ValueError("params must contain 'x0', 'x1', and 'c'")
scalar_phase_input = np.isscalar(phases)
scalar_band_input = isinstance(bands, str)
scalar_input = scalar_phase_input and (scalar_band_input or bands is None)
full_params = {
'z': params.get('z', 0.0),
't0': params.get('t0', 0.0),
'x0': params['x0'],
'x1': params['x1'],
'c': params['c']
}
phases_arr = jnp.atleast_1d(jnp.array(phases))
# Resolve band metadata
if band_indices is not None and unique_bands is not None:
unique_bands_list = list(unique_bands)
band_indices_arr = jnp.array(band_indices)
else:
if bands is None:
raise ValueError("bands must be provided when band_indices are not supplied")
bands_arr = np.atleast_1d(np.array(bands))
unique_bands_list = []
band_index_list = []
for band in bands_arr:
if band not in unique_bands_list:
unique_bands_list.append(band)
band_index_list.append(unique_bands_list.index(band))
band_indices_arr = jnp.array(band_index_list)
# Bridges: use provided or cached/computed
bridges_to_use = tuple(bridges) if bridges is not None else self._get_bridges(unique_bands_list)
# Align phases and band indices
phase_len = len(phases_arr)
band_len = len(band_indices_arr)
if phase_len == band_len:
phases_eval = phases_arr
band_indices_eval = band_indices_arr
elif phase_len == 1:
phases_eval = jnp.full(band_len, phases_arr[0])
band_indices_eval = band_indices_arr
elif band_len == 1:
phases_eval = phases_arr
band_indices_eval = jnp.full(phase_len, int(band_indices_arr[0]))
else:
raise ValueError(
f"Incompatible shapes: bands ({band_len}) and phases ({phase_len}). "
"Either must be same length, or one must be length 1."
)
# Handle zeropoints per observation
if zp is not None and zpsys is None:
raise ValueError("zpsys must be provided when zp is specified")
zps_arr = None
if zp is not None:
zps_arr = jnp.atleast_1d(jnp.array(zp))
if len(zps_arr) == 1:
zps_arr = jnp.full(len(phases_eval), zps_arr[0])
elif len(zps_arr) != len(phases_eval):
raise ValueError(f"zp length ({len(zps_arr)}) must match phases length ({len(phases_eval)})")
# Handle wavelength shifts per unique band
shifts_per_band = None
if shifts is not None:
shifts_arr = np.atleast_1d(np.array(shifts))
if len(shifts_arr) == 1:
shifts_per_band = [float(shifts_arr[0])] * len(bridges_to_use)
elif len(shifts_arr) == len(bridges_to_use):
shifts_per_band = [float(s) for s in shifts_arr]
else:
raise ValueError(f"shifts length ({len(shifts_arr)}) must match unique bands ({len(bridges_to_use)})")
band_key = tuple(unique_bands_list)
apply_zp = zps_arr is not None
if zps_arr is None:
zps_arr = jnp.zeros(len(phases_eval))
if shifts_per_band is None:
shifts_array = jnp.zeros(len(bridges_to_use))
else:
shifts_array = jnp.array(shifts_per_band)
# Canonicalize zpsys for caching/static use
zpsys_key = zpsys
if isinstance(zpsys, (list, tuple, np.ndarray)):
if len(zpsys) == 0:
zpsys_key = None
elif all(z == zpsys[0] for z in zpsys):
zpsys_key = zpsys[0]
else:
raise ValueError("Array-valued zpsys with mixed entries is not supported")
compiled_fn = self._get_compiled_bandflux(band_key, bridges_to_use, zpsys_key, apply_zp)
gathered_flux = compiled_fn(phases_eval, band_indices_eval, full_params, zps_arr, shifts_array)
if scalar_input:
return gathered_flux[0]
return gathered_flux
[docs]
def bandflux_batch(self, params, bands, phases, zp=None, zpsys=None,
band_indices=None, bridges=None, unique_bands=None, shifts=None):
"""Batched bandflux evaluation over multiple parameter sets.
All core params (x0, x1, c) and optional z, t0 must be 1D arrays of the same length.
"""
required = ['x0', 'x1', 'c']
for k in required:
if k not in params:
raise ValueError(f"params must contain '{k}' for batched evaluation")
def _as_1d(name, default):
val = params.get(name, default)
arr = jnp.atleast_1d(jnp.asarray(val))
if arr.ndim != 1:
raise ValueError(f"Parameter '{name}' must be 1D for batched evaluation")
return arr
x0_arr = _as_1d('x0', None)
x1_arr = _as_1d('x1', None)
c_arr = _as_1d('c', None)
z_arr = _as_1d('z', 0.0)
t0_arr = _as_1d('t0', 0.0)
batch_size = x0_arr.shape[0]
for name, arr in [('x1', x1_arr), ('c', c_arr), ('z', z_arr), ('t0', t0_arr)]:
if arr.shape[0] != batch_size:
raise ValueError(f"Parameter '{name}' batch size {arr.shape[0]} != {batch_size}")
def single(param_vec):
p = {
'x0': param_vec[0],
'x1': param_vec[1],
'c': param_vec[2],
'z': param_vec[3],
't0': param_vec[4],
}
return self.bandflux(
p, bands, phases, zp=zp, zpsys=zpsys,
band_indices=band_indices, bridges=bridges,
unique_bands=unique_bands, shifts=shifts
)
batched_fn = jax.vmap(single)
params_stack = jnp.stack([x0_arr, x1_arr, c_arr, z_arr, t0_arr], axis=1)
return batched_fn(params_stack)
[docs]
def bandmag(self, params, bands, magsys, phases, band_indices=None,
bridges=None, unique_bands=None):
"""Calculate magnitude using v3.0 functional API.
Parameters
----------
params : dict
Model parameters (x0, x1, c, optionally z and t0)
bands : str or array-like
Bandpass name(s)
magsys : str
Magnitude system ('ab' or 'vega')
phases : float or array
Rest-frame phase(s)
band_indices : array, optional
For performance: indices into unique_bands/bridges arrays
bridges : tuple, optional
For performance: precomputed bridge data structures
unique_bands : list, optional
For performance: list of unique band names
Returns
-------
mag : float or array
Magnitude value(s)
Notes
-----
Magnitude is calculated as -2.5 * log10(flux/zp0)
"""
# Get flux at zeropoint
if magsys == 'ab':
zp = 0.0 # AB magnitudes defined such that zp=0 gives flux in standard units
else:
zp = 0.0 # For now, treat all systems the same way
flux = self.bandflux(params, bands, phases, zp=zp, zpsys=magsys,
band_indices=band_indices, bridges=bridges,
unique_bands=unique_bands)
# Convert to magnitude: m = -2.5 * log10(flux)
# Avoid log of zero/negative values
flux_safe = jnp.where(flux > 0, flux, jnp.nan)
mag = -2.5 * jnp.log10(flux_safe)
return mag
[docs]
class TimeSeriesSource:
"""JAX implementation of custom SED time series source.
Matches sncosmo.TimeSeriesSource API with functional parameter passing.
Enables fitting arbitrary spectral time series models on GPU with JAX.
This class provides a flexible interface for fitting custom supernova models
defined by a 2D grid of flux values across phase and wavelength. It uses
bicubic interpolation (matching sncosmo) and supports both simple usage and
high-performance modes for MCMC/nested sampling.
Parameters
----------
phase : array_like
1D array of phase values (days) defining the model grid.
Must be sorted in ascending order. Shape (n_phase,)
wave : array_like
1D array of wavelength values (Angstroms) defining the model grid.
Must be sorted in ascending order. Shape (n_wave,)
flux : array_like
2D array of flux values (erg/s/cm²/Å) with shape (n_phase, n_wave).
flux[i, j] is the flux at phase[i] and wavelength wave[j].
zero_before : bool, optional
If True, flux is zero for phases before minphase. If False,
extrapolates using edge values. Default is False.
time_spline_degree : int, optional
Degree of interpolation in time direction. 1 for linear, 3 for cubic.
Default is 3 (matches sncosmo default).
name : str, optional
Name for this source model.
version : str, optional
Version identifier for this source model.
Examples
--------
Basic usage::
import numpy as np
from jax_supernovae import TimeSeriesSource
# Create simple Gaussian model
phase = np.linspace(-20, 50, 100)
wave = np.linspace(3000, 9000, 200)
# Gaussian in time and wavelength
p_grid, w_grid = np.meshgrid(phase, wave, indexing='ij')
flux = np.exp(-0.5 * (p_grid/10)**2) * np.exp(-0.5 * ((w_grid-5000)/1000)**2)
flux *= 1e-15 # Scale to realistic flux levels
source = TimeSeriesSource(phase, wave, flux)
# Calculate bandflux (functional API)
params = {'amplitude': 1.0}
flux_b = source.bandflux(params, 'bessellb', 0.0, zp=25.0, zpsys='ab')
Notes
-----
- Uses functional API: parameters passed to methods, not stored
- Compatible with JAX JIT compilation and GPU acceleration
- Bicubic interpolation in 2D (phase and wavelength)
- Matches sncosmo numerical results to ~0.01%
"""
_param_names = ['amplitude']
def __init__(self, phase, wave, flux, zero_before=False,
time_spline_degree=3, name=None, version=None):
"""Initialise TimeSeriesSource."""
# Convert to numpy for validation
phase = np.asarray(phase)
wave = np.asarray(wave)
flux = np.asarray(flux)
# Validate inputs
if phase.ndim != 1:
raise ValueError(f"phase must be 1D array, got shape {phase.shape}")
if wave.ndim != 1:
raise ValueError(f"wave must be 1D array, got shape {wave.shape}")
if flux.ndim != 2:
raise ValueError(f"flux must be 2D array, got shape {flux.shape}")
if flux.shape != (len(phase), len(wave)):
raise ValueError(
f"flux shape {flux.shape} must match (len(phase), len(wave)) = "
f"({len(phase)}, {len(wave)})"
)
# Check grids are sorted
if not np.all(np.diff(phase) > 0):
raise ValueError("phase grid must be sorted in ascending order")
if not np.all(np.diff(wave) > 0):
raise ValueError("wave grid must be sorted in ascending order")
# Validate time_spline_degree
if time_spline_degree not in [1, 3]:
raise ValueError(
f"time_spline_degree must be 1 (linear) or 3 (cubic), "
f"got {time_spline_degree}"
)
# Store metadata
self.name = name
self.version = version
self._zero_before = zero_before
self._time_degree = time_spline_degree
# Convert to JAX arrays
self._phase = jnp.asarray(phase)
self._wave = jnp.asarray(wave)
self._flux = jnp.asarray(flux)
# Cache bounds for quick access
self._minphase = float(phase[0])
self._maxphase = float(phase[-1])
self._minwave = float(wave[0])
self._maxwave = float(wave[-1])
# Register all bandpasses to ensure they're available
register_all_bandpasses()
@property
def param_names(self):
"""List of model parameter names."""
return self._param_names
[docs]
def minphase(self):
"""Minimum phase of model."""
return self._minphase
[docs]
def maxphase(self):
"""Maximum phase of model."""
return self._maxphase
[docs]
def minwave(self):
"""Minimum wavelength of model."""
return self._minwave
[docs]
def maxwave(self):
"""Maximum wavelength of model."""
return self._maxwave
[docs]
def __str__(self):
"""String representation."""
name_str = f"'{self.name}'" if self.name else 'unnamed'
return (f"TimeSeriesSource({name_str}, "
f"phase=[{self._minphase:.1f}, {self._maxphase:.1f}] days, "
f"wave=[{self._minwave:.0f}, {self._maxwave:.0f}] Å)")
[docs]
def __repr__(self):
"""Official string representation."""
return (f"TimeSeriesSource(name={self.name!r}, version={self.version!r}, "
f"zero_before={self._zero_before}, time_spline_degree={self._time_degree})")
[docs]
def bandflux(self, params, bands, phases, zp=None, zpsys=None,
band_indices=None, bridges=None, unique_bands=None):
"""Calculate bandflux using functional API.
Parameters
----------
params : dict
Parameter dictionary. Must contain 'amplitude'.
bands : str, list, or None
Bandpass name(s). Use None in optimised mode with bridges.
phases : float or array_like
Rest-frame phase(s) at which to evaluate flux (days).
zp : float or array_like, optional
Zero point(s). If provided, zpsys must also be given.
zpsys : str, optional
Zero point system (e.g., 'ab'). Required if zp is provided.
band_indices : array_like, optional
(Optimised mode) Integer indices mapping observations to unique_bands.
bridges : tuple of dict, optional
(Optimised mode) Pre-computed bandpass bridges.
unique_bands : list, optional
(Optimised mode) List of unique band names corresponding to bridges.
Returns
-------
float or jnp.array
Bandflux value(s). Shape matches input phases.
"""
# Validate params
if 'amplitude' not in params:
raise ValueError("params must contain 'amplitude'")
# Validate zp/zpsys consistency
if zp is not None and zpsys is None:
raise ValueError('zpsys must be given if zp is not None')
# Extract amplitude
amplitude = params['amplitude']
# Check if input is scalar
scalar_phase_input = np.isscalar(phases)
scalar_band_input = isinstance(bands, str)
scalar_input = scalar_phase_input and scalar_band_input
# Convert phases to JAX array
phases = jnp.atleast_1d(jnp.array(phases))
# Handle zp
if zp is not None:
zps = jnp.atleast_1d(jnp.array(zp))
else:
zps = None
# High-performance path: use precomputed bridges
if bridges is not None and band_indices is not None and unique_bands is not None:
band_indices_arr = jnp.array(band_indices, dtype=jnp.int32)
# Ensure zps has the right length if provided
if zps is not None:
if len(zps) == 1:
zps = jnp.full(len(phases), zps[0])
elif len(zps) != len(phases):
raise ValueError(
f"zp length ({len(zps)}) must match phases length ({len(phases)})"
)
# Calculate model fluxes using optimised multiband function
model_fluxes = timeseries_multiband_flux(
phases, bridges, band_indices_arr,
self._phase, self._wave, self._flux,
amplitude, self._zero_before, self._minphase, self._time_degree,
zps=zps, zpsys=zpsys
)
# Return scalar if input was scalar
if scalar_input:
return model_fluxes[0]
return model_fluxes
# Standard path: create bridges on the fly
if isinstance(bands, str):
bands_arr = [bands]
else:
bands_arr = list(bands) if not isinstance(bands, list) else bands
phases_arr = jnp.atleast_1d(jnp.array(phases))
# Handle zp array
if zps is not None:
if len(zps) == 1:
zps_arr = jnp.full(len(phases_arr), zps[0])
elif len(zps) != len(phases_arr):
raise ValueError(
f"zp length ({len(zps)}) must match phases length ({len(phases_arr)})"
)
else:
zps_arr = zps
else:
zps_arr = None
# If phases and bands have same length
if len(bands_arr) == len(phases_arr):
fluxes = []
for i, (band, phase) in enumerate(zip(bands_arr, phases_arr)):
bandpass = get_bandpass(band)
bridge = precompute_bandflux_bridge(bandpass)
curr_zp = zps_arr[i] if zps_arr is not None else None
flux = timeseries_bandflux(
phase, bridge, self._phase, self._wave, self._flux,
amplitude, self._zero_before, self._minphase, self._time_degree,
zp=curr_zp, zpsys=zpsys
)
fluxes.append(flux)
result = jnp.stack(fluxes)
return result[0] if scalar_input else result
# If single band, multiple phases
elif len(bands_arr) == 1:
bandpass = get_bandpass(bands_arr[0])
bridge = precompute_bandflux_bridge(bandpass)
fluxes = []
for i, phase in enumerate(phases_arr):
curr_zp = zps_arr[i] if zps_arr is not None else None
flux = timeseries_bandflux(
phase, bridge, self._phase, self._wave, self._flux,
amplitude, self._zero_before, self._minphase, self._time_degree,
zp=curr_zp, zpsys=zpsys
)
fluxes.append(flux)
result = jnp.stack(fluxes)
return result[0] if scalar_input else result
# If single phase, multiple bands
elif len(phases_arr) == 1:
phase = phases_arr[0]
fluxes = []
for i, band in enumerate(bands_arr):
bandpass = get_bandpass(band)
bridge = precompute_bandflux_bridge(bandpass)
curr_zp = zps_arr[i] if zps_arr is not None else None
flux = timeseries_bandflux(
phase, bridge, self._phase, self._wave, self._flux,
amplitude, self._zero_before, self._minphase, self._time_degree,
zp=curr_zp, zpsys=zpsys
)
fluxes.append(flux)
return jnp.stack(fluxes)
else:
raise ValueError(
f"Incompatible shapes: bands ({len(bands_arr)}) and phases ({len(phases_arr)}). "
"Either must be same length, or one must be length 1."
)
[docs]
def bandmag(self, params, bands, magsys, phases, band_indices=None,
bridges=None, unique_bands=None):
"""Calculate magnitude using functional API.
Parameters
----------
params : dict
Model parameters. Must contain 'amplitude'.
bands : str or array-like
Bandpass name(s)
magsys : str
Magnitude system (e.g., 'ab')
phases : float or array
Rest-frame phase(s)
band_indices, bridges, unique_bands : optional
For high-performance mode
Returns
-------
mag : float or array
Magnitude value(s). Returns NaN for flux ≤ 0.
"""
# Get flux at appropriate zero point for magnitude system
if magsys == 'ab':
zp = 0.0
else:
zp = 0.0
flux = self.bandflux(params, bands, phases, zp=zp, zpsys=magsys,
band_indices=band_indices, bridges=bridges,
unique_bands=unique_bands)
# Convert to magnitude: m = -2.5 * log10(flux)
flux_safe = jnp.where(flux > 0, flux, jnp.nan)
mag = -2.5 * jnp.log10(flux_safe)
return mag