jax.experimental.pallas.mosaic_gpu.kernel#
- jax.experimental.pallas.mosaic_gpu.kernel(body=<not-specified>, out_shape=<not-specified>, *, out_type=<not-specified>, scratch_types=<not-specified>, scratch_shapes=<not-specified>, compiler_params=None, grid=(), grid_names=(), cluster=(), cluster_names=(), num_threads=None, thread_name=None, interpret=None, debug=False, **mesh_kwargs)[source]#
Entry point for defining a Mosaic GPU kernel.
- Parameters:
body (Callable[..., None] | api.NotSpecified) β The kernel body, which should take as arguments the input, output, and scratch Refs. The number of input Refs is determined by the number of arguments passed into kernel returned by this function. The number of output and scratch Refs are determined by out_shape and scratch_shapes respectively.
out_shape (object | api.NotSpecified) β A deprecated alias for
out_type.out_type (object | api.NotSpecified) β The type of the output. Should be a PyTree of
jax.ShapeDtypeStructor JAX types.scratch_shapes (ScratchShapeTree | api.NotSpecified) β A deprecated alias for
scratch_types.scratch_types (ScratchShapeTree | api.NotSpecified) β The types of the scratch
Refs to allocate. Should be a PyTree ofjax.ShapeDtypeStructor JAX types.compiler_params (pallas_core.CompilerParams | None) β Additional compiler options. See the CompilerParams dataclass for more details.
grid (tuple[int, ...]) β A tuple of integers specifying the size of the kernel grid.
grid_names (tuple[str, ...]) β The axis names of the grid. Must be the same length as grid.
cluster (tuple[int, ...]) β A tuple of integers specifying the size of the kernel cluster.
cluster_names (tuple[str, ...]) β The axis names of the grid. Must be the same length as cluster.
num_threads (int | None) β The number of threads to launch per block. Note that these do not correspond to CUDA threads, but rather to warpgroups on Hopper and Blackwell GPUs.
thread_name (str | None) β The axis name used to query the thread index.
debug (bool) β Whether or not to output helpful debugging information.
**mesh_kwargs (Any) β Additional mesh kwargs. See Mesh for more details.
interpret (Any)
- Returns:
If
bodyis provided, returns a function that runs the kernel. It should take any number of input operands and returns an output with the same PyTree structure asout_shape.If
bodyis omitted, returns a decorator that can be used to annotate a kernel body.- Return type:
Any