Pallas: a JAX kernel language#

Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. It aims to provide fine-grained control over the generated code, combined with the high-level ergonomics of JAX tracing and the jax.numpy API.

This section contains tutorials, guides and examples for using Pallas. See also the jax.experimental.pallas module API documentation.

Warning

Pallas is experimental and is changing frequently. See the Pallas Changelog for the recent changes.

You can expect to encounter errors and unimplemented cases, e.g., when lowering of high-level JAX concepts that would require emulation, or simply because Pallas is still under development.