"""SALT3-NIR model implementation in JAX."""
import jax
import jax.numpy as jnp
import jax.lax as lax
import numpy as np
import sncosmo
import os
import math
from jax_supernovae.bandpasses import HC_ERG_AA, MODEL_BANDFLUX_SPACING
from functools import partial
from jax import vmap
import importlib.resources
from jax_supernovae import dust
from jax_supernovae.utils import bandflux_integration, apply_zeropoint
# Constants
H_ERG_S = 6.62607015e-27 # Planck constant in erg*s
# Get package directory
PACKAGE_DIR = os.path.dirname(__file__)
# Model directory - hardcoded path
MODEL_DIR = os.path.join(PACKAGE_DIR, 'data/models/salt3-nir/salt3nir-p22')
def read_griddata_file(filename):
"""Read 2-d grid data from a text file.
Parameters
----------
filename : str
Path to the file containing grid data
Returns
-------
tuple
(phase, wavelength, values) where:
- phase is an array of unique phase values
- wavelength is an array of unique wavelength values
- values is a 2D grid of flux values
Notes
-----
Each line in the file has values `x0 x1 y` (phase, wavelength, flux), space separated.
"""
# Read data from file
data = np.loadtxt(filename)
# Get unique phase and wavelength values, ensuring they're sorted
phase = np.sort(np.unique(data[:, 0]))
wave = np.sort(np.unique(data[:, 1]))
# Create empty grid
values = np.zeros((len(phase), len(wave)))
# Map each data point to its position in the grid
for p, w, v in data:
pi = np.searchsorted(phase, p)
wi = np.searchsorted(wave, w)
values[pi, wi] = v
return phase, wave, values
# Read M0 and M1 data
m0_file = os.path.join(MODEL_DIR, 'salt3_template_0.dat')
m1_file = os.path.join(MODEL_DIR, 'salt3_template_1.dat')
cl_file = os.path.join(MODEL_DIR, 'salt3_color_correction.dat')
# Read data and apply scaling (match SNCosmo exactly)
SCALE_FACTOR = 1e-12
phase_grid, wave_grid, m0_data = read_griddata_file(m0_file)
_, _, m1_data = read_griddata_file(m1_file)
# Apply scale factor to data (match SNCosmo exactly)
m0_data = m0_data * SCALE_FACTOR
m1_data = m1_data * SCALE_FACTOR
# Convert to JAX arrays
phase_grid = jnp.array(phase_grid)
wave_grid = jnp.array(wave_grid)
m0_data = jnp.array(m0_data) # Scale factor already applied
m1_data = jnp.array(m1_data) # Scale factor already applied
# Read color law coefficients
with open(cl_file, 'r') as f:
words = f.read().split()
ncoeffs = int(words[0])
colorlaw_coeffs = jnp.array([float(word) for word in words[1: 1 + ncoeffs]])
colorlaw_range = [3000., 7000.] # Default range
for i in range(1+ncoeffs, len(words), 2):
if words[i] == 'Salt2ExtinctionLaw.min_lambda':
colorlaw_range[0] = float(words[i+1])
elif words[i] == 'Salt2ExtinctionLaw.max_lambda':
colorlaw_range[1] = float(words[i+1])
@jax.jit
def kernval(x):
"""Compute kernel value for bicubic interpolation.
Parameters
----------
x : float or array
Input value(s) for kernel function
Returns
-------
float or array
Kernel value(s)
Notes
-----
This matches SNCosmo's implementation exactly:
W(x) = (a+2)*x**3-(a+3)*x**2+1 for x<=1
W(x) = a( x**3-5*x**2+8*x-4) for 1<x<2
W(x) = 0 for x>2
where a=-0.5
"""
x = jnp.abs(x)
a = -0.5 # This matches SNCosmo's value
# Calculate the result for each case
case1 = (a + 2) * x**3 - (a + 3) * x**2 + 1 # x <= 1
case2 = a * (x**3 - 5 * x**2 + 8 * x - 4) # 1 < x < 2
# Use where to select the appropriate result
result = jnp.where(x <= 1, case1,
jnp.where(x < 2, case2, 0.0))
return result
@jax.jit
def find_index(values, x):
"""Find index i such that values[i] <= x < values[i+1].
Parameters
----------
values : array
Sorted array of values
x : float
Value to find in the array
Returns
-------
int
Index i such that values[i] <= x < values[i+1]
"""
i = jnp.searchsorted(values, x) - 1
i = jnp.clip(i, 0, len(values) - 2) # Ensure we stay within bounds
return i.astype(jnp.int32)
@jax.jit
def compute_interpolation_weights(x, values):
"""Compute interpolation weights and indices.
Parameters
----------
x : float
The point to interpolate at
values : array
The grid values to interpolate between
Returns
-------
tuple
(indices, normalized coordinates, in_bounds, near_boundary)
"""
# Find indices
i = find_index(values, x)
# Check bounds and boundaries
in_bounds = (x >= values[0]) & (x <= values[-1])
near_boundary = (i <= 0) | (i >= len(values) - 2)
# Calculate normalized coordinates
dx = (x - values[i]) / (values[i + 1] - values[i])
return i, dx, in_bounds, near_boundary
@jax.jit
def interpolate_2d(phase, wave, data):
"""Perform 2D interpolation on gridded data.
Parameters
----------
phase : float
Phase value to interpolate at
wave : float
Wavelength value to interpolate at
data : array
2D grid of values to interpolate from
Returns
-------
float
Interpolated value
"""
# Compute weights for both dimensions
ix, dx, x_in_bounds, x_near_boundary = compute_interpolation_weights(phase, phase_grid)
iy, dy, y_in_bounds, y_near_boundary = compute_interpolation_weights(wave, wave_grid)
# Check if we need to use linear interpolation
near_boundary = x_near_boundary | y_near_boundary
# Get corner values for linear interpolation
z00 = data[ix, iy]
z01 = data[ix, iy + 1]
z10 = data[ix + 1, iy]
z11 = data[ix + 1, iy + 1]
# Linear interpolation
linear_result = (z00 * (1 - dx) * (1 - dy) +
z10 * dx * (1 - dy) +
z01 * (1 - dx) * dy +
z11 * dx * dy)
# For bicubic interpolation, pad the array with edge values
padded = jnp.pad(data, ((1, 1), (1, 1)), mode='edge')
# Get 4x4 grid for bicubic interpolation
ix_pad = ix + 1 # Adjust for padding
iy_pad = iy + 1
grid = lax.dynamic_slice(padded, (ix_pad - 1, iy_pad - 1), (4, 4))
# Calculate bicubic weights
wx = jnp.array([
kernval(dx + 1.0),
kernval(dx),
kernval(dx - 1.0),
kernval(dx - 2.0)
])
wy = jnp.array([
kernval(dy + 1.0),
kernval(dy),
kernval(dy - 1.0),
kernval(dy - 2.0)
])
# Calculate bicubic interpolation
cubic_result = jnp.sum(jnp.outer(wx, wy) * grid)
# Use linear interpolation near boundaries, bicubic otherwise
result = jnp.where(near_boundary, linear_result, cubic_result)
# Return 0 if out of bounds, interpolated value otherwise
return jnp.where(x_in_bounds & y_in_bounds, result, 0.0)
@jax.jit
def salt3_m0_single(phase, wave):
"""Get the M0 component at a single phase and wavelength.
Parameters
----------
phase : float
Rest-frame phase in days
wave : float
Rest-frame wavelength in Angstroms
Returns
-------
float
M0 component value
"""
return interpolate_2d(phase, wave, m0_data)
@jax.jit
def salt3_m1_single(phase, wave):
"""Get the M1 component at a single phase and wavelength.
Parameters
----------
phase : float
Rest-frame phase in days
wave : float
Rest-frame wavelength in Angstroms
Returns
-------
float
M1 component value
"""
return interpolate_2d(phase, wave, m1_data)
@jax.jit
def salt3_m0(phase, wave):
"""Get the M0 component at the given phase and wavelength.
Args:
phase (float or array): Rest-frame phase in days
wave (float or array): Rest-frame wavelength in Angstroms
Returns:
float or array: M0 component value(s)
"""
phase = jnp.asarray(phase)
wave = jnp.asarray(wave)
# Handle scalar inputs
if phase.ndim == 0 and wave.ndim == 0:
return salt3_m0_single(phase, wave)
# Handle 2D inputs with broadcasting
if phase.ndim == 2 and wave.ndim == 2:
# First vmap over phases (axis 0)
phase_mapped = jax.vmap(lambda p: jax.vmap(lambda w: salt3_m0_single(p, w))(wave[0, :]))(phase[:, 0])
return phase_mapped
# Handle array inputs of same size
if phase.ndim == 1 and wave.ndim == 1 and phase.shape == wave.shape:
return jax.vmap(lambda p, w: salt3_m0_single(p, w))(phase, wave)
# Handle broadcasting case (phase array with single wavelength)
if phase.ndim == 1 and wave.ndim == 0:
return jax.vmap(lambda p: salt3_m0_single(p, wave))(phase)
# Handle broadcasting case (single phase with wavelength array)
if phase.ndim == 0 and wave.ndim == 1:
return jax.vmap(lambda w: salt3_m0_single(phase, w))(wave)
# Handle broadcasting case (phase array with wave array of different size)
if phase.ndim == 1 and wave.ndim == 1:
# First map over phases, then over wavelengths
return jax.vmap(lambda p: jax.vmap(lambda w: salt3_m0_single(p, w))(wave))(phase)
raise ValueError("Unsupported input shapes for salt3_m0")
@jax.jit
def salt3_m1(phase, wave):
"""Get the M1 component at the given phase and wavelength.
Args:
phase (float or array): Rest-frame phase in days
wave (float or array): Rest-frame wavelength in Angstroms
Returns:
float or array: M1 component value(s)
"""
phase = jnp.asarray(phase)
wave = jnp.asarray(wave)
# Handle scalar inputs
if phase.ndim == 0 and wave.ndim == 0:
return salt3_m1_single(phase, wave)
# Handle 2D inputs with broadcasting
if phase.ndim == 2 and wave.ndim == 2:
# First vmap over phases (axis 0)
phase_mapped = jax.vmap(lambda p: jax.vmap(lambda w: salt3_m1_single(p, w))(wave[0, :]))(phase[:, 0])
return phase_mapped
# Handle array inputs of same size
if phase.ndim == 1 and wave.ndim == 1 and phase.shape == wave.shape:
return jax.vmap(lambda p, w: salt3_m1_single(p, w))(phase, wave)
# Handle broadcasting case (phase array with single wavelength)
if phase.ndim == 1 and wave.ndim == 0:
return jax.vmap(lambda p: salt3_m1_single(p, wave))(phase)
# Handle broadcasting case (single phase with wavelength array)
if phase.ndim == 0 and wave.ndim == 1:
return jax.vmap(lambda w: salt3_m1_single(phase, w))(wave)
# Handle broadcasting case (phase array with wave array of different size)
if phase.ndim == 1 and wave.ndim == 1:
# First map over phases, then over wavelengths
return jax.vmap(lambda p: jax.vmap(lambda w: salt3_m1_single(p, w))(wave))(phase)
raise ValueError("Unsupported input shapes for salt3_m1")
@jax.jit
def salt3_colorlaw(wave):
"""Calculate SALT3 color law at given wavelength."""
wave = jnp.asarray(wave)
# Define constants (exactly as in SNCosmo)
B_WAVE = 4302.57
V_WAVE = 5428.55
v_minus_b = V_WAVE - B_WAVE
# Calculate normalized wavelength
l = (wave - B_WAVE) / v_minus_b
l_lo = (colorlaw_range[0] - B_WAVE) / v_minus_b
l_hi = (colorlaw_range[1] - B_WAVE) / v_minus_b
# Calculate polynomial coefficients
alpha = 1. - jnp.sum(colorlaw_coeffs)
coeffs = jnp.concatenate([jnp.array([0., alpha]), colorlaw_coeffs])
coeffs_rev = jnp.flipud(coeffs)
# Calculate derivative coefficients
prime_coeffs = jnp.arange(len(coeffs)) * coeffs
prime_coeffs = prime_coeffs[1:] # Remove first element (0)
prime_coeffs_rev = jnp.flipud(prime_coeffs)
# Calculate polynomial values at boundaries
p_lo = jnp.polyval(coeffs_rev, l_lo)
pprime_lo = jnp.polyval(prime_coeffs_rev, l_lo)
p_hi = jnp.polyval(coeffs_rev, l_hi)
pprime_hi = jnp.polyval(prime_coeffs_rev, l_hi)
# Calculate extinction for each region
extinction = jnp.where(
l < l_lo,
p_lo + pprime_lo * (l - l_lo), # Blue side
jnp.where(
l > l_hi,
p_hi + pprime_hi * (l - l_hi), # Red side
jnp.polyval(coeffs_rev, l) # In between
)
)
# Return negative extinction to match SNCosmo's convention
return -extinction
@partial(jax.jit, static_argnames=['bandpass', 'zpsys'])
def salt3_bandflux(phase, bandpass, params, zp=None, zpsys=None):
"""Calculate bandflux for SALT3 model.
Parameters
----------
Rest-frame phase in days relative to maximum brightness.
bandpass : Bandpass object
Bandpass to calculate flux through.
params : dict
Model parameters including z, t0, x0, x1, c.
Optional dust parameters:
- 'dust_type': int, dust law index (0=ccm89, 1=od94, 2=f99)
- 'ebv': float, E(B-V) value
- 'r_v': float, R_V value (default: 3.1)
zp : float or None, optional
Zero point for flux. If None, no scaling is applied.
zpsys : str, optional
Magnitude system for zero point. Must be provided if zp is not None.
Default is None.
Returns
-------
float or array_like
Flux in photons/s/cm^2. Return value is float if phase is scalar,
array if phase is array. If zp and zpsys are given, flux is scaled
to the requested zeropoint.
"""
# Check that if zp is provided, zpsys must also be provided
if zp is not None and zpsys is None:
raise ValueError('zpsys must be given if zp is not None')
# Check if input is scalar BEFORE converting to array
is_scalar = jnp.ndim(phase) == 0
# Convert inputs to arrays
phase = jnp.atleast_1d(phase)
# Get parameters
z = params['z']
t0 = params['t0']
x0 = params['x0']
x1 = params['x1']
c = params['c']
# Convert to rest-frame phase
a = 1.0 / (1.0 + z) # Scale factor
restphase = (phase - t0) * a
# Use pre-computed integration grid from bandpass
wave = bandpass.integration_wave
dwave = bandpass.integration_spacing
restwave = wave * a
trans = bandpass(wave)
# Pre-compute color law for all wavelengths
cl = salt3_colorlaw(restwave)
# Compute M0 and M1 components for all phases and wavelengths at once
m0 = salt3_m0(restphase[:, None], restwave[None, :])
m1 = salt3_m1(restphase[:, None], restwave[None, :])
# Calculate rest-frame flux for all phases and wavelengths at once
rest_flux = x0 * (m0 + x1 * m1) * 10**(-0.4 * cl[None, :] * c) * a
# Apply dust extinction if parameters are provided
has_dust = 'dust_type' in params and 'ebv' in params
if has_dust:
dust_type_idx = params['dust_type']
ebv = params['ebv']
r_v = params.get('r_v', 3.1) # Default R_V = 3.1 if not specified
# Get the appropriate dust law function based on the index
if dust_type_idx == 0:
dust_law = dust.ccm89_extinction
elif dust_type_idx == 1:
dust_law = dust.od94_extinction
elif dust_type_idx == 2:
dust_law = dust.f99_extinction
else:
# Default to CCM89
dust_law = dust.ccm89_extinction
# Calculate extinction for each wavelength
extinction = dust_law(restwave, ebv, r_v)
# Apply extinction to rest-frame flux
rest_flux = dust.apply_extinction(rest_flux, extinction[None, :])
# Integrate flux through bandpass using shared integration function
# bandflux_integration expects flux with shape (..., N_wave)
result = bandflux_integration(wave, trans, rest_flux, dwave)
# Apply zero point if provided
if zp is not None:
# Get the magsystem's zpbandflux for this bandpass
if zpsys == 'ab':
# For AB system, calculate zpbandflux
# AB spectrum is 3631 x 10^{-23} erg/s/cm^2/Hz
# Convert to F_lambda: 3631e-23 * c / wave^2 erg/s/cm^2/AA
# Then integrate: sum(f * trans * wave) * dwave / (hc)
zpbandflux = 3631e-23 * dwave / H_ERG_S * jnp.sum(trans / wave)
else:
raise ValueError(f"Unsupported magnitude system: {zpsys}")
# Scale the flux according to the zeropoint (exactly like sncosmo)
zpnorm = 10.**(0.4 * zp) / zpbandflux
result = result * zpnorm
# Return scalar if input was scalar
if is_scalar:
result = result[0]
return result
@partial(jax.jit, static_argnames=['bandpasses', 'zpsys'])
def salt3_multiband_flux(phase, bandpasses, params, zps=None, zpsys=None):
"""Calculate flux for multiple bandpasses at once.
Args:
phase (array-like): Phase(s) in observer frame.
bandpasses (list): List of Bandpass objects.
params (dict): Model parameters including z, t0, x0, x1, c.
zps (array-like, optional): Zero points for each bandpass.
zpsys (str, optional): Magnitude system (e.g. 'ab').
Returns:
array-like: Flux values for each phase and bandpass combination.
"""
# Convert inputs to arrays
phase = jnp.atleast_1d(phase)
n_phase = len(phase)
n_bands = len(bandpasses)
# Initialize output array
result = jnp.zeros((n_phase, n_bands))
# Calculate flux for each bandpass
for i in range(n_bands):
zp = zps[i] if zps is not None else None
band_flux = salt3_bandflux(phase, bandpasses[i], params, zp=zp, zpsys=zpsys)
result = result.at[:, i].set(band_flux)
return result
[docs]
def precompute_bandflux_bridge(bandpass):
"""Precompute static components for a given bandpass.
Parameters
----------
bandpass : Bandpass
Bandpass object to precompute components for
Returns
-------
dict
Dictionary containing:
- 'wave': the integration wavelength grid
- 'dwave': spacing between grid points
- 'trans': the transmission values computed on the grid
- 'wave_original': original wavelength array for shift interpolation
- 'trans_original': original transmission array
- 'zpbandflux_ab': AB zeropoint normalization for this band
"""
wave = bandpass.integration_wave
dwave = bandpass.integration_spacing
trans = bandpass(wave)
# Zeropoint normalization is constant per bandpass (AB system)
zpbandflux_ab = 3631e-23 * dwave / H_ERG_S * jnp.sum(trans / wave)
# Store original arrays for shift interpolation
return {
'wave': wave,
'dwave': dwave,
'trans': trans,
'wave_original': bandpass.wave,
'trans_original': bandpass.trans,
'zpbandflux_ab': zpbandflux_ab,
}
@jax.jit
def compute_shifted_transmission(wave, wave_original, trans_original, shift):
"""Compute transmission values with wavelength shift.
Parameters
----------
wave : array
Wavelengths at which to evaluate transmission
wave_original : array
Original wavelength array from bandpass
trans_original : array
Original transmission array from bandpass
shift : float or array
Wavelength shift(s) to apply
Returns
-------
array
Shifted transmission values
"""
# Apply shift - if shift is callable, it should be evaluated outside JIT
effective_wave = wave - shift
# Use existing interp function from utils
from jax_supernovae.utils import interp
return interp(effective_wave, wave_original, trans_original)
[docs]
@partial(jax.jit, static_argnames=['zpsys'])
def optimized_salt3_bandflux(phase, wave, dwave, trans, params,
zp=None, zpsys=None, shift=0.0,
wave_original=None, trans_original=None):
"""Calculate bandflux for a single bandpass using precomputed static data.
Parameters
----------
phase : array or scalar
Observer-frame phase(s) at which to compute the flux
wave : array
Wavelength grid for integration
dwave : float
Spacing between wavelength grid points
trans : array
Transmission values on the wavelength grid (used if shift=0)
params : dict
Dictionary containing model parameters: 'z', 't0', 'x0', 'x1', 'c'
Optional dust parameters:
- 'dust_type': int, dust law index (0=ccm89, 1=od94, 2=f99)
- 'ebv': float, E(B-V) value
- 'r_v': float, R_V value (default: 3.1)
zp : float or None, optional
Zero point for flux scaling
zpsys : str or None, optional
Magnitude system (e.g. 'ab')
shift : float, optional
Constant wavelength shift to apply to transmission curve (in Angstroms)
wave_original : array, optional
Original wavelength array (required if shift != 0)
trans_original : array, optional
Original transmission array (required if shift != 0)
Returns
-------
float or array
Flux in photons/s/cm^2
"""
if zp is not None and zpsys is None:
raise ValueError('zpsys must be given if zp is not None')
# Check if input is scalar BEFORE converting to array
is_scalar = jnp.ndim(phase) == 0
# Convert inputs to arrays
phase = jnp.atleast_1d(phase)
z = params['z']
t0 = params['t0']
x0 = params['x0']
x1 = params['x1']
c = params['c']
# Calculate scaling factor and transform phase to rest-frame.
a = 1.0 / (1.0 + z)
restphase = (phase - t0) * a
# Scale the integration grid to rest-frame wavelengths.
restwave = wave * a
# Get transmission values - use shifted version if shift is non-zero
# Use jnp.where to handle conditional logic in JAX
shift_is_nonzero = jnp.abs(shift) > 0.0
has_original_arrays = (wave_original is not None) and (trans_original is not None)
if has_original_arrays:
# Apply shift and recompute transmission
trans_computed = compute_shifted_transmission(
wave, wave_original, trans_original, shift
)
# Use jnp.where to select between shifted and original transmission
trans_shifted = jnp.where(shift_is_nonzero, trans_computed, trans)
else:
# Use pre-computed transmission (backward compatibility)
trans_shifted = trans
# Compute colour law on the restwave grid.
cl = salt3_colorlaw(restwave)
# Compute m0 and m1 components over the 2D grid.
m0 = salt3_m0(restphase[:, None], restwave[None, :])
m1 = salt3_m1(restphase[:, None], restwave[None, :])
# Compute rest-frame flux including the colour law effect.
rest_flux = x0 * (m0 + x1 * m1) * 10**(-0.4 * cl[None, :] * c) * a
# Apply dust extinction if parameters are provided
has_dust = 'dust_type' in params and 'ebv' in params
# Define a function to apply dust extinction based on dust_type
def apply_ccm89(restwave, ebv, r_v):
return dust.ccm89_extinction(restwave, ebv, r_v)
def apply_od94(restwave, ebv, r_v):
return dust.od94_extinction(restwave, ebv, r_v)
def apply_f99(restwave, ebv, r_v):
return dust.f99_extinction(restwave, ebv, r_v)
# Apply dust extinction conditionally
if has_dust:
ebv = params['ebv']
r_v = params.get('r_v', 3.1) # Default R_V = 3.1 if not specified
dust_type_idx = params['dust_type']
# Use a JAX-friendly approach to select the dust law
extinction = jnp.zeros_like(restwave)
extinction = jnp.where(dust_type_idx == 0, apply_ccm89(restwave, ebv, r_v), extinction)
extinction = jnp.where(dust_type_idx == 1, apply_od94(restwave, ebv, r_v), extinction)
extinction = jnp.where(dust_type_idx == 2, apply_f99(restwave, ebv, r_v), extinction)
# Apply extinction to rest-frame flux
rest_flux = dust.apply_extinction(rest_flux, extinction[None, :])
# Use trans_shifted (which handles transmission shifts) instead of trans
# Integrate using shared integration function
result = bandflux_integration(wave, trans_shifted, rest_flux, dwave)
# Apply zero point correction if required.
if zp is not None:
if zpsys == 'ab':
# Note: zpbandflux should also use shifted transmission
zpbandflux = 3631e-23 * dwave / H_ERG_S * jnp.sum(trans_shifted / wave)
else:
raise ValueError(f"Unsupported magnitude system: {zpsys}")
zpnorm = 10**(0.4 * zp) / zpbandflux
result = result * zpnorm
# Return scalar if input was scalar
if is_scalar:
result = result[0]
return result
[docs]
@partial(jax.jit, static_argnames=['zpsys'])
def optimized_salt3_multiband_flux(phase, bridges, params, zps=None, zpsys=None, shifts=None):
"""Calculate fluxes for multiple bandpasses with transmission shifts.
Parameters
----------
phase : array
Observer-frame phases
bridges : list of dict
Precomputed bridge data for each bandpass
params : dict
Model parameters
zps : list or array or None, optional
Zero points for each bandpass
zpsys : str or None, optional
Magnitude system
shifts : list or array or None, optional
Constant wavelength shifts for each bandpass (in Angstroms)
Returns
-------
array
Array of flux values for each phase and band
"""
phase = jnp.atleast_1d(phase)
n_phase = len(phase)
n_bands = len(bridges)
result = jnp.zeros((n_phase, n_bands))
# Default shifts to zero if not provided
if shifts is None:
shifts = [0.0] * n_bands
for i in range(n_bands):
bp_bridge = bridges[i]
curr_zp = zps[i] if zps is not None else None
curr_shift = shifts[i]
# Extract original arrays if available
wave_original = bp_bridge.get('wave_original', None)
trans_original = bp_bridge.get('trans_original', None)
band_flux = optimized_salt3_bandflux(
phase,
bp_bridge['wave'],
bp_bridge['dwave'],
bp_bridge['trans'],
params,
zp=curr_zp,
zpsys=zpsys,
shift=curr_shift,
wave_original=wave_original,
trans_original=trans_original
)
result = result.at[:, i].set(band_flux)
return result