jax.random.generalized_normal

jax.random.generalized_normal#

jax.random.generalized_normal(key, p, shape=(), dtype=None, *, out_sharding=None)[source]#

Sample from the generalized normal distribution.

The values are returned according to the probability density function:

\[f(x;p) \propto e^{-|x|^p}\]

on the domain \(-\infty < x < \infty\), where \(p > 0\) is the shape parameter.

Parameters:
  • key (ArrayLike) – a PRNG key used as the random key.

  • p (float) – a float representing the shape parameter.

  • shape (Shape) – optional, the batch dimensions of the result. Default ().

  • dtype (DTypeLikeFloat | None) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding (NamedSharding | P | None) – optional, specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

Return type:

Array