jax.numpy.linspace#
- jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)[source]#
Return evenly-spaced numbers within an interval.
JAX implementation of
numpy.linspace()
.- Parameters:
start (ArrayLike) – scalar or array of starting values.
stop (ArrayLike) – scalar or array of stop values.
num (int) – number of values to generate. Default: 50.
endpoint (bool) – if True (default) then include the
stop
value in the result. If False, then exclude thestop
value.retstep (bool) – If True, then return a
(result, step)
tuple, wherestep
is the interval between adjacent values inresult
.axis (int) – integer axis along which to generate the linspace. Defaults to zero.
device (xc.Device | Sharding | None | None) – optional
Device
orSharding
to which the created array will be committed.dtype (DTypeLike | None | None)
- Returns:
values
is an array of evenly-spaced values fromstart
tostop
step
is the interval between adjacent values.
- Return type:
An array
values
, or a tuple(values, step)
ifretstep
is True, where
See also
jax.numpy.arange()
: GenerateN
evenly-spaced values given a starting point and a stepjax.numpy.logspace()
: Generate logarithmically-spaced values.jax.numpy.geomspace()
: Generate geometrically-spaced values.
Examples
List of 5 values between 0 and 10:
>>> jnp.linspace(0, 10, 5) Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32)
List of 8 values between 0 and 10, excluding the endpoint:
>>> jnp.linspace(0, 10, 8, endpoint=False) Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32)
List of values and the step size between them
>>> vals, step = jnp.linspace(0, 10, 9, retstep=True) >>> vals Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) >>> step Array(1.25, dtype=float32)
Multi-dimensional linspace:
>>> start = jnp.array([0, 5]) >>> stop = jnp.array([5, 10]) >>> jnp.linspace(start, stop, 5) Array([[ 0. , 5. ], [ 1.25, 6.25], [ 2.5 , 7.5 ], [ 3.75, 8.75], [ 5. , 10. ]], dtype=float32)