jax.lax.axis_index#
- jax.lax.axis_index(axis_name)[source]#
Return the index along the mapped axis
axis_name
.- Parameters:
axis_name (AxisName) – hashable Python object used to name the mapped axis.
- Returns:
An integer representing the index.
- Return type:
For example, with 8 XLA devices available:
>>> from functools import partial >>> @partial(jax.pmap, axis_name='i') ... def f(_): ... return lax.axis_index('i') ... >>> f(jnp.zeros(4)) Array([0, 1, 2, 3], dtype=int32) >>> f(jnp.zeros(8)) Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) >>> @partial(jax.pmap, axis_name='i') ... @partial(jax.pmap, axis_name='j') ... def f(_): ... return lax.axis_index('i'), lax.axis_index('j') ... >>> x, y = f(jnp.zeros((4, 2))) >>> print(x) [[0 0] [1 1] [2 2] [3 3]] >>> print(y) [[0 1] [0 1] [0 1] [0 1]]