Getting Started with JAX#
Welcome to JAX! The JAX documentation contains a number of useful resources for getting started. Quickstart is the easiest place to jump-in and get an overview of the JAX project.
If you’re accustomed to writing NumPy code and are starting to explore JAX, you might find the following resources helpful:
How to think in JAX is a conceptual walkthrough of JAX’s execution model.
🔪 JAX - The Sharp Bits 🔪 lists some of JAX’s sharp corners.
Frequently asked questions (FAQ) answers some frequent jax questions.
Tutorials#
If you’re ready to explore JAX more deeply, the JAX tutorials go into much more detail:
- Tutorials
- Quickstart
- Key concepts
- Just-in-time compilation
- Automatic vectorization
- Automatic differentiation
- Introduction to debugging
- Pseudorandom numbers
- Working with pytrees
- Introduction to parallel programming
- Stateful computations
- Control flow and logical operators with JIT
- Advanced automatic differentiation
- External callbacks
- Gradient checkpointing with
jax.checkpoint
(jax.remat
) - JAX Internals: primitives
- JAX internals: The jaxpr language
If you prefer a video introduction here is one from JAX contributor Jake VanderPlas:
Building on JAX#
JAX provides the core numerical computing primitives for a number of tools developed by the larger community. For example, if you’re interested in using JAX for training neural networks, two well-supported options are Flax and Haiku.
For a community-curated list of JAX-related projects across a wide set of domains, check out Awesome JAX.
Finding Help#
If you have questions about JAX, we’d love to answer them! Two good places to get your questions answered are: