jax.experimental.pallas.mosaic_gpu.as_torch_kernel#
- jax.experimental.pallas.mosaic_gpu.as_torch_kernel(fn)[source]#
Makes a Mosaic GPU kernel callable with PyTorch tensors.
- Parameters:
fn – A JAX function that invokes a Mosaic GPU kernel. Note that the implementation currently only supports functions that contain a single Mosaic GPU kernel invocation, without any other JAX API calls, e.g. from
jax.numpy.- Returns:
A wrapper function that accepts PyTorch tensors as inputs and returns PyTorch tensors as outputs. The output tensors are allocated on the same device as the input tensors.
Example:
@plgpu.kernel(out_type=jax.ShapeDtypeStruct([128], jnp.int32) def kernel(x_gmem_ref, y_gmem_ref, o_gmem_ref): ... x = torch.arange(128, dtype=torch.int32, device="cuda") y = x * x out = plgpu.as_torch_kernel(kernel)(x, y)