jax.numpy.ndim#

jax.numpy.ndim(a)[source]#

Return the number of dimensions of an array.

JAX implementation of numpy.ndim(). Unlike np.ndim, this function raises a TypeError if the input is a collection such as a list or tuple.

Parameters:

a (ArrayLike) – array-like object.

Returns:

An integer specifying the number of dimensions of a.

Return type:

int

Examples

Number of dimensions for arrays:

>>> x = jnp.arange(10)
>>> jnp.ndim(x)
1
>>> y = jnp.ones((2, 3))
>>> jnp.ndim(y)
2

This also works for scalars:

>>> jnp.ndim(3.14)
0

For arrays, this can also be accessed via the jax.Array.ndim property:

>>> x.ndim
1