jax.experimental.pallas.triton.elementwise_inline_asm#
- jax.experimental.pallas.triton.elementwise_inline_asm(asm, *, args, constraints, pack, result_shape_dtypes)[source]#
Inline assembly applying an elementwise operation.
- Parameters:
asm (str) – The assembly code to run.
args (Sequence[jax.Array]) – The arguments to pass to the assembly code.
constraints (str) – LLVM inline assembly constraints.
pack (int) – The number of elements from each argument expected by a single instance of the assembly code.
result_shape_dtypes (Sequence[jax.ShapeDtypeStruct]) – The shapes and dtypes of the results produced by the assembly code.
- Returns:
The results produced by the assembly code.
- Return type:
Sequence[jax.Array]