jax.random.rademacher#
- jax.random.rademacher(key, shape=(), dtype=None, *, out_sharding=None)[source]#
Sample from a Rademacher distribution.
The values are distributed according to the probability mass function:
\[f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))\]on the domain \(k \in \{-1, 1\}\), where \(\delta(x)\) is the dirac delta function.
- Parameters:
key (ArrayLike) – a PRNG key.
shape (Shape) – The shape of the returned samples. Default ().
dtype (DTypeLikeInt | None) – The type used for samples.
out_sharding (NamedSharding | P | None) – optional, specifies how the output array should be sharded across devices in multi-device computation. Can be a
NamedSharding, aPartitionSpec(P), orNone(default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.
- Returns:
A jnp.array of samples, of shape shape. Each element in the output has a 50% change of being 1 or -1.
- Return type: