JAX-bandflux Documentation
JAX-bandflux: Differentiable Supernova Light Curve Modeling
JAX-bandflux is a Python package that implements supernova light curve modeling using JAX. The codebase offers a differentiable approach to core SNCosmo functionality, enabling efficient gradient-based optimization and GPU acceleration for supernova cosmology research.
Why JAX-bandflux?
JAX-bandflux provides several key advantages over traditional supernova light curve modeling frameworks:
Differentiability: By implementing the SALT3 model in JAX, JAX-bandflux enables automatic differentiation of the entire modeling pipeline, allowing for efficient gradient-based optimization.
Performance: JAX’s just-in-time (JIT) compilation and GPU acceleration provide significant performance improvements, especially for large-scale analyses involving many supernovae.
Flexibility: The modular design allows for easy customization of bandpasses, models, and optimization strategies.
Compatibility: JAX-bandflux maintains compatibility with existing SNCosmo data formats and models, making it easy to integrate into existing workflows.
Research-Friendly: The codebase is designed with research in mind, providing tools for both standard analyses and novel approaches to supernova cosmology.
Key Features
Differentiable implementation of SALT3 model for supernova light curves
GPU-accelerated flux calculations using JAX
Flexible bandpass management for various astronomical filters
Efficient data loading and processing routines
Support for gradient-based optimization and nested sampling
Comprehensive documentation and examples
Package Structure
JAX-bandflux is organized into several key components:
jax_supernovae/
├── salt3.py # SALT3 model implementation
├── bandpasses.py # Bandpass management
├── data.py # Data loading and processing
├── utils.py # Utility functions
├── constants.py # Physical constants
├── data/ # Example data files
└── sncosmo-modelfiles/ # Model and bandpass files
The following diagram shows the relationships between key components:
Support and Contributing
See the top-level CONTRIBUTING.md for issue/PR guidelines.
For help, open a GitHub issue with a minimal example and your environment (Python/JAX/JAXlib, CPU vs GPU, CUDA version).
Quick Links
Installation - Installation instructions
Quickstart - Get started quickly with basic examples
API Differences from SNCosmo - How JAX-bandflux differs from SNCosmo
TimeSeriesSource - Custom SED models with TimeSeriesSource
Data Loading - Learn how to load and process supernova data
Bandpass Loading - Working with astronomical filters and bandpasses
Generating Model Fluxes - Computing model fluxes using the SALT3 model
Dust Extinction - Applying dust extinction to supernova models
Sampling - Techniques for parameter estimation and sampling
Getting Help
If you encounter any issues or have questions about JAX-bandflux, please:
Check the Installation and Quickstart guides
Review the documentation sections for specific functionality
Look for similar issues in the GitHub repository
Open a new issue if your problem hasn’t been addressed