Sampling
This section describes how to perform basic parameter sampling and optimization with JAX-Supernovae. We focus on defining objective functions and using simple optimization techniques to fit SALT3 model parameters to supernova light curve data. When parameters are evaluated in batches on GPU (as is common in JAX samplers), the fused bandflux kernels can be ~100× faster per parameter set than serial SNCosmo while matching fluxes to 0.001% (see Leeney et al. 2025).
Defining an Objective Function
The first step in parameter fitting is to define an objective function that quantifies the goodness of fit between model predictions and observed data. For supernova light curve fitting, a common choice is the chi-squared statistic:
import jax.numpy as jnp
from jax_supernovae.salt3 import optimized_salt3_multiband_flux
def objective(parameters):
"""
Objective function for SALT3 parameter fitting.
Parameters:
- parameters: Array of [t0, x0, x1, c] (assuming fixed redshift)
Returns:
- chi2: Chi-squared value
"""
# Create parameter dictionary
params = {
'z': fixed_z[0], # Fixed redshift
't0': parameters[0],
'x0': parameters[1],
'x1': parameters[2],
'c': parameters[3]
}
# Calculate model fluxes
model_fluxes = optimized_salt3_multiband_flux(times, bridges, params, zps=zps, zpsys='ab')
# Index the model fluxes with band_indices to match observations
model_fluxes = model_fluxes[jnp.arange(len(times)), band_indices]
# Calculate chi-squared
chi2 = jnp.sum(((fluxes - model_fluxes) / fluxerrs)**2)
return chi2
Basic Sampling with scipy.optimize
Once the objective function is defined, we can use optimization methods from SciPy to find the best-fit parameters:
from scipy.optimize import minimize
import numpy as np
# Initial parameter values
initial_params = np.array([
58650.0, # t0
1e-5, # x0
0.0, # x1
0.0 # c
])
# Parameter bounds
bounds = [
(58600.0, 58700.0), # t0
(1e-6, 1e-4), # x0
(-3.0, 3.0), # x1
(-0.3, 0.3) # c
]
# Optimize the parameters
result = minimize(
objective,
initial_params,
method='L-BFGS-B',
bounds=bounds
)
# Print the results
print("Optimization successful:", result.success)
print("Number of function evaluations:", result.nfev)
# Extract best-fit parameters
best_params = {
'z': fixed_z[0],
't0': result.x[0],
'x0': result.x[1],
'x1': result.x[2],
'c': result.x[3]
}
print("\nBest-fit parameters:")
for name, value in best_params.items():
print(f"{name:>10} = {value:.6f}")
print(f"\nFinal chi-squared: {result.fun:.2f}")
Complete Sampling Example
Here is a complete example that demonstrates the entire process of loading data, defining an objective function, and optimizing parameters:
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import minimize
from jax_supernovae.data import load_and_process_data
from jax_supernovae.salt3 import optimized_salt3_multiband_flux
# Enable float64 precision
jax.config.update("jax_enable_x64", True)
# Load data
times, fluxes, fluxerrs, zps, band_indices, unique_bands, bridges, fixed_z = load_and_process_data(
sn_name='19dwz',
data_dir='data',
fix_z=True
)
# Define the objective function
def objective(parameters):
# Create parameter dictionary
params = {
'z': fixed_z[0], # Fixed redshift
't0': parameters[0],
'x0': parameters[1],
'x1': parameters[2],
'c': parameters[3]
}
# Calculate model fluxes
model_fluxes = optimized_salt3_multiband_flux(times, bridges, params, zps=zps, zpsys='ab')
# Index the model fluxes with band_indices to match observations
model_fluxes = model_fluxes[jnp.arange(len(times)), band_indices]
# Calculate chi-squared
chi2 = jnp.sum(((fluxes - model_fluxes) / fluxerrs)**2)
return float(chi2)
# Initial parameter values
initial_params = np.array([
58650.0, # t0
1e-5, # x0
0.0, # x1
0.0 # c
])
# Parameter bounds
bounds = [
(58600.0, 58700.0), # t0
(1e-6, 1e-4), # x0
(-3.0, 3.0), # x1
(-0.3, 0.3) # c
]
# Optimize the parameters
result = minimize(
objective,
initial_params,
method='L-BFGS-B',
bounds=bounds
)
# Print the results
print("Optimization successful:", result.success)
print("Number of function evaluations:", result.nfev)
# Extract best-fit parameters
best_params = {
'z': fixed_z[0],
't0': result.x[0],
'x0': result.x[1],
'x1': result.x[2],
'c': result.x[3]
}
print("\nBest-fit parameters:")
for name, value in best_params.items():
print(f"{name:>10} = {value:.6f}")
print(f"\nFinal chi-squared: {result.fun:.2f}")