jax.numpy.triu#

jax.numpy.triu(m, k=0)[source]#

Return upper triangle of an array.

JAX implementation of numpy.triu()

Parameters:
  • m (ArrayLike) – input array. Must have m.ndim >= 2.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal below which the elements of the array are set to zero. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

Returns:

An array with same shape as input containing the upper triangle of the given array with elements below the sub-diagonal specified by k are set to zero.

Return type:

Array

See also

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9],
...                [10, 11, 12]])
>>> jnp.triu(x)
Array([[1, 2, 3],
       [0, 5, 6],
       [0, 0, 9],
       [0, 0, 0]], dtype=int32)
>>> jnp.triu(x, k=1)
Array([[0, 2, 3],
       [0, 0, 6],
       [0, 0, 0],
       [0, 0, 0]], dtype=int32)
>>> jnp.triu(x, k=-1)
Array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 0,  8,  9],
       [ 0,  0, 12]], dtype=int32)

When m.ndim > 2, jnp.triu operates batch-wise on the trailing axes.

>>> x1 = jnp.array([[[1, 2],
...                  [3, 4]],
...                 [[5, 6],
...                  [7, 8]]])
>>> jnp.triu(x1)
Array([[[1, 2],
        [0, 4]],

       [[5, 6],
        [0, 8]]], dtype=int32)