JAX-bandflux Documentation

JAX-bandflux: Differentiable Supernova Light Curve Modeling

PyPI version

JAX-bandflux is a Python package that implements supernova light curve modeling using JAX. The codebase offers a differentiable approach to core SNCosmo functionality, enabling efficient gradient-based optimization and GPU acceleration for supernova cosmology research.

Why JAX-bandflux?

JAX-bandflux provides several key advantages over traditional supernova light curve modeling frameworks:

  • Differentiability: By implementing the SALT3 model in JAX, JAX-bandflux enables automatic differentiation of the entire modeling pipeline, allowing for efficient gradient-based optimization.

  • Performance: JAX’s just-in-time (JIT) compilation and GPU acceleration provide significant performance improvements, especially for large-scale analyses involving many supernovae.

  • Flexibility: The modular design allows for easy customization of bandpasses, models, and optimization strategies.

  • Compatibility: JAX-bandflux maintains compatibility with existing SNCosmo data formats and models, making it easy to integrate into existing workflows.

  • Research-Friendly: The codebase is designed with research in mind, providing tools for both standard analyses and novel approaches to supernova cosmology.

Key Features

  • Differentiable implementation of SALT3 model for supernova light curves

  • GPU-accelerated flux calculations using JAX

  • Flexible bandpass management for various astronomical filters

  • Efficient data loading and processing routines

  • Support for gradient-based optimization and nested sampling

  • Comprehensive documentation and examples

Package Structure

JAX-bandflux is organized into several key components:

jax_supernovae/
├── salt3.py         # SALT3 model implementation
├── bandpasses.py    # Bandpass management
├── data.py          # Data loading and processing
├── utils.py         # Utility functions
├── constants.py     # Physical constants
├── data/            # Example data files
└── sncosmo-modelfiles/ # Model and bandpass files

The following diagram shows the relationships between key components:

graph TD A[JAX-bandflux] --> B[SALT3 Model] A --> C[Bandpass Management] A --> D[Data Handling] A --> E[Optimization Methods] B --> F[salt3_bandflux] B --> G[salt3_flux] C --> H[Built-in Bandpasses] C --> I[Custom Bandpasses] C --> J[SVO Integration] D --> K[HSF Data Format] D --> L[Redshift Handling] E --> M[L-BFGS-B] E --> N[Nested Sampling]

Contents:

Support and Contributing

  • See the top-level CONTRIBUTING.md for issue/PR guidelines.

  • For help, open a GitHub issue with a minimal example and your environment (Python/JAX/JAXlib, CPU vs GPU, CUDA version).

Getting Help

If you encounter any issues or have questions about JAX-bandflux, please:

  • Check the Installation and Quickstart guides

  • Review the documentation sections for specific functionality

  • Look for similar issues in the GitHub repository

  • Open a new issue if your problem hasn’t been addressed

Indices and tables