jax.experimental.pallas
module#
Module for Pallas, a JAX extension for custom kernels.
See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html.
Backends#
Classes#
|
Specifies how an array should be sliced for each invocation of a kernel. |
|
Encodes the grid parameters for |
|
A slice with a start index and a size. |
|
Like jax.ShapeDtypeStruct but with memory spaces. |
Functions#
|
Invokes a Pallas kernel on some inputs. |
|
Returns the kernel execution position along the given axis of the grid. |
|
Returns the size of the grid along the given axis. |
|
Returns an array loaded from the given index. |
|
Stores a value at the given index. |
|
Swaps the value at the given index and returns the old value. |
|
Atomically computes |
|
Atomically computes |
|
Performs an atomic compare-and-swap of the value in the ref with the given value. |
|
Atomically computes |
|
Atomically computes |
|
Atomically computes |
|
Atomically exchanges the given value with the value at the given index. |
|
Atomically computes |
|
|
|
Prints values from inside a Pallas kernel. |
|
|
|
|
|
|
|
Calls the function with allocated references and returns the result. |
|