jax.numpy.atleast_1d#
- jax.numpy.atleast_1d(*arys)[source]#
Convert inputs to arrays with at least 1 dimension.
JAX implementation of
numpy.atleast_1d()
.- Parameters:
arguments. (zero or more arraylike)
arys (ArrayLike)
- Returns:
an array or list of arrays corresponding to the input values. Arrays of shape
()
are converted to shape(1,)
, and arrays with other shapes are returned unchanged.- Return type:
Examples
Scalar arguments are converted to 1D, length-1 arrays:
>>> x = jnp.float32(1.0) >>> jnp.atleast_1d(x) Array([1.], dtype=float32)
Higher dimensional inputs are returned unchanged:
>>> y = jnp.arange(4) >>> jnp.atleast_1d(y) Array([0, 1, 2, 3], dtype=int32)
Multiple arguments can be passed to the function at once, in which case a list of results is returned:
>>> jnp.atleast_1d(x, y) [Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]