jax.jvp#
- jax.jvp(fun, primals, tangents, has_aux=False)[source]#
Computes a (forward-mode) Jacobian-vector product of
fun
.- Parameters:
fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun
.tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals
.has_aux (bool) – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Returns:
If
has_aux
isFalse
, returns a(primals_out, tangents_out)
pair, whereprimals_out
isfun(*primals)
, andtangents_out
is the Jacobian-vector product offunction
evaluated atprimals
withtangents
. Thetangents_out
value has the same Python tree structure and shapes asprimals_out
. Ifhas_aux
isTrue
, returns a(primals_out, tangents_out, aux)
tuple whereaux
is the auxiliary data returned byfun
.- Return type:
tuple[Any, …]
For example:
>>> import jax >>> >>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(primals) 0.09983342 >>> print(tangents) 0.19900084