jax.experimental.pallas.num_programs#

jax.experimental.pallas.num_programs(axis)[source]#

Returns the size of the grid along the given axis.

Parameters:

axis (int)

Return type:

int | jax.Array