jax.experimental.pallas.triton.TritonCompilerParams#
- class jax.experimental.pallas.triton.TritonCompilerParams(num_warps=None, num_stages=None, serialized_metadata=None)[source]#
Compiler parameters for Triton.
- num_warps#
The number of warps to use for the kernel. Each warp consists of 32 threads.
- Type:
int | None
- num_stages#
The number of stages the compiler should use for software pipelining loops.
- Type:
int | None
- serialized_metadata#
Additional compiler metadata. This field is unstable and may be removed in the future.
- Type:
bytes | None
- __init__(num_warps=None, num_stages=None, serialized_metadata=None)#
Methods
__init__
([num_warps, num_stages, ...])Attributes
PLATFORM