jax.make_jaxpr#
- jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[False] = False, abstracted_axes: Any | None = None) Callable[..., core.ClosedJaxpr] [source]#
- jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = False, abstracted_axes: Any | None = None) Callable[..., tuple[core.ClosedJaxpr, Any]]
Creates a function that produces its jaxpr given example args.
- Parameters:
fun – The function whose
jaxpr
is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.static_argnums – See the
jax.jit()
docstring.axis_env – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of
jax.pmap()
.return_shape – Optional boolean, defaults to
False
. IfTrue
, the wrapped function returns a pair where the first element is theClosedJaxpr
representation offun
and the second element is a pytree with the same structure as the output offun
and where the leaves are objects withshape
anddtype
attributes representing the corresponding types of the output leaves.
- Returns:
A wrapped version of
fun
that when applied to example arguments returns aClosedJaxpr
representation offun
on those arguments. If the argumentreturn_shape
isTrue
, then the returned function instead returns a pair where the first element is theClosedJaxpr
representation offun
and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output offun
.
A
jaxpr
is JAX’s intermediate representation for program traces. Thejaxpr
language is based on the simply-typed first-order lambda calculus with let-bindings.make_jaxpr()
adapts a function to return itsjaxpr
, which we can inspect to understand what JAX is doing internally. Thejaxpr
returned is a trace offun
abstracted toShapedArray
level. Other levels of abstraction exist internally.We do not describe the semantics of the
jaxpr
language in detail here, but instead give a few examples.>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let b:f32[] = cos a c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c in (g,) }