JAX reference documentation¶
JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.
Getting Started
Reference Documentation
Advanced JAX Tutorials
- The Autodiff Cookbook
- Autobatching log-densities example
- Training a Simple Neural Network, with tensorflow/datasets Data Loading
- Custom derivative rules for JAX-transformable Python functions
- How JAX primitives work
- Writing custom Jaxpr interpreters in JAX
- Training a Simple Neural Network, with PyTorch Data Loading
- XLA in Python
- MAML Tutorial with JAX
- Generative Modeling by Estimating Gradients of Data Distribution in JAX
- Named axes and easy-to-revise parallelism
- Using JAX in multi-host and multi-process environments
Notes
Developer documentation