Bandpass Loading
JAX-bandflux provides flexible bandpass management for working with astronomical filters from various sources.
Built-in Bandpasses
JAX-bandflux includes common astronomical filters:
from jax_supernovae.bandpasses import get_bandpass, Bandpass, register_bandpass
from jax_supernovae.salt3 import precompute_bandflux_bridge
# Access a built-in bandpass
bp_b = get_bandpass('bessellb')
print(f"Bessell B: {bp_b.minwave():.0f} - {bp_b.maxwave():.0f} Angstroms")
# Compare wavelength coverage across filters
for name in ['bessellb', 'bessellv', 'bessellr', 'besselli']:
bp = get_bandpass(name)
print(f" {name:10s}: {bp.minwave():7.1f} - {bp.maxwave():7.1f} A")
Bessell B: 3600 - 5600 Angstroms
bessellb : 3600.0 - 5600.0 A
bessellv : 4700.0 - 7000.0 A
bessellr : 5500.0 - 9000.0 A
besselli : 7000.0 - 9200.0 A
Available built-in bandpasses:
Bessell:
bessellb,bessellv,bessellr,besselli,besselluxSDSS:
g,r,i,zZTF:
ztfg,ztfrATLAS:
c,o2MASS:
H
Creating Custom Bandpasses
Create a bandpass from wavelength and transmission arrays:
# Define a Gaussian bandpass
wavelengths = np.linspace(4000, 5000, 100)
transmission = np.exp(-((wavelengths - 4500) / 200)**2)
# Create Bandpass object
bandpass = Bandpass(wavelengths, transmission, name='custom_g')
print(f"Custom bandpass '{bandpass.name}':")
print(f" Wavelength range: {bandpass.minwave():.1f} - {bandpass.maxwave():.1f} A")
print(f" Peak wavelength: ~4500 A (Gaussian center)")
print(f" Number of points: {len(wavelengths)}")
Custom bandpass 'custom_g':
Wavelength range: 4000.0 - 5000.0 A
Peak wavelength: ~4500 A (Gaussian center)
Number of points: 100
The Bandpass class automatically:
Normalizes transmission values
Creates integration grids for flux calculations
Provides interpolation for arbitrary wavelengths
Registering Bandpasses
Register custom bandpasses for use by name:
# Register the bandpass
register_bandpass('my_filter', bandpass, force=True)
# Now accessible by name
retrieved = get_bandpass('my_filter')
print(f"Retrieved '{retrieved.name}' as 'my_filter'")
print(f"Range: {retrieved.minwave():.1f} - {retrieved.maxwave():.1f} A")
Retrieved 'custom_g' as 'my_filter'
Range: 4000.0 - 5000.0 A
Loading from Files
Load bandpass data from text files:
from jax_supernovae.bandpasses import load_bandpass_from_file
# Load from two-column file (wavelength, transmission)
bandpass = load_bandpass_from_file('my_filter.dat', name='my_filter')
# Skip header rows if needed
bandpass = load_bandpass_from_file('my_filter.dat', skiprows=1, name='my_filter')
Expected file format:
# wavelength(Å) transmission
4000.0 0.001
4050.0 0.050
4100.0 0.200
...
Spanish Virtual Observatory (SVO)
Download filters from the SVO Filter Profile Service:
from jax_supernovae.bandpasses import create_bandpass_from_svo
# Download UKIRT J-band filter
bandpass = create_bandpass_from_svo('UKIRT/WFCAM.J', output_dir='filter_data')
# Register for later use
register_bandpass('ukirt_j', bandpass)
The SVO provides thousands of filters from major observatories and instruments. Browse available filters at: http://svo2.cab.inta-csic.es/svo/theory/fps/
Pre-computing Bridges
For high-performance calculations, pre-compute integration bridges:
# Pre-compute bridge for a bandpass
bridge = precompute_bandflux_bridge(bandpass)
print(f"Bridge structure keys: {sorted(bridge.keys())}")
print(f"Wavelength grid: {bridge['wave'].shape[0]} points")
print(f"Grid spacing: {bridge['dwave']} Angstroms")
print(f"Wavelength range: {float(bridge['wave'][0]):.1f} - {float(bridge['wave'][-1]):.1f} A")
Bridge structure keys: ['dwave', 'trans', 'trans_original', 'wave', 'wave_original', 'zpbandflux_ab']
Wavelength grid: 200 points
Grid spacing: 5.0 Angstroms
Wavelength range: 4002.5 - 4997.5 A
Bridges contain:
wave: Integration wavelength gridtrans: Transmission values on the griddwave: Grid spacing (typically 5.0 Å)
Multiple Bandpasses
Pre-compute bridges for all bands in your dataset:
# Define unique bands
unique_bands = ['bessellb', 'bessellv', 'bessellr']
# Pre-compute all bridges
bridges = tuple(precompute_bandflux_bridge(get_bandpass(b))
for b in unique_bands)
print(f"Pre-computed {len(bridges)} bridges")
# Inspect bridge sizes
for band, bridge in zip(unique_bands, bridges):
print(f" {band}: {bridge['wave'].shape[0]} points, {float(bridge['wave'][0]):.0f}-{float(bridge['wave'][-1]):.0f} A")
# Create band index mapping
band_to_idx = {b: i for i, b in enumerate(unique_bands)}
print(f"Band index mapping: {band_to_idx}")
Pre-computed 3 bridges
bessellb: 400 points, 3602-5598 A
bessellv: 460 points, 4702-6998 A
bessellr: 700 points, 5502-8998 A
Band index mapping: {'bessellb': 0, 'bessellv': 1, 'bessellr': 2}
Bandpass Properties
Query bandpass characteristics:
bp = get_bandpass('bessellv')
# Wavelength range
print(f"Bessell V properties:")
print(f" Wavelength range: {bp.minwave():.1f} - {bp.maxwave():.1f} A")
print(f" Central wavelength: ~{(bp.minwave() + bp.maxwave()) / 2:.0f} A")
# Access raw data
print(f" Data points: {len(bp.wave)} wavelengths, {len(bp.trans)} transmission values")
print(f" Peak transmission: {float(max(bp.trans)):.3f}")
Bessell V properties:
Wavelength range: 4700.0 - 7000.0 A
Central wavelength: ~5850 A
Data points: 24 wavelengths, 24 transmission values
Peak transmission: 1.000
Interpolation
Bandpasses support interpolation at arbitrary wavelengths:
# Get transmission at specific wavelengths
wave_query = np.array([5000.0, 5500.0, 6000.0, 6500.0])
trans_values = bp(wave_query)
print("Bessell V transmission:")
for w, t in zip(wave_query, trans_values):
print(f" {w:.0f} A: {float(t):.3f}")
Bessell V transmission:
5000 A: 0.485
5500 A: 0.865
6000 A: 0.317
6500 A: 0.037
Wavelength Shifts
Apply wavelength shifts to bandpasses (useful for filter calibration):
# Get transmission with different wavelength shifts
wave_query = np.array([5500.0])
print("Effect of wavelength shift on transmission at 5500 A:")
for shift in [-20.0, -10.0, 0.0, 10.0, 20.0]:
trans = bp(wave_query, shift=shift)
print(f" Shift {shift:+5.1f} A: transmission = {float(trans[0]):.4f}")
Effect of wavelength shift on transmission at 5500 A:
Shift -20.0 A: transmission = 0.8422
Shift -10.0 A: transmission = 0.8538
Shift +0.0 A: transmission = 0.8653
Shift +10.0 A: transmission = 0.8743
Shift +20.0 A: transmission = 0.8833
Registering All Standard Bandpasses
Register all standard bandpasses at once:
from jax_supernovae.bandpasses import register_all_bandpasses
# Register standard and custom bandpasses
bandpass_dict, bridges_dict = register_all_bandpasses(
custom_bandpass_files={'my_filter': 'path/to/filter.dat'},
svo_filters={'jwst_f150w': 'JWST/NIRCam.F150W'}
)
print(f"Registered {len(bandpass_dict)} bandpasses")
Best Practices
Pre-compute bridges once: Do this outside your likelihood function
# GOOD: Pre-compute once bridges = tuple(precompute_bandflux_bridge(get_bandpass(b)) for b in unique_bands) @jax.jit def likelihood(params): return source.bandflux(params, None, phases, bridges=bridges, ...)
Use consistent naming: Register custom bandpasses with descriptive names
Cache SVO downloads: Use
output_dirto save downloaded filters locallyCheck wavelength coverage: Ensure your bandpass covers the relevant wavelength range for your supernova model
source = SALT3Source() bp = get_bandpass('bessellv') # Verify bandpass is within model range print(f"Model range: {source.minwave():.0f} - {source.maxwave():.0f} A") print(f"Bandpass range: {bp.minwave():.0f} - {bp.maxwave():.0f} A") print(f"Bandpass within model? {bp.minwave() > source.minwave() and bp.maxwave() < source.maxwave()}")
Model range: 2000 - 20000 A Bandpass range: 4700 - 7000 A Bandpass within model? True