jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[source]#
Return a transposed version of an N-dimensional array.
JAX implementation of
numpy.transpose()
, implemented in terms ofjax.lax.transpose()
.- Parameters:
a (ArrayLike) – input array
axes (Sequence[int] | None | None) – optionally specify the permutation using a length-a.ndim sequence of integers
i
satisfying0 <= i < a.ndim
. Defaults torange(a.ndim)[::-1]
, i.e. reverses the order of all axes.
- Returns:
transposed copy of the array.
- Return type:
See also
jax.Array.transpose()
: equivalent function via anArray
method.jax.Array.T
: equivalent function via anArray
property.jax.numpy.matrix_transpose()
: transpose the last two axes of an array. This is suitable for working with batched 2D matrices.jax.numpy.swapaxes()
: swap any two axes in an array.jax.numpy.moveaxis()
: move an axis to another position in the array.
Note
Unlike
numpy.transpose()
,jax.numpy.transpose()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.Examples
For a 1D array, the transpose is the identity:
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
For a 2D array, the transpose is a matrix transpose:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
For an N-dimensional array, the transpose reverses the order of the axes:
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
The
axes
argument can be specified to change this default behavior:>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
Since swapping the last two axes is a common operation, it can be done via its own API,
jax.numpy.matrix_transpose()
:>>> jnp.matrix_transpose(x).shape (3, 5, 4)
For convenience, transposes may also be performed using the
jax.Array.transpose()
method or thejax.Array.T
property:>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)