jax.pure_callback#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[source]#

Calls a pure Python callback. Works under jit()/vmap()/etc.

For more explanation, see External Callbacks.

pure_callback enables calling a Python function in JIT-ed JAX functions. The input callback will be passed JAX arrays placed on a local CPU, and it should also return JAX arrays on CPU.

The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by vmap() or pmap()), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.

Warning

In the context of JAX transformations, Python exceptions should be considered side-effects: this means that intentionally raising an error within a pure_callback breaks the API contract, and the behavior of the resulting program is undefined.

When vmap-ed the behavior will depend on the value of the vmap_method.

  • Calling vmap() on a callback without an explicit vmap_method is deprecated and it will eventually raise NotImplementedError.

  • vmap_method="sequential" uses map() to loop over the batched arguments, calling callback once for each batch element.

  • vmap_method="sequential_unrolled" is like sequential, but the loop is unrolled.

  • vmap_method="expand_dims" calls callback with new axes of size 1 added as the leading dimension unbatched inputs.

  • vmap_method="broadcast_all" behaves like expand_dims, but the inputs are tiled to the expected batched shape.

If necessary, the legacy behavior provided by the deprecated vectorized=True argument can be recovered using vmap_method="legacy_vectorized".

The current default behavior is to use vmap_method="sequential" when not specified, but this behavior is deprecated, and in the future, the default will be to raise a NotImplementedError unless vmap_method is explicitly specified.

Parameters:
  • callback (Callable[..., Any]) – function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it may behave in unexpected ways, particularly under transformation. The callable will be passed PyTrees of arrays as arguments, and should return a PyTree of arrays that matches result_shape_dtypes.

  • result_shape_dtypes (Any) – pytree whose leaves have shape and dtype attributes, whose structure matches the expected output of the callback function at runtime. jax.ShapeDtypeStruct is often used to define leaf values.

  • *args (Any) – arguments to be passed to the callback function

  • sharding (SingleDeviceSharding | None | None) – optional sharding that specifies the device from which the callback should be invoked.

  • vmap_method (str | None | None) – string specifying how the callback transforms under vmap() as described above.

  • **kwargs (Any) – keyword arguments to be passed to the callback function

  • vectorized (bool | None | DeprecatedArg)

Returns:

a pytree of jax.Array objects whose structure matches that of

result_shape_dtypes.

Return type:

result

See also

Examples

The behavior of pure_callback under vmap() is controlled by the vmap_method argument as described above. It is useful to consider some explicit examples that demonstrate the semantics. For example, consider the following function:

>>> def callback(x, y):
...   print(jnp.shape(x), jnp.shape(y))
...   return x + y
>>> def fun(x, y, *, vmap_method):
...   shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y))
...   dtype = jnp.result_type(x, y)
...   out_type = jax.ShapeDtypeStruct(shape, dtype)
...   return jax.pure_callback(callback, out_type, x, y,
...                            vmap_method=vmap_method)

Calling this with vmap_method="expand_dims" adds a new axis of size 1 to y:

>>> from functools import partial
>>> x = jnp.arange(4)
>>> y = 1.0
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
(4,) (1,)
Array([1., 2., 3., 4.], dtype=float32)

Whereas, vmap_method="broadcast_all" adds an axis of size 4 to y:

>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
...          in_axes=(0, None))(x, y)
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)