jax.random.wrap_key_data#
- jax.random.wrap_key_data(key_bits_array, *, impl=None)[source]#
Wrap an array of key data bits into a PRNG key array.
- Parameters:
key_bits_array (Array) – a
uint32
array with trailing shape corresponding to the key shape of the PRNG implementation specified byimpl
.impl (PRNGSpecDesc | None | None) – optional, specifies a PRNG implementation, as in
random.key
.
- Returns:
- A PRNG key array, whose dtype is a subdtype of
jax.dtypes.prng_key
corresponding to
impl
, and whose shape equals the leading shape ofkey_bits_array.shape
up to the key bit dimensions.
- A PRNG key array, whose dtype is a subdtype of