jax.experimental.pjit
module#
API#
- jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[source]#
Makes
fun
compiled and automatically partitioned across multiple devices.NOTE: This function is now equivalent to jax.jit please use that instead. The returned function has semantics equivalent to those of
fun
, but is compiled to an XLA computation that runs across multiple devices (e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted version offun
would not fit in a single device’s memory, or to speed upfun
by running each operation in parallel across multiple devices.The partitioning over devices happens automatically based on the propagation of the input partitioning specified in
in_shardings
and the output partitioning specified inout_shardings
. The resources specified in those two arguments must refer to mesh axes, as defined by thejax.sharding.Mesh()
context manager. Note that the mesh definition atpjit()
application time is ignored, and the returned function will use the mesh definition available at each call site.Inputs to a
pjit()
’d function will be automatically partitioned across devices if they’re not already correctly partitioned based onin_shardings
. In some scenarios, ensuring that the inputs are already correctly pre-partitioned can increase performance. For example, if passing the output of onepjit()
’d function to anotherpjit()
’d function (or the samepjit()
’d function in a loop), make sure the relevantout_shardings
match the correspondingin_shardings
.Note
Multi-process platforms: On multi-process platforms such as TPU pods,
pjit()
can be used to run computations across all available devices across processes. To achieve this,pjit()
is designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the samepjit()
’d function in the same order.When running in this configuration, the mesh should contain devices across all processes. All inputs arguments must be globally shaped.
fun
will still be executed across all devices in the mesh, including those from other processes, and will be given a global view of the data spread across multiple processes as a single array.The SPMD model also requires that the same multi-process
pjit()
’d functions must be run in the same order on all processes, but they can be interspersed with arbitrary operations running in a single process.- Parameters:
fun (Callable) – Function to be compiled. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by
static_argnums
can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.in_shardings (Any) –
Pytree of structure matching that of arguments to
fun
, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.The
in_shardings
argument is optional. JAX will infer the shardings from the inputjax.Array
’s, and defaults to replicating the input if the sharding cannot be inferred.The valid resource assignment specifications are:
Sharding
, which will decide how the value will be partitioned. With this, using a mesh context manager is not required.None
is a special case whose semantics are:if the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
For backwards compatibility, in_shardings still supports ingesting
PartitionSpec
. This option can only be used with the mesh context manager.PartitionSpec
, a tuple of length at most equal to the rank of the partitioned value. Each element can be aNone
, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec.
The size of every dimension has to be a multiple of the total number of resources assigned to it.
out_shardings (Any) – Like
in_shardings
, but specifies resource assignment for function outputs. Theout_shardings
argument is optional. If not specified,jax.jit()
will use GSPMD’s sharding propagation to determine how to shard the outputs.static_argnums (int | Sequence[int] | None | None) –
An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__
and__eq__
are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If
static_argnums
is not provided, no arguments are treated as static.static_argnames (str | Iterable[str] | None | None) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on
static_argnums
for details. If not provided butstatic_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.donate_argnums (int | Sequence[int] | None | None) –
Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.
If neither
donate_argnums
nordonate_argnames
is provided, no arguments are donated. Ifdonate_argnums
is not provided butdonate_argnames
is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond todonate_argnames
(or vice versa). If bothdonate_argnums
anddonate_argnames
are provided,inspect.signature
is not used, and only actual parameters listed in eitherdonate_argnums
ordonate_argnames
will be donated.For more details on buffer donation see the FAQ.
donate_argnames (str | Iterable[str] | None | None) – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on
donate_argnums
for details. If not provided butdonate_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.
device (xc.Device | None | None) – This argument is deprecated. Please put your arguments on the device you want before passing them to jit. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.backend (str | None | None) – This argument is deprecated. Please put your arguments on the backend you want before passing them to jit. Optional, a string representing the XLA backend:
'cpu'
,'gpu'
, or'tpu'
.inline (bool)
abstracted_axes (Any | None | None)
- Returns:
A wrapped version of
fun
, set up for just-in-time compilation and automatically partitioned by the mesh available at each call site.- Return type:
JitWrapped
For example, a convolution operator can be automatically partitioned over an arbitrary set of devices by a single
pjit()
application:>>> import jax >>> import jax.numpy as jnp >>> import numpy as np >>> from jax.sharding import Mesh, PartitionSpec >>> from jax.experimental.pjit import pjit >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'), ... in_shardings=None, out_shardings=PartitionSpec('devices')) >>> with Mesh(np.array(jax.devices()), ('devices',)): ... print(f(x)) [ 0.5 2. 4. 6. 8. 10. 12. 10. ]