jax.lax.optimization_barrier#

jax.lax.optimization_barrier(operand, /)[source]#

Prevents the compiler from moving operations across the barrier.

Optimization barriers have a number of possible uses:

  • An optimization barrier ensures that all inputs are evaluated before any operators that depend on the barrier’s outputs. This can be used to enforce a particular order of operations.

  • An optimization barrier prevents common subexpression elimination. This is used by JAX to implement rematerialization.

  • Optimization barriers prevent compiler fusions. That is, operations before the barrier may not be fused into the same kernel as operations after the barrier by the compiler.

JAX does not define derivative or batching rules for an optimization barrier.

Optimization barriers have no effect outside a compiled function.

Parameters:

operand – a pytree of JAX values.

Returns:

A pytree of JAX values, with the same structure and contents as operand.

Examples

Prevents common-subexpression elimination between the two calls to sin:

>>> def f(x):
...   return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
>>> jax.jit(f)(0.)
Array(0., dtype=float32, weak_type=True)