jax.tree_util.Partial#
- class jax.tree_util.Partial(func, *args, **kw)#
A version of functools.partial that works in pytrees.
Use it for partial function evaluation in a way that is compatible with JAX’s transformations, e.g.,
Partial(func, *args, **kwargs)
.(You need to explicitly opt-in to this behavior because we didn’t want to give functools.partial different semantics than normal function closures.)
For example, here is a basic usage of
Partial
in a manner similar tofunctools.partial
:>>> import jax.numpy as jnp >>> add_one = Partial(jnp.add, 1) >>> add_one(2) Array(3, dtype=int32, weak_type=True)
Pytree compatibility means that the resulting partial function can be passed as an argument within transformed JAX functions, which is not possible with a standard
functools.partial
function:>>> from jax import jit >>> @jit ... def call_func(f, *args): ... return f(*args) ... >>> call_func(add_one, 2) Array(3, dtype=int32, weak_type=True)
Passing zero arguments to
Partial
effectively wraps the original function, making it a valid argument in JAX transformed functions:>>> call_func(Partial(jnp.add), 1, 2) Array(3, dtype=int32, weak_type=True)
Had we passed
jnp.add
tocall_func
directly, it would have resulted in aTypeError
.Note that if the result of
Partial
is used in the context where the value is traced, it results in all bound arguments being traced when passed to the partially-evaluated function:>>> print_zero = Partial(print, 0) >>> print_zero() 0 >>> call_func(print_zero) Traced<~int32[]>with<DynamicJaxprTrace...>
- __init__()#
Methods
__init__
()Attributes
args
tuple of arguments to future partial calls
func
function object to use in future partial calls
keywords
dictionary of keyword arguments to future partial calls