jax.test_util module#

List of Functions#

check_grads(f, args, order[, modes, atol, ...])

Check gradients from automatic differentiation against finite differences.

check_jvp(f, f_jvp, args[, atol, rtol, eps, ...])

Check a JVP from automatic differentiation against finite differences.

check_vjp(f, f_vjp, args[, atol, rtol, eps, ...])

Check a VJP from automatic differentiation against finite differences.