jax.experimental.pallas module#

Module for Pallas, a JAX extension for custom kernels.

See the Pallas documentation at https://jax.readthedocs.io/en/latest/pallas.html.

Backends#

Classes#

BlockSpec([block_shape, index_map, ...])

Specifies how an array should be sliced for each invocation of a kernel.

GridSpec([grid, in_specs, out_specs, ...])

Encodes the grid parameters for jax.experimental.pallas.pallas_call().

Slice(start, size[, stride])

A slice with a start index and a size.

MemoryRef(shape, dtype, memory_space)

Like jax.ShapeDtypeStruct but with memory spaces.

Functions#

pallas_call(kernel, out_shape, *[, ...])

Invokes a Pallas kernel on some inputs.

program_id(axis)

Returns the kernel execution position along the given axis of the grid.

num_programs(axis)

Returns the size of the grid along the given axis.

load(x_ref_or_view, idx, *[, mask, other, ...])

Returns an array loaded from the given index.

store(x_ref_or_view, idx, val, *[, mask, ...])

Stores a value at the given index.

swap(x_ref_or_view, idx, val, *[, mask, ...])

Swaps the value at the given index and returns the old value.

atomic_and(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] &= val.

atomic_add(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] += val.

atomic_cas(ref, cmp, val)

Performs an atomic compare-and-swap of the value in the ref with the given value.

atomic_max(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] = max(x_ref_or_view[idx], val).

atomic_min(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] = min(x_ref_or_view[idx], val).

atomic_or(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] |= val.

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

Atomically exchanges the given value with the value at the given index.

atomic_xor(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] ^= val.

broadcast_to(a, shape)

debug_print(fmt, *args)

Prints values from inside a Pallas kernel.

dot(a, b[, trans_a, trans_b, allow_tf32, ...])

max_contiguous(x, values)

multiple_of(x, values)

run_scoped(f, *types, **kw_types)

Calls the function with allocated references and returns the result.

when(condition)