jax.numpy.logspace#

jax.numpy.logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0)[source]#

Generate logarithmically-spaced values.

JAX implementation of numpy.logspace().

Parameters:
  • start (ArrayLike) – scalar or array. Used to specify the start value. The start value is base ** start.

  • stop (ArrayLike) – scalar or array. Used to specify the stop value. The end value is base ** stop.

  • num (int) – int, optional, default=50. Number of values to generate.

  • endpoint (bool) – bool, optional, default=True. If True, then include the stop value in the result. If False, then exclude the stop value.

  • base (ArrayLike) – scalar or array, optional, default=10. Specifies the base of the logarithm.

  • dtype (DTypeLike | None | None) – optional. Specifies the dtype of the output.

  • axis (int) – int, optional, default=0. Axis along which to generate the logspace.

Returns:

An array of logarithm.

Return type:

Array

See also

Examples

List 5 logarithmically spaced values between 1 (10 ** 0) and 100 (10 ** 2):

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(0, 2, 5)
Array([  1.   ,   3.162,  10.   ,  31.623, 100.   ], dtype=float32)

List 5 logarithmically-spaced values between 1(10 ** 0) and 100 (10 ** 2), excluding endpoint:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(0, 2, 5, endpoint=False)
Array([ 1.   ,  2.512,  6.31 , 15.849, 39.811], dtype=float32)

List 7 logarithmically-spaced values between 1 (2 ** 0) and 4 (2 ** 2) with base 2:

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(0, 2, 7, base=2)
Array([1.   , 1.26 , 1.587, 2.   , 2.52 , 3.175, 4.   ], dtype=float32)

Multi-dimensional logspace:

>>> start = jnp.array([0, 5])
>>> stop = jnp.array([5, 0])
>>> base = jnp.array([2, 3])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.logspace(start, stop, 5, base=base)
Array([[  1.   , 243.   ],
       [  2.378,  61.547],
       [  5.657,  15.588],
       [ 13.454,   3.948],
       [ 32.   ,   1.   ]], dtype=float32)