TimeSeriesSource

TimeSeriesSource is a JAX-bandflux class for fitting custom supernova spectral energy distributions (SEDs). It provides a JAX/GPU-accelerated implementation matching sncosmo’s TimeSeriesSource API whilst using a functional parameter-passing approach for optimal performance in MCMC and nested sampling applications.

Key Features

  • Custom SED Models: Fit any spectral time series defined on a 2D (phase × wavelength) grid

  • Bicubic Interpolation: Matches sncosmo exactly using JAX primitives

  • Functional API: Parameters passed as dictionaries for JAX compatibility

  • Two-Tier Performance: Simple mode for convenience, optimised mode for speed

  • JIT Compatible: Works seamlessly in JIT-compiled likelihood functions

  • GPU Accelerated: Runs efficiently on GPUs via JAX

  • Numerical Accuracy: Matches sncosmo to <0.01% (tested)

API Comparison: sncosmo vs JAX-bandflux

Constructor (Nearly Identical)

sncosmo:

source = sncosmo.TimeSeriesSource(phase, wave, flux,
                                   zero_before=False,
                                   time_spline_degree=3,
                                   name=None, version=None)

JAX-bandflux:

source = TimeSeriesSource(phase, wave, flux,  # Same signature!
                          zero_before=False,
                          time_spline_degree=3,
                          name=None, version=None)

Method Calls (Functional API)

sncosmo (stateful):

source.set(amplitude=1.0)
flux = source.bandflux('bessellb', 0.0, zp=25.0, zpsys='ab')

JAX-bandflux (functional):

params = {'amplitude': 1.0}
flux = source.bandflux(params, 'bessellb', 0.0, zp=25.0, zpsys='ab')

The key difference: JAX-bandflux passes parameters as a dictionary to each method call, enabling JAX to trace parameter dependencies for autodiff and JIT compilation.

Basic Usage

Creating a TimeSeriesSource

# Define your model grid
phase = np.linspace(-20, 50, 100)  # Days
wave = np.linspace(3000, 9000, 200)  # Angstroms

# Create flux array (phase × wavelength)
p_grid, w_grid = np.meshgrid(phase, wave, indexing='ij')
time_profile = np.exp(-0.5 * (p_grid / 12.0)**2)
wave_profile = np.exp(-0.5 * ((w_grid - 5500.0) / 1200.0)**2)
flux_grid = time_profile * wave_profile * 1e-15

# Create source
source = TimeSeriesSource(phase, wave, flux_grid,
                          zero_before=False,
                          time_spline_degree=3,
                          name='my_model')
print(source.param_names)
['amplitude']

Simple Photometry

# Define parameters
params = {'amplitude': 1.0}

# Single observation at peak (phase=0)
flux_b = source.bandflux(params, 'bessellb', 0.0, zp=25.0, zpsys='ab')
print(f"B-band flux at peak: {float(flux_b):.4e}")

# Light curve (multiple phases, same band)
phases = np.linspace(-10, 30, 10)
fluxes_b = source.bandflux(params, 'bessellb', phases, zp=25.0, zpsys='ab')
print("B-band light curve:")
for p, f in zip(phases[:5], fluxes_b[:5]):  # Show first 5
    print(f"  Phase {p:+6.1f}d: {float(f):8.2f}")

# Multi-band observation (same phase, different bands)
bands = ['bessellb', 'bessellv', 'bessellr']
phases_same = np.zeros(3)
fluxes_multi = source.bandflux(params, bands, phases_same, zp=25.0, zpsys='ab')
print("Flux at peak in different bands:")
for band, flux in zip(bands, fluxes_multi):
    print(f"  {band:10s}: {float(flux):8.2f}")
B-band flux at peak: 1.1498e+03
B-band light curve:
  Phase  -10.0d:   812.49
  Phase   -5.6d:  1032.93
  Phase   -1.1d:  1144.86
  Phase   +3.3d:  1106.27
  Phase   +7.8d:   931.95
Flux at peak in different bands:
  bessellb  :  1149.78
  bessellv  :  2636.44
  bessellr  :  2538.65

Plotting a Custom Model Light Curve

Visualize the custom SED model across multiple bands:

import matplotlib.pyplot as plt
import numpy as np
from jax_supernovae import TimeSeriesSource

# Create custom SED model
phase = np.linspace(-20, 50, 100)
wave = np.linspace(3000, 9000, 200)
p_grid, w_grid = np.meshgrid(phase, wave, indexing='ij')
time_profile = np.exp(-0.5 * (p_grid / 12.0)**2)
wave_profile = np.exp(-0.5 * ((w_grid - 5500.0) / 1200.0)**2)
flux_grid = time_profile * wave_profile * 1e-15

source = TimeSeriesSource(phase, wave, flux_grid,
                          zero_before=False,
                          time_spline_degree=3,
                          name='my_model')
params = {'amplitude': 1.0}

# Generate light curve data
lc_phases = np.linspace(-15, 40, 100)
bands_plot = ['bessellb', 'bessellv', 'bessellr']
colors = {'bessellb': 'blue', 'bessellv': 'green', 'bessellr': 'red'}

plt.figure(figsize=(10, 6))
for band in bands_plot:
    flux = source.bandflux(params, band, lc_phases, zp=25.0, zpsys='ab')
    plt.plot(lc_phases, np.array(flux), color=colors[band], label=band.upper(), linewidth=2)

plt.xlabel('Phase (days)', fontsize=12)
plt.ylabel('Flux (zp=25.0)', fontsize=12)
plt.title('TimeSeriesSource Custom Model', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
_images/timeseries_source-1.png

Calculate Magnitudes

# Magnitude in AB system
mag_b = source.bandmag(params, 'bessellb', 'ab', 0.0)
print(f"B-band magnitude at peak: {float(mag_b):.2f} mag")

# Multi-band magnitudes
print("Magnitudes at peak:")
for band in ['bessellb', 'bessellv', 'bessellr']:
    mag = source.bandmag(params, band, 'ab', 0.0)
    print(f"  {band:10s}: {float(mag):.2f} mag")
B-band magnitude at peak: 17.35 mag
Magnitudes at peak:
  bessellb  : 17.35 mag
  bessellv  : 16.45 mag
  bessellr  : 16.49 mag

High-Performance Mode

For MCMC, nested sampling, or any application requiring many model evaluations, use the optimised mode with pre-computed bridges:

# Example: 30 observations in 3 bands
n_obs = 30
obs_phases = np.linspace(-10, 40, n_obs)
band_names = ['bessellb', 'bessellv', 'bessellr'] * (n_obs // 3)
zps = jnp.ones(n_obs) * 25.0

# Pre-compute bridges ONCE (outside the likelihood)
unique_bands = ['bessellb', 'bessellv', 'bessellr']
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
                for b in unique_bands)

# Create band indices mapping each observation to its bridge
band_to_idx = {'bessellb': 0, 'bessellv': 1, 'bessellr': 2}
band_indices = jnp.array([band_to_idx[b] for b in band_names])

# Fast calculation (10-100x faster than simple mode)
params = {'amplitude': 1.0}
fluxes = source.bandflux(params, None, obs_phases,
                         zp=zps, zpsys='ab',
                         band_indices=band_indices,
                         bridges=bridges,
                         unique_bands=unique_bands)
print(f"Computed {len(fluxes)} fluxes using optimized mode")
print(f"Mean flux: {float(jnp.mean(fluxes)):.2e}, range: [{float(jnp.min(fluxes)):.2e}, {float(jnp.max(fluxes)):.2e}]")
Computed 30 fluxes using optimized mode
Mean flux: 9.92e+02, range: [9.81e+00, 2.60e+03]

JIT-Compiled Likelihood Functions

TimeSeriesSource works seamlessly in JIT-compiled functions:

# Generate synthetic data
true_amplitude = 2.0
np.random.seed(123)
true_fluxes = np.array(source.bandflux({'amplitude': true_amplitude}, None, obs_phases,
                                       zp=zps, zpsys='ab',
                                       band_indices=band_indices,
                                       bridges=bridges,
                                       unique_bands=unique_bands))
flux_errors = np.abs(true_fluxes) * 0.05
observed_fluxes = jnp.array(true_fluxes + np.random.normal(0, flux_errors))
flux_errors = jnp.array(flux_errors)

# Define JIT-compiled likelihood
@jax.jit
def loglikelihood(amplitude):
    """Calculate log-likelihood for given amplitude."""
    params = {'amplitude': amplitude}
    model_fluxes = source.bandflux(params, None, obs_phases,
                                   zp=zps, zpsys='ab',
                                   band_indices=band_indices,
                                   bridges=bridges,
                                   unique_bands=unique_bands)
    chi2 = jnp.sum(((observed_fluxes - model_fluxes) / flux_errors)**2)
    return -0.5 * chi2

# Evaluate likelihood at true amplitude
logL_true = loglikelihood(2.0)
print(f"Log-likelihood at true amplitude (2.0): {float(logL_true):.2f}")

# Test at wrong amplitude
logL_wrong = loglikelihood(1.0)
print(f"Log-likelihood at wrong amplitude (1.0): {float(logL_wrong):.2f}")

print(f"Difference in log-likelihood: {float(logL_true - logL_wrong):.1f}")
Log-likelihood at true amplitude (2.0): -20.47
Log-likelihood at wrong amplitude (1.0): -1533.88
Difference in log-likelihood: 1513.4

Parameters

Constructor Parameters

Parameter

Type

Default

Description

phase

array_like

Required

1D array of phase values (days). Must be sorted ascending.

wave

array_like

Required

1D array of wavelength values (Å). Must be sorted ascending.

flux

array_like

Required

2D array of flux values (erg/s/cm²/Å). Shape: (len(phase), len(wave)).

zero_before

bool

False

If True, flux is zero for phase < minphase. If False, extrapolates.

time_spline_degree

int

3

Time interpolation degree: 1 (linear) or 3 (cubic).

name

str

None

Optional name for the model.

version

str

None

Optional version identifier.

Model Parameters (Functional API)

The functional API requires passing parameters as a dictionary to each method call:

Parameter

Type

Description

amplitude

float

Scaling factor for the model flux.

Example:

params = {'amplitude': 1.5}
flux = source.bandflux(params, 'bessellb', 0.0, zp=25.0, zpsys='ab')
print(f"Flux with amplitude=1.5: {float(flux):.2f}")

# Compare to amplitude=1.0
flux_1 = source.bandflux({'amplitude': 1.0}, 'bessellb', 0.0, zp=25.0, zpsys='ab')
print(f"Ratio: {float(flux / flux_1):.2f} (expected: 1.5)")
Flux with amplitude=1.5: 1724.67
Ratio: 1.50 (expected: 1.5)

Methods

bandflux

Calculate bandflux through specified bandpass(es).

Signature:

bandflux(params, bands, phases, zp=None, zpsys=None, **kwargs)

Parameters:

  • params (dict): Must contain 'amplitude'

  • bands (str, list, or None): Bandpass name(s). Use None for optimised mode.

  • phases (float or array): Rest-frame phase(s) in days

  • zp (float or array, optional): Zero point(s)

  • zpsys (str, optional): Zero point system (‘ab’, etc.)

  • band_indices (array, optional): For optimised mode

  • bridges (tuple, optional): For optimised mode

  • unique_bands (list, optional): For optimised mode

Returns:

  • float or array: Bandflux value(s) matching input shape

bandmag

Calculate magnitude through specified bandpass(es).

Signature:

bandmag(params, bands, magsys, phases, **kwargs)

Parameters:

  • params (dict): Must contain 'amplitude'

  • bands (str or list): Bandpass name(s)

  • magsys (str): Magnitude system (‘ab’, etc.)

  • phases (float or array): Rest-frame phase(s)

  • Additional kwargs for optimised mode

Returns:

  • float or array: Magnitude value(s). Returns NaN for flux ≤ 0.

Properties

  • param_names: List of parameter names ([‘amplitude’])

  • minphase(): Minimum phase of model (days)

  • maxphase(): Maximum phase of model (days)

  • minwave(): Minimum wavelength of model (Å)

  • maxwave(): Maximum wavelength of model (Å)

Advanced Topics

Interpolation Methods

TimeSeriesSource supports two interpolation methods:

Cubic Interpolation (default):

source = TimeSeriesSource(phase, wave, flux, time_spline_degree=3)
  • Uses bicubic interpolation (same as sncosmo)

  • Smooth light curves

  • Better for well-sampled grids

Linear Interpolation:

source = TimeSeriesSource(phase, wave, flux, time_spline_degree=1)
  • Uses bilinear interpolation

  • Faster computation

  • Better for coarse grids or performance-critical applications

Zero-Before Behaviour

zero_before=False (default):

  • Extrapolates flux for phases before minphase

  • Uses edge values from the grid

  • Suitable for models where early-time flux is uncertain

zero_before=True:

source = TimeSeriesSource(phase, wave, flux, zero_before=True)
  • Returns exactly zero for phase < minphase

  • Suitable for models that should not have flux before explosion

  • Matches sncosmo’s behaviour

Handling Redshift

TimeSeriesSource works in rest-frame. Calculate rest-frame phases outside:

z = 0.5
t0 = 58650.0
times_obs = np.array([58640, 58650, 58660, 58670])
phases_rest = (times_obs - t0) / (1 + z)
print(phases_rest)
[-6.66666667  0.          6.66666667 13.33333333]

Performance Tips

  1. Use Optimised Mode for Fitting: Pre-compute bridges once, reuse many times

  2. JIT Compile Likelihood Functions: Use @jax.jit for 10-100x speedup

  3. Batch Observations: Process multiple observations together when possible

  4. Appropriate Grid Resolution: Balance accuracy vs memory/compute

  5. Use GPU When Available: JAX automatically uses GPU if available

Comparison with SALT3Source

Feature

TimeSeriesSource

SALT3Source

Model Type

Custom SED

SALT3-NIR only

Parameters

amplitude

x0, x1, c

Flexibility

Any 2D flux grid

Fixed SALT3 model

Use Case

Custom models, rare events

Type Ia SNe standardisation

Performance

Comparable

Comparable

Both classes coexist and can be used together in the same analysis.

Common Issues

Q: Why does my model return NaN?

Check that:

  1. Your phase/wavelength ranges cover the requested observations

  2. Flux values are finite (no NaN/Inf in input grid)

  3. For magnitudes: flux must be positive

Q: Why is simple mode slow?

Simple mode creates bandpass bridges on-the-fly. For repeated calculations (MCMC/nested sampling), use optimised mode with pre-computed bridges.

Q: Can I use this with nested sampling?

Yes! TimeSeriesSource is designed for this. Use optimised mode with the JIT-compiled likelihood pattern shown above. See Sampling for complete nested sampling examples.