jax.experimental.pallas.pallas_call#
- jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=(), input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None, backend=None)[source]#
Invokes a Pallas kernel on some inputs.
See Pallas Quickstart.
- Parameters:
kernel (Callable[..., None]) – the kernel function, that receives a Ref for each input and output. The shape of the Refs are given by the
block_shape
in the correspondingin_specs
andout_specs
.out_shape (Any) – a PyTree of
jax.ShapeDtypeStruct
describing the shape and dtypes of the outputs.grid_spec (GridSpec | None | None) – An alternative way to specify
grid
,in_specs
,out_specs
andscratch_shapes
. If given, those other parameters must not be also given.grid (TupleGrid) – the iteration space, as a tuple of integers. The kernel is executed as many times as
prod(grid)
. See details at grid, a.k.a. kernels in a loop.in_specs (BlockSpecTree) – a PyTree of
jax.experimental.pallas.BlockSpec
with a structure matching that of the positional arguments. The default value forin_specs
specifies the whole array for all inputs, e.g., aspl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)
. See details at BlockSpec, a.k.a. how to chunk up inputs.out_specs (BlockSpecTree) – a PyTree of
jax.experimental.pallas.BlockSpec
with a structure matching that of the outputs. The default value forout_specs
specifies the whole array, e.g., aspl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)
. See details at BlockSpec, a.k.a. how to chunk up inputs.scratch_shapes (ScratchShapeTree) – a PyTree of backend-specific temporary objects required by the kernel, such as temporary buffers, synchronization primitives, etc.
input_output_aliases (Mapping[int, int]) – a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs.
debug (bool) – if True, Pallas prints various intermediate forms of the kernel as it is being processed.
interpret (bool) – runs the
pallas_call
as ajax.jit
of a scan over the grid whose body is the kernel lowered as a JAX function. This does not require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. This is useful for debugging.name (str | None | None) – if present, specifies the name to use for this kernel call in debugging and error messages. To this name we append the file and line where the kernel function is defined, .e.g: {name} for kernel function {kernel_name} at {file}:{line}. If missing, then we use {kernel_name} at {file}:{line}.
compiler_params (Mapping[Backend, CompilerParams] | CompilerParams | None | None) – Optional compiler parameters. The value should either be a backend-specific dataclass (
jax.experimental.pallas.tpu.TPUCompilerParams
,jax.experimental.pallas.triton.TritonCompilerParams
,jax.experimental.pallas.mosaic_gpu.GPUCompilerParams
) or a dict mapping backend name to the corresponding platform-specific dataclass.backend (Backend | None | None) – Optional string literal one of
"mosaic_tpu"
,"triton"
or"mosaic_gpu"
determining the backend to be used. None means let Pallas decide.cost_estimate (CostEstimate | None | None)
- Returns:
A function that can be called on a number of positional array arguments to invoke the Pallas kernel.
- Return type:
Callable[…, Any]