jax.vjp#
- jax.vjp(fun: Callable[..., T], *primals: Any, has_aux: Literal[False] = False, reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable] [source]#
- jax.vjp(fun: Callable[..., tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable, U]
Compute a (reverse-mode) vector-Jacobian product of
fun
.grad()
is implemented as a special case ofvjp()
.- Parameters:
fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
primals – A sequence of primal values at which the Jacobian of
fun
should be evaluated. The number ofprimals
should be equal to the number of positional parameters offun
. Each primal value should be an array, a scalar, or a pytree (standard Python containers) thereof.has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Returns:
If
has_aux
isFalse
, returns a(primals_out, vjpfun)
pair, whereprimals_out
isfun(*primals)
. Ifhas_aux
isTrue
, returns a(primals_out, vjpfun, aux)
tuple whereaux
is the auxiliary data returned byfun
.vjpfun
is a function from a cotangent vector with the same shape asprimals_out
to a tuple of cotangent vectors with the same number and shapes asprimals
, representing the vector-Jacobian product offun
evaluated atprimals
.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((-0.7, 0.3)) >>> print(xbar) -0.61430776 >>> print(ybar) -0.2524413