Installation
This section provides essential instructions for installing JAX-bandflux. Python 3.10+ is required. Core dependencies include JAX (0.4.20+), NumPy (1.24.0+), Astropy (5.0+), and SNCosmo (2.9.0+). SALT3/SALT3-NIR model files ship with the package; no extra download is needed.
CPU vs CUDA wheels
JAX-bandflux does not force a CUDA dependency. Choose the JAX wheel that matches your hardware:
CPU:
pip install jax-bandflux pip install --upgrade "jax[cpu]"
CUDA (example for CUDA 12):
pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install jax-bandflux
or in one go:
pip install "jax-bandflux[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For other CUDA versions, see the JAX installation guide and pick the matching wheel for your driver/toolkit.
Nested sampling extras
Optional dependencies for the nested sampling examples:
pip install "jax-bandflux[nested]"
Development install
git clone https://github.com/samleeney/JAX-bandflux.git
cd JAX-bandflux
pip install -e ".[dev,nested,docs]"
Verification
To verify that JAX-bandflux is installed correctly:
python -c "import jax_supernovae; print('JAX-bandflux successfully installed')"
This command should display a success message if JAX-bandflux is installed correctly.