jax.numpy.linspace#

jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)[source]#

Return evenly-spaced numbers within an interval.

JAX implementation of numpy.linspace().

Parameters:
  • start (ArrayLike) – scalar or array of starting values.

  • stop (ArrayLike) – scalar or array of stop values.

  • num (int) – number of values to generate. Default: 50.

  • endpoint (bool) – if True (default) then include the stop value in the result. If False, then exclude the stop value.

  • retstep (bool) – If True, then return a (result, step) tuple, where step is the interval between adjacent values in result.

  • axis (int) – integer axis along which to generate the linspace. Defaults to zero.

  • device (xc.Device | Sharding | None | None) – optional Device or Sharding to which the created array will be committed.

  • dtype (DTypeLike | None | None)

Returns:

  • values is an array of evenly-spaced values from start to stop

  • step is the interval between adjacent values.

Return type:

An array values, or a tuple (values, step) if retstep is True, where

See also

Examples

List of 5 values between 0 and 10:

>>> jnp.linspace(0, 10, 5)
Array([ 0. ,  2.5,  5. ,  7.5, 10. ], dtype=float32)

List of 8 values between 0 and 10, excluding the endpoint:

>>> jnp.linspace(0, 10, 8, endpoint=False)
Array([0.  , 1.25, 2.5 , 3.75, 5.  , 6.25, 7.5 , 8.75], dtype=float32)

List of values and the step size between them

>>> vals, step = jnp.linspace(0, 10, 9, retstep=True)
>>> vals
Array([ 0.  ,  1.25,  2.5 ,  3.75,  5.  ,  6.25,  7.5 ,  8.75, 10.  ],      dtype=float32)
>>> step
Array(1.25, dtype=float32)

Multi-dimensional linspace:

>>> start = jnp.array([0, 5])
>>> stop = jnp.array([5, 10])
>>> jnp.linspace(start, stop, 5)
Array([[ 0.  ,  5.  ],
       [ 1.25,  6.25],
       [ 2.5 ,  7.5 ],
       [ 3.75,  8.75],
       [ 5.  , 10.  ]], dtype=float32)