jax.numpy.diag_indices#
- jax.numpy.diag_indices(n, ndim=2)[source]#
Return indices for accessing the main diagonal of a multidimensional array.
JAX implementation of
numpy.diag_indices()
.- Parameters:
- Returns:
A tuple of arrays, each of length n, containing the indices to access the main diagonal.
- Return type:
Examples
>>> jnp.diag_indices(3) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> jnp.diag_indices(4, ndim=3) (Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))