Source code for jax_supernovae.bandpasses

"""Bandpass handling for JAX supernova models."""
import os
import jax.numpy as jnp
import numpy as np
from functools import partial
import math
from jax_supernovae.utils import interp
from jax_supernovae.constants import HC_ERG_AA, C_AA_PER_S, MODEL_BANDFLUX_SPACING
import requests

# Get package directory
PACKAGE_DIR = os.path.dirname(__file__)

[docs] class Bandpass: """Bandpass filter class.""" def __init__(self, wave, trans, integration_spacing=MODEL_BANDFLUX_SPACING, name=None): """Initialize bandpass with wavelength and transmission arrays.""" self._wave = jnp.asarray(wave) self._trans = jnp.asarray(trans) self._name = name self._minwave = float(jnp.min(wave)) self._maxwave = float(jnp.max(wave)) # Pre-compute integration grid to match sncosmo exactly range_diff = self._maxwave - self._minwave n_steps = math.ceil(range_diff / integration_spacing) self._integration_spacing = range_diff / n_steps # Create grid starting at minwave + 0.5 * spacing self._integration_wave = jnp.linspace( self._minwave + 0.5 * self._integration_spacing, self._maxwave - 0.5 * self._integration_spacing, n_steps )
[docs] def __call__(self, wave, shift=0.0): """Get interpolated transmission at given wavelengths with optional shift. Parameters ---------- wave : array_like Wavelengths at which to evaluate transmission shift : float, optional Constant wavelength shift to apply (in Angstroms) """ wave = jnp.asarray(wave) # Apply constant shift effective_wave = wave - shift return interp(effective_wave, self._wave, self._trans)
@property def name(self): """Optional human-readable bandpass name.""" return self._name
[docs] def minwave(self): """Get minimum wavelength.""" return self._minwave
[docs] def maxwave(self): """Get maximum wavelength.""" return self._maxwave
@property def wave(self): """Get wavelength array.""" return self._wave @property def trans(self): """Get transmission array.""" return self._trans @property def integration_wave(self): """Get pre-computed integration wavelength grid.""" return self._integration_wave @property def integration_spacing(self): """Get integration grid spacing.""" return self._integration_spacing
# Registry to store bandpasses _BANDPASSES = {} def get_bandpass_filepath(band): """Map bandpass name to file path. Parameters ---------- band : str Bandpass name (e.g., 'c', 'o', 'ztfg', 'g', etc.) Returns ------- str Path to the bandpass file """ bandpass_map = { # ATLAS bandpasses 'c': 'bandpasses/atlas/Atlas.Cyan', 'o': 'bandpasses/atlas/Atlas.Orange', # ZTF bandpasses 'ztfg': 'bandpasses/ztf/P48_g.dat', 'ztfr': 'bandpasses/ztf/P48_R.dat', # SDSS bandpasses 'g': 'bandpasses/sdss/sdss_g.dat', # SDSS g-band 'r': 'bandpasses/sdss/sdss_r.dat', # SDSS r-band 'i': 'bandpasses/sdss/sdss_i.dat', # SDSS i-band 'z': 'bandpasses/sdss/sdss_z.dat', # SDSS z-band # 2MASS bandpasses 'H': 'bandpasses/2mass/2mass.H', # 2MASS H-band } if band not in bandpass_map: raise ValueError(f"Unknown bandpass: {band}. Available bandpasses: {list(bandpass_map.keys())}") # Look for the file in sncosmo-modelfiles directory filepath = os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles', bandpass_map[band]) if not os.path.exists(filepath): raise FileNotFoundError(f"Bandpass file not found: {filepath}") return filepath
[docs] def load_bandpass(band): """Load a bandpass from file. Parameters ---------- band : str Name of the bandpass to load Returns ------- bandpass : Bandpass A Bandpass object containing the filter transmission curve. """ fname = get_bandpass_filepath(band) try: # Handle different file formats if band in ['ztfg', 'ztfr']: # ZTF files have a header line data = np.loadtxt(fname, skiprows=1) else: # All other files are simple two-column format data = np.loadtxt(fname) # Create bandpass object return Bandpass( wave=jnp.array(data[:, 0]), trans=jnp.array(data[:, 1]), name=band ) except FileNotFoundError: raise FileNotFoundError(f"Bandpass file for '{band}' not found at {fname}") except Exception as e: raise ValueError(f"Error loading bandpass file for '{band}': {e}")
def load_bandpass_from_file(filepath, skiprows=0, name=None): """Load a bandpass from a custom file path. Parameters ---------- filepath : str Path to the bandpass file skiprows : int, optional Number of header rows to skip name : str, optional Name to register the bandpass under. If None, uses the filename. Returns ------- tuple (name, bandpass) where: - name is the registered name of the bandpass - bandpass is the Bandpass object """ try: # If name not provided, use filename without extension if name is None: name = os.path.splitext(os.path.basename(filepath))[0] # Load data from file data = np.loadtxt(filepath, skiprows=skiprows) # Create bandpass object bandpass = Bandpass( wave=jnp.array(data[:, 0]), trans=jnp.array(data[:, 1]), name=name ) return name, bandpass except FileNotFoundError: raise FileNotFoundError(f"Bandpass file not found at {filepath}") except Exception as e: raise ValueError(f"Error loading bandpass file from {filepath}: {e}") def create_bandpass_from_svo(filter_id, output_dir='filter_data', force_download=False): """Create a bandpass object from a filter profile in the SVO Filter Profile Service. This function downloads a filter profile from the Spanish Virtual Observatory (SVO) Filter Profile Service and creates a Bandpass object from it. Parameters ---------- filter_id : str The SVO filter identifier, e.g., 'UKIRT/WFCAM.J' output_dir : str, optional Directory to save the downloaded filter file. Default is 'filter_data'. force_download : bool, optional If True, download the filter even if it already exists locally. Returns ------- Bandpass A Bandpass object for the specified filter Raises ------ FileNotFoundError If the filter profile file cannot be found or downloaded """ # Check if the file already exists locally local_filename = os.path.join(output_dir, f"{filter_id.replace('/', '_')}.dat") if os.path.exists(local_filename) and not force_download: try: print(f"Loading filter {filter_id} from local file {local_filename}") data = np.loadtxt(local_filename) wave, trans = data[:, 0], data[:, 1] return Bandpass(wave=jnp.array(wave), trans=jnp.array(trans), name=filter_id) except Exception as e: raise FileNotFoundError(f"Failed to load filter profile from {local_filename}: {e}") else: # If the file doesn't exist locally, suggest using the download_svo_filter.py script raise FileNotFoundError( f"Filter profile file for {filter_id} not found at {local_filename}. " f"Please use the download_svo_filter.py script to download it:\n" f"python examples/download_svo_filter.py --filter {filter_id}" )
[docs] def register_bandpass(name, bandpass, force=False): """Register a bandpass with a given name. Parameters ---------- name : str Name to register the bandpass under bandpass : Bandpass Bandpass object to register force : bool, optional If True, overwrite any existing bandpass with the same name Returns ------- None Raises ------ ValueError If a bandpass with the given name already exists and force=False """ global _BANDPASSES if name in _BANDPASSES and not force: raise ValueError(f"Bandpass '{name}' already exists in registry") _BANDPASSES[name] = bandpass
[docs] def get_bandpass(name): """Get a bandpass from the registry. Parameters ---------- name : str or Bandpass Name of the bandpass or a Bandpass object Returns ------- bandpass : Bandpass The requested bandpass Notes ----- Bandpasses must be registered before use. Common bands (Bessell, SDSS, etc.) are automatically registered when SALT3Source is initialized. For custom bands, use register_bandpass() or register_all_bandpasses() before JIT compilation. """ if isinstance(name, Bandpass): return name if name not in _BANDPASSES: raise ValueError( f"Bandpass '{name}' not found in registry. " f"Available bandpasses: {list(_BANDPASSES.keys())}. " f"Bandpasses must be registered before use, especially before JIT compilation." ) return _BANDPASSES[name]
[docs] def load_custom_bandpasses(bandpass_files): """Load custom bandpasses from a list of file paths. Parameters ---------- bandpass_files : list of str or dict List of file paths to bandpass files, or a dictionary mapping bandpass names to file paths Returns ------- dict Dictionary mapping bandpass names to Bandpass objects """ custom_bandpasses = {} if not bandpass_files: return custom_bandpasses if isinstance(bandpass_files, dict): # Dictionary mapping names to file paths for name, filepath in bandpass_files.items(): try: _, bandpass = load_bandpass_from_file(filepath, name=name) register_bandpass(name, bandpass, force=True) custom_bandpasses[name] = bandpass print(f"Registered custom bandpass '{name}' from {filepath}") except Exception as e: print(f"Warning: Failed to load custom bandpass '{name}' from {filepath}: {e}") else: # List of file paths for filepath in bandpass_files: try: name, bandpass = load_bandpass_from_file(filepath) register_bandpass(name, bandpass, force=True) custom_bandpasses[name] = bandpass print(f"Registered custom bandpass '{name}' from {filepath}") except Exception as e: print(f"Warning: Failed to load custom bandpass from {filepath}: {e}") return custom_bandpasses
[docs] def register_all_bandpasses(custom_bandpass_files=None, svo_filters=None): """Register bandpasses in JAX and return dictionaries of bandpasses and bridges. Parameters ---------- custom_bandpass_files : list or dict, optional List of file paths to custom bandpass files, or a dictionary mapping bandpass names to file paths svo_filters : list, optional List of dictionaries containing SVO filter information. Each dictionary should have the following keys: - 'name': Name to register the bandpass under - 'filter_id': SVO filter identifier (e.g., 'UKIRT/WFCAM.J') - 'variants': Optional list of variant names to register using the same bandpass Returns ------- tuple A tuple containing: - bandpass_dict: Dictionary mapping bandpass names to Bandpass objects - bridges_dict: Dictionary mapping bandpass names to precomputed bridge data """ from jax_supernovae.salt3 import precompute_bandflux_bridge bandpass_info = [ # ZTF bandpasses {'name': 'ztfg', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/ztf/P48_g.dat'), 'skiprows': 1}, {'name': 'ztfr', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/ztf/P48_R.dat'), 'skiprows': 1}, # ATLAS bandpasses {'name': 'c', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/atlas/Atlas.Cyan'), 'skiprows': 0}, {'name': 'o', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/atlas/Atlas.Orange'), 'skiprows': 0}, # SDSS bandpasses {'name': 'g', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/sdss/sdss_g.dat'), 'skiprows': 0}, {'name': 'r', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/sdss/sdss_r.dat'), 'skiprows': 0}, {'name': 'i', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/sdss/sdss_i.dat'), 'skiprows': 0}, {'name': 'z', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/sdss/sdss_z.dat'), 'skiprows': 0}, # 2MASS bandpasses {'name': 'H', 'file': os.path.join(PACKAGE_DIR, 'sncosmo-modelfiles/bandpasses/2mass/2mass.H'), 'skiprows': 0}, ] # Load commonly used bands from sncosmo (Bessell filters) sncosmo_bands = ['bessellb', 'bessellv', 'bessellr', 'besselli', 'bessellux'] bandpass_dict = {} bridges_dict = {} # Load standard bandpasses for info in bandpass_info: try: data = np.loadtxt(info['file'], skiprows=info['skiprows']) wave, trans = data[:, 0], data[:, 1] jax_bandpass = Bandpass(wave, trans, name=info['name']) register_bandpass(info['name'], jax_bandpass, force=True) bandpass_dict[info['name']] = jax_bandpass bridges_dict[info['name']] = precompute_bandflux_bridge(jax_bandpass) except Exception as e: print(f"Warning: Failed to load bandpass {info['name']}: {e}") # Load Bessell and other common bands from sncosmo try: import sncosmo for band_name in sncosmo_bands: try: snc_bandpass = sncosmo.get_bandpass(band_name) jax_bandpass = Bandpass(snc_bandpass.wave, snc_bandpass.trans, name=band_name) register_bandpass(band_name, jax_bandpass, force=True) bandpass_dict[band_name] = jax_bandpass bridges_dict[band_name] = precompute_bandflux_bridge(jax_bandpass) except Exception as e: print(f"Warning: Failed to load bandpass {band_name} from sncosmo: {e}") except ImportError: print("Warning: sncosmo not available, skipping Bessell filter registration") # Load SVO filter bandpasses if provided if svo_filters: for filter_info in svo_filters: try: # Create and register the main bandpass bandpass = create_bandpass_from_svo(filter_info['filter_id']) # Register the main bandpass register_bandpass(filter_info['name'], bandpass, force=True) bandpass_dict[filter_info['name']] = bandpass bridges_dict[filter_info['name']] = precompute_bandflux_bridge(bandpass) print(f"Registered {filter_info['name']} bandpass from SVO Filter Profile Service") # Register variants if any if 'variants' in filter_info and filter_info['variants']: for variant in filter_info['variants']: register_bandpass(variant, bandpass, force=True) bandpass_dict[variant] = bandpass bridges_dict[variant] = precompute_bandflux_bridge(bandpass) print(f"Registered {len(filter_info['variants'])} variants of {filter_info['name']} bandpass") except Exception as e: print(f"Warning: Failed to create {filter_info['name']} bandpass from SVO: {e}") # Load custom bandpasses if provided if custom_bandpass_files: custom_bandpasses = load_custom_bandpasses(custom_bandpass_files) for name, bandpass in custom_bandpasses.items(): bandpass_dict[name] = bandpass bridges_dict[name] = precompute_bandflux_bridge(bandpass) return bandpass_dict, bridges_dict