jax.numpy.shape#

jax.numpy.shape(a)[source]#

Return the shape an array.

JAX implementation of numpy.shape(). Unlike np.shape, this function raises a TypeError if the input is a collection such as a list or tuple.

Parameters:

a (ArrayLike) – array-like object.

Returns:

An tuple of integers representing the shape of a.

Return type:

tuple[int, …]

Examples

Shape for arrays:

>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)

This also works for scalars:

>>> jnp.shape(3.14)
()

For arrays, this can also be accessed via the jax.Array.shape property:

>>> x.shape
(10,)