jax.random.rademacher

Contents

jax.random.rademacher#

jax.random.rademacher(key, shape=(), dtype=None, *, out_sharding=None)[source]#

Sample from a Rademacher distribution.

The values are distributed according to the probability mass function:

\[f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))\]

on the domain \(k \in \{-1, 1\}\), where \(\delta(x)\) is the dirac delta function.

Parameters:
  • key (ArrayLike) – a PRNG key.

  • shape (Shape) – The shape of the returned samples. Default ().

  • dtype (DTypeLikeInt | None) – The type used for samples.

  • out_sharding (NamedSharding | P | None) – optional, specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A jnp.array of samples, of shape shape. Each element in the output has a 50% change of being 1 or -1.

Return type:

Array