jax.experimental.pallas.pallas_call#

jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=(), input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None, backend=None)[source]#

Invokes a Pallas kernel on some inputs.

See Pallas Quickstart.

Parameters:
  • kernel (Callable[..., None]) – the kernel function, that receives a Ref for each input and output. The shape of the Refs are given by the block_shape in the corresponding in_specs and out_specs.

  • out_shape (Any) – a PyTree of jax.ShapeDtypeStruct describing the shape and dtypes of the outputs.

  • grid_spec (GridSpec | None | None) – An alternative way to specify grid, in_specs, out_specs and scratch_shapes. If given, those other parameters must not be also given.

  • grid (TupleGrid) – the iteration space, as a tuple of integers. The kernel is executed as many times as prod(grid). See details at grid, a.k.a. kernels in a loop.

  • in_specs (BlockSpecTree) – a PyTree of jax.experimental.pallas.BlockSpec with a structure matching that of the positional arguments. The default value for in_specs specifies the whole array for all inputs, e.g., as pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim). See details at BlockSpec, a.k.a. how to chunk up inputs.

  • out_specs (BlockSpecTree) – a PyTree of jax.experimental.pallas.BlockSpec with a structure matching that of the outputs. The default value for out_specs specifies the whole array, e.g., as pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim). See details at BlockSpec, a.k.a. how to chunk up inputs.

  • scratch_shapes (ScratchShapeTree) – a PyTree of backend-specific temporary objects required by the kernel, such as temporary buffers, synchronization primitives, etc.

  • input_output_aliases (Mapping[int, int]) – a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs.

  • debug (bool) – if True, Pallas prints various intermediate forms of the kernel as it is being processed.

  • interpret (bool) – runs the pallas_call as a jax.jit of a scan over the grid whose body is the kernel lowered as a JAX function. This does not require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. This is useful for debugging.

  • name (str | None | None) – if present, specifies the name to use for this kernel call in debugging and error messages. To this name we append the file and line where the kernel function is defined, .e.g: {name} for kernel function {kernel_name} at {file}:{line}. If missing, then we use {kernel_name} at {file}:{line}.

  • compiler_params (Mapping[Backend, CompilerParams] | CompilerParams | None | None) – Optional compiler parameters. The value should either be a backend-specific dataclass (jax.experimental.pallas.tpu.TPUCompilerParams, jax.experimental.pallas.triton.TritonCompilerParams, jax.experimental.pallas.mosaic_gpu.GPUCompilerParams) or a dict mapping backend name to the corresponding platform-specific dataclass.

  • backend (Backend | None | None) – Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu" determining the backend to be used. None means let Pallas decide.

  • cost_estimate (CostEstimate | None | None)

Returns:

A function that can be called on a number of positional array arguments to invoke the Pallas kernel.

Return type:

Callable[…, Any]