jax.experimental.pallas.triton.TritonCompilerParams#

class jax.experimental.pallas.triton.TritonCompilerParams(num_warps=None, num_stages=None, serialized_metadata=None)[source]#

Compiler parameters for Triton.

Parameters:
  • num_warps (int | None)

  • num_stages (int | None)

  • serialized_metadata (bytes | None)

num_warps#

The number of warps to use for the kernel. Each warp consists of 32 threads.

Type:

int | None

num_stages#

The number of stages the compiler should use for software pipelining loops.

Type:

int | None

serialized_metadata#

Additional compiler metadata. This field is unstable and may be removed in the future.

Type:

bytes | None

__init__(num_warps=None, num_stages=None, serialized_metadata=None)#
Parameters:
  • num_warps (int | None | None)

  • num_stages (int | None | None)

  • serialized_metadata (bytes | None | None)

Return type:

None

Methods

__init__([num_warps, num_stages, ...])

Attributes

PLATFORM

num_stages

num_warps

serialized_metadata