jax.experimental.pallas.triton.elementwise_inline_asm#

jax.experimental.pallas.triton.elementwise_inline_asm(asm, *, args, constraints, pack, result_shape_dtypes)[source]#

Inline assembly applying an elementwise operation.

Parameters:
  • asm (str) – The assembly code to run.

  • args (Sequence[jax.Array]) – The arguments to pass to the assembly code.

  • constraints (str) – LLVM inline assembly constraints.

  • pack (int) – The number of elements from each argument expected by a single instance of the assembly code.

  • result_shape_dtypes (Sequence[jax.ShapeDtypeStruct]) – The shapes and dtypes of the results produced by the assembly code.

Returns:

The results produced by the assembly code.

Return type:

Sequence[jax.Array]