jax.experimental.enable_x64#

jax.experimental.enable_x64(new_val=True)[source]#

Experimental context manager to temporarily enable X64 mode.

Warning

This context manager remains experimental because it is fundamentally broken and can result in unexpected behavior, particularly when used in conjunction with JAX transformations like jax.jit(), jax.vmap(), jax.grad(), and others. See jax-ml/jax#5982 for details.

Usage:

>>> x = np.arange(5, dtype='float64')
>>> with enable_x64():
...   print(jnp.asarray(x).dtype)
...
float64

See also

jax.experimental.disable_x64

temporarily disable X64 mode.

Parameters:

new_val (bool)