jax.experimental.pallas.kernel

Contents

jax.experimental.pallas.kernel#

jax.experimental.pallas.kernel(body=<not-specified>, out_type=(), *, mesh, scratch_types=(), compiler_params=None, interpret=False, cost_estimate=None, debug=False, name=None, metadata=None)[source]#

Entry point for creating a Pallas kernel.

This is a convenience wrapper around mpmd_map for executing a kernel over a mesh.

If body is provided, this function behaves as a decorator:

def kernel_body(in_ref, out_ref):
  ...
kernel = pl.kernel(kernel_body, out_type=...)

If body is omitted, this function behaves as a decorator factory and will return a decorator that can be used to annotate a kernel body:

@pl.kernel(mesh=..., out_type=...)
def kernel(in_ref, out_ref):
  ...

For MPMD kernels, you can pass parallel lists of bodies and meshes:

my_kernel = pl.kernel(
    body=[vector_fn, scalar_fn],
    mesh=[v_mesh, s_mesh],
    out_type=...
)

JAX Ref``s can be closed over by the kernel body or passed in as arguments. Any such ``Ref will be treated as if it is read-from and written-to and will be aliased in and out of the kernel.

@pl.kernel(mesh=...)
def kernel(in_ref, out_ref):
  ...
x_ref = jax.new_ref(...)
y_ref = jax.new_ref(...)
kernel(x_ref, y_ref)  # Can now mutate x_ref and y_ref
Parameters:
  • body (Callable | Sequence[Callable] | NotSpecified) – The body of the kernel. If provided, this function behaves as a decorator, and if omitted, this function behaves as a decorator factory. Can also be a sequence of callables to be paired with a sequence of meshes.

  • out_type (object | None) – The type of the output. Should be a PyTree of jax.ShapeDtypeStruct or JAX types.

  • mesh (Mesh | Sequence[Mesh]) – The mesh to run the kernel on. Must be a sequence of meshes if body is a sequence of callables.

  • scratch_types (Sequence[ScratchShape | ScratchShapeTree | None] | Mapping[str, ScratchShape | ScratchShapeTree]) – The shapes of the scratch arrays.

  • compiler_params (CompilerParams | None) – The compiler parameters to pass to the backend.

  • interpret (bool) – Whether to run the function in interpret mode.

  • debug (bool) – Whether or not to out helpful debugging information.

  • cost_estimate (CostEstimate | None) – The cost estimate of the function.

  • name (str | None) – The (optional) name of the kernel.

  • metadata (dict[str, str] | None) – Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.

Returns:

If body is provided, returns a function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure as out_type. If body is omitted, returns a decorator that can be used to annotate a kernel body.