jax.lax.stop_gradient#
- jax.lax.stop_gradient(x)[source]#
Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argument x unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them. For some discussion of where this is useful, refer to Stopping gradients.- Parameters:
x (T) – array or pytree of arrays
- Returns:
input value is returned unchanged, but within autodiff will be treated as a constant.
- Return type:
T
Examples
Consider a simple function that returns the square of the input value:
>>> def f1(x): ... return x ** 2 >>> x = jnp.float32(3.0) >>> f1(x) Array(9.0, dtype=float32) >>> jax.grad(f1)(x) Array(6.0, dtype=float32)
The same function with
stop_gradient
aroundx
will be equivalent under normal evaluation, but return a zero gradient becausex
is effectively treated as a constant:>>> def f2(x): ... return jax.lax.stop_gradient(x) ** 2 >>> f2(x) Array(9.0, dtype=float32) >>> jax.grad(f2)(x) Array(0.0, dtype=float32)
This is used in a number of places within the JAX codebase; for example
jax.nn.softmax()
internally normalizes the input by its maximum value, and this maximum value is wrapped instop_gradient
for efficiency. Refer to Stopping gradients for more discussion of the applicability ofstop_gradient
.