jax.experimental.pallas.triton module#

Triton-specific Pallas APIs.

Classes#

TritonCompilerParams([num_warps, ...])

Compiler parameters for Triton.

Functions#

approx_tanh(x)

Elementwise approximate hyperbolic tangent: \(\mathrm{tanh}(x)\).

debug_barrier()

Synchronizes all kernel executions in the grid.

elementwise_inline_asm(asm, *, args, ...)

Inline assembly applying an elementwise operation.