jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[source]#
Calls a pure Python callback. Works under
jit()
/vmap()
/etc.For more explanation, see External Callbacks.
pure_callback
enables calling a Python function in JIT-ed JAX functions. The inputcallback
will be passed JAX arrays placed on a local CPU, and it should also return JAX arrays on CPU.The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by
vmap()
orpmap()
), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.Warning
In the context of JAX transformations, Python exceptions should be considered side-effects: this means that intentionally raising an error within a pure_callback breaks the API contract, and the behavior of the resulting program is undefined.
When vmap-ed the behavior will depend on the value of the
vmap_method
.Calling
vmap()
on a callback without an explicitvmap_method
is deprecated and it will eventually raiseNotImplementedError
.vmap_method="sequential"
usesmap()
to loop over the batched arguments, callingcallback
once for each batch element.vmap_method="sequential_unrolled"
is likesequential
, but the loop is unrolled.vmap_method="expand_dims"
callscallback
with new axes of size1
added as the leading dimension unbatched inputs.vmap_method="broadcast_all"
behaves likeexpand_dims
, but the inputs are tiled to the expected batched shape.
If necessary, the legacy behavior provided by the deprecated
vectorized=True
argument can be recovered usingvmap_method="legacy_vectorized"
.The current default behavior is to use
vmap_method="sequential"
when not specified, but this behavior is deprecated, and in the future, the default will be to raise aNotImplementedError
unlessvmap_method
is explicitly specified.- Parameters:
callback (Callable[..., Any]) – function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it may behave in unexpected ways, particularly under transformation. The callable will be passed PyTrees of arrays as arguments, and should return a PyTree of arrays that matches
result_shape_dtypes
.result_shape_dtypes (Any) – pytree whose leaves have
shape
anddtype
attributes, whose structure matches the expected output of the callback function at runtime.jax.ShapeDtypeStruct
is often used to define leaf values.*args (Any) – arguments to be passed to the callback function
sharding (SingleDeviceSharding | None | None) – optional sharding that specifies the device from which the callback should be invoked.
vmap_method (str | None | None) – string specifying how the callback transforms under
vmap()
as described above.**kwargs (Any) – keyword arguments to be passed to the callback function
vectorized (bool | None | DeprecatedArg)
- Returns:
- a pytree of
jax.Array
objects whose structure matches that of result_shape_dtypes
.
- a pytree of
- Return type:
result
See also
jax.experimental.io_callback()
: callback designed for impure functions.jax.debug.callback()
: callback designed for general-purpose debugging.jax.debug.print()
: callback designed for printing.
Examples
The behavior of
pure_callback
undervmap()
is controlled by thevmap_method
argument as described above. It is useful to consider some explicit examples that demonstrate the semantics. For example, consider the following function:>>> def callback(x, y): ... print(jnp.shape(x), jnp.shape(y)) ... return x + y
>>> def fun(x, y, *, vmap_method): ... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y)) ... dtype = jnp.result_type(x, y) ... out_type = jax.ShapeDtypeStruct(shape, dtype) ... return jax.pure_callback(callback, out_type, x, y, ... vmap_method=vmap_method)
Calling this with
vmap_method="expand_dims"
adds a new axis of size1
toy
:>>> from functools import partial >>> x = jnp.arange(4) >>> y = 1.0 >>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y) (4,) (1,) Array([1., 2., 3., 4.], dtype=float32)
Whereas,
vmap_method="broadcast_all"
adds an axis of size4
toy
:>>> jax.vmap(partial(fun, vmap_method="broadcast_all"), ... in_axes=(0, None))(x, y) (4,) (4,) Array([1., 2., 3., 4.], dtype=float32)