jax.lax.reshape#
- jax.lax.reshape(operand, new_sizes, dimensions=None, *, out_sharding=None)[source]#
Wraps XLA’s Reshape operator.
For inserting/removing dimensions of size 1, prefer using
lax.squeeze
/lax.expand_dims
. These preserve information about axis identity that may be useful for advanced transformation rules.- Parameters:
operand (ArrayLike) – array to be reshaped.
new_sizes (Shape) – sequence of integers specifying the resulting shape. The size of the final array must match the size of the input.
dimensions (Sequence[int] | None | None) – optional sequence of integers specifying the permutation order of the input shape. If specified, the length must match
operand.shape
.out_sharding (NamedSharding | P | None | None)
- Returns:
reshaped array.
- Return type:
out
Examples
Simple reshaping from one to two dimensions:
>>> x = jnp.arange(6) >>> y = reshape(x, (2, 3)) >>> y Array([[0, 1, 2], [3, 4, 5]], dtype=int32)
Reshaping back to one dimension:
>>> reshape(y, (6,)) Array([0, 1, 2, 3, 4, 5], dtype=int32)
Reshaping to one dimension with permutation of dimensions:
>>> reshape(y, (6,), (1, 0)) Array([0, 3, 1, 4, 2, 5], dtype=int32)