jax.numpy.reshape#
- jax.numpy.reshape(a, shape, order='C', *, copy=None)[source]#
Return a reshaped copy of an array.
JAX implementation of
numpy.reshape()
, implemented in terms ofjax.lax.reshape()
.- Parameters:
a (ArrayLike) – input array to reshape
shape (DimSize | Shape) – integer or sequence of integers giving the new shape, which must match the size of the input array. If any single dimension is given size
-1
, it will be replaced with a value such that the output has the correct size.order (str) –
'F'
or'C'
, specifies whether the reshape should apply column-major (fortran-style,"F"
) or row-major (C-style,"C"
) order; default is"C"
. JAX does not supportorder="A"
.copy (bool | None | None) – unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away.
- Returns:
reshaped copy of input array with the specified shape.
- Return type:
Notes
Unlike
numpy.reshape()
,jax.numpy.reshape()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.Array.reshape()
: equivalent functionality via an array method.jax.numpy.ravel()
: flatten an array into a 1D shape.jax.numpy.squeeze()
: remove one or more length-1 axes from an array’s shape.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
You can use
-1
to automatically compute a shape that is consistent with the input size:>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
The default ordering of axes in the reshape is C-style row-major ordering. To use Fortran-style column-major ordering, specify
order='F'
:>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
For convenience, this functionality is also available via the
jax.Array.reshape()
method:>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)