Sampling

This section describes how to perform basic parameter sampling and optimization with JAX-Supernovae. We focus on defining objective functions and using simple optimization techniques to fit SALT3 model parameters to supernova light curve data. When parameters are evaluated in batches on GPU (as is common in JAX samplers), the fused bandflux kernels can be ~100× faster per parameter set than serial SNCosmo while matching fluxes to 0.001% (see Leeney et al. 2025).

Defining an Objective Function

The first step in parameter fitting is to define an objective function that quantifies the goodness of fit between model predictions and observed data. For supernova light curve fitting, a common choice is the chi-squared statistic:

import jax.numpy as jnp
from jax_supernovae.salt3 import optimized_salt3_multiband_flux

def objective(parameters):
    """
    Objective function for SALT3 parameter fitting.

    Parameters:
    - parameters: Array of [t0, x0, x1, c] (assuming fixed redshift)

    Returns:
    - chi2: Chi-squared value
    """
    # Create parameter dictionary
    params = {
        'z': fixed_z[0],  # Fixed redshift
        't0': parameters[0],
        'x0': parameters[1],
        'x1': parameters[2],
        'c': parameters[3]
    }

    # Calculate model fluxes
    model_fluxes = optimized_salt3_multiband_flux(times, bridges, params, zps=zps, zpsys='ab')
    # Index the model fluxes with band_indices to match observations
    model_fluxes = model_fluxes[jnp.arange(len(times)), band_indices]

    # Calculate chi-squared
    chi2 = jnp.sum(((fluxes - model_fluxes) / fluxerrs)**2)

    return chi2

Basic Sampling with scipy.optimize

Once the objective function is defined, we can use optimization methods from SciPy to find the best-fit parameters:

from scipy.optimize import minimize
import numpy as np

# Initial parameter values
initial_params = np.array([
    58650.0,  # t0
    1e-5,     # x0
    0.0,      # x1
    0.0       # c
])

# Parameter bounds
bounds = [
    (58600.0, 58700.0),  # t0
    (1e-6, 1e-4),        # x0
    (-3.0, 3.0),         # x1
    (-0.3, 0.3)          # c
]

# Optimize the parameters
result = minimize(
    objective,
    initial_params,
    method='L-BFGS-B',
    bounds=bounds
)

# Print the results
print("Optimization successful:", result.success)
print("Number of function evaluations:", result.nfev)

# Extract best-fit parameters
best_params = {
    'z': fixed_z[0],
    't0': result.x[0],
    'x0': result.x[1],
    'x1': result.x[2],
    'c': result.x[3]
}

print("\nBest-fit parameters:")
for name, value in best_params.items():
    print(f"{name:>10} = {value:.6f}")

print(f"\nFinal chi-squared: {result.fun:.2f}")

Complete Sampling Example

Here is a complete example that demonstrates the entire process of loading data, defining an objective function, and optimizing parameters:

 import jax
 import jax.numpy as jnp
 import numpy as np
 from scipy.optimize import minimize
 from jax_supernovae.data import load_and_process_data
 from jax_supernovae.salt3 import optimized_salt3_multiband_flux

 # Enable float64 precision
 jax.config.update("jax_enable_x64", True)

 # Load data
times, fluxes, fluxerrs, zps, band_indices, unique_bands, bridges, fixed_z = load_and_process_data(
     sn_name='19dwz',
     data_dir='data',
     fix_z=True
 )

 # Define the objective function
 def objective(parameters):
     # Create parameter dictionary
     params = {
         'z': fixed_z[0],  # Fixed redshift
         't0': parameters[0],
         'x0': parameters[1],
         'x1': parameters[2],
         'c': parameters[3]
     }

     # Calculate model fluxes
     model_fluxes = optimized_salt3_multiband_flux(times, bridges, params, zps=zps, zpsys='ab')
     # Index the model fluxes with band_indices to match observations
     model_fluxes = model_fluxes[jnp.arange(len(times)), band_indices]

     # Calculate chi-squared
     chi2 = jnp.sum(((fluxes - model_fluxes) / fluxerrs)**2)

     return float(chi2)

 # Initial parameter values
 initial_params = np.array([
     58650.0,  # t0
     1e-5,     # x0
     0.0,      # x1
     0.0       # c
 ])

 # Parameter bounds
 bounds = [
     (58600.0, 58700.0),  # t0
     (1e-6, 1e-4),        # x0
     (-3.0, 3.0),         # x1
     (-0.3, 0.3)          # c
 ]

 # Optimize the parameters
 result = minimize(
     objective,
     initial_params,
     method='L-BFGS-B',
     bounds=bounds
 )

 # Print the results
 print("Optimization successful:", result.success)
 print("Number of function evaluations:", result.nfev)

 # Extract best-fit parameters
 best_params = {
     'z': fixed_z[0],
     't0': result.x[0],
     'x0': result.x[1],
     'x1': result.x[2],
     'c': result.x[3]
 }

 print("\nBest-fit parameters:")
 for name, value in best_params.items():
     print(f"{name:>10} = {value:.6f}")

 print(f"\nFinal chi-squared: {result.fun:.2f}")