jax.numpy.atleast_1d#

jax.numpy.atleast_1d(*arys)[source]#

Convert inputs to arrays with at least 1 dimension.

JAX implementation of numpy.atleast_1d().

Parameters:
  • arguments. (zero or more arraylike)

  • arys (ArrayLike)

Returns:

an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1,), and arrays with other shapes are returned unchanged.

Return type:

Array | list[Array]

Examples

Scalar arguments are converted to 1D, length-1 arrays:

>>> x = jnp.float32(1.0)
>>> jnp.atleast_1d(x)
Array([1.], dtype=float32)

Higher dimensional inputs are returned unchanged:

>>> y = jnp.arange(4)
>>> jnp.atleast_1d(y)
Array([0, 1, 2, 3], dtype=int32)

Multiple arguments can be passed to the function at once, in which case a list of results is returned:

>>> jnp.atleast_1d(x, y)
[Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]