jax.nn.initializers.constant#
- jax.nn.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)[source]#
Builds an initializer that returns arrays full of a constant
value
.- Parameters:
value (ArrayLike) – the constant value with which to fill the initializer.
dtype (DTypeLikeInexact) – optional; the initializer’s default dtype.
- Return type:
Initializer
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32)