jax.experimental.pallas.num_programs# jax.experimental.pallas.num_programs(axis)[source]# Returns the size of the grid along the given axis. Parameters: axis (int) Return type: int | jax.Array