jax.sharding
module#
Classes#
- class jax.sharding.Sharding#
Describes how a
jax.Array
is laid out across devices.- property addressable_devices: set[Device]#
The set of devices in the
Sharding
that are addressable by the current process.
- addressable_devices_indices_map(global_shape)[source]#
A mapping from addressable devices to the slice of array data each contains.
addressable_devices_indices_map
contains that part ofdevice_indices_map
that applies to the addressable devices.- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index | None]
- property device_set: set[Device][source]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index]
- is_equivalent_to(other, ndim)[source]#
Returns
True
if two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedSharding
may be equivalent to aPositionalSharding
if both place the same shards of the array on the same devices.
- property is_fully_addressable: bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool[source]#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.SingleDeviceSharding#
Bases:
Sharding
A
Sharding
that places its data on a single device.- Parameters:
device – A single
Device
.
Examples
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- property device_set: set[Device][source]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index]
- property is_fully_addressable: bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool[source]#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.NamedSharding#
Bases:
Sharding
A
NamedSharding
expresses sharding using named axes.A
NamedSharding
is a pair of aMesh
of devices andPartitionSpec
which describes how to shard an array across that mesh.A
Mesh
is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g.'x'
or'y'
.A
PartitionSpec
is a tuple, whose elements can be aNone
, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example,PartitionSpec('x', 'y')
says that the first dimension of data is sharded acrossx
axis of the mesh, and the second dimension is sharded acrossy
axis of the mesh.The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how
Mesh
andPartitionSpec
are used.- Parameters:
mesh – A
jax.sharding.Mesh
object.spec – A
jax.sharding.PartitionSpec
object.
Examples
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- property addressable_devices: set[Device][source]#
The set of devices in the
Sharding
that are addressable by the current process.
- property device_set: set[Device][source]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool[source]#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- property mesh#
(self) -> object
- property spec#
(self) -> object
- class jax.sharding.PositionalSharding(devices, *, memory_kind=None)[source]#
Bases:
Sharding
- Parameters:
devices (Sequence[xc.Device] | np.ndarray)
memory_kind (str | None)
- property device_set: set[xc.Device]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.PmapSharding#
Bases:
Sharding
Describes a sharding used by
jax.pmap()
.- classmethod default(shape, sharded_dim=0, devices=None)[source]#
Creates a
PmapSharding
which matches the default placement used byjax.pmap()
.- Parameters:
shape (Shape) – The shape of the input array.
sharded_dim (int | None) – Dimension the input array is sharded on. Defaults to 0.
devices (Sequence[xc.Device] | None | None) – Optional sequence of devices to use. If omitted, the implicit device order used by pmap is used, which is the order of
jax.local_devices()
.
- Return type:
- property device_set: set[Device]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property devices#
(self) -> ndarray
- devices_indices_map(global_shape)[source]#
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
- Parameters:
global_shape (Shape)
- Return type:
Mapping[Device, Index]
- is_equivalent_to(other, ndim)[source]#
Returns
True
if two shardings are equivalent.Two shardings are equivalent if they place the same logical array shards on the same devices.
For example, a
NamedSharding
may be equivalent to aPositionalSharding
if both place the same shards of the array on the same devices.- Parameters:
self (PmapSharding)
other (PmapSharding)
ndim (int)
- Return type:
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- shard_shape(global_shape)[source]#
Returns the shape of the data on each device.
The shard shape returned by this function is calculated from
global_shape
and the properties of the sharding.- Parameters:
global_shape (Shape)
- Return type:
Shape
- property sharding_spec#
(self) -> jax::ShardingSpec
- class jax.sharding.GSPMDSharding#
Bases:
Sharding
- property device_set: set[Device]#
The set of devices that this
Sharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- property is_fully_addressable: bool#
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the
Sharding
.is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
- property is_fully_replicated: bool#
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
- class jax.sharding.PartitionSpec(*partitions)[source]#
Tuple describing how to partition an array across a mesh of devices.
Each element is either
None
, a string, or a tuple of strings. See the documentation ofjax.sharding.NamedSharding
for more details.This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.
- class jax.sharding.Mesh(devices, axis_names, *, axis_types=None)[source]#
Declare the hardware resources available in the scope of this manager.
In particular, all
axis_names
become valid resource names inside the managed block and can be used e.g. in thein_axis_resources
argument ofjax.experimental.pjit.pjit()
. Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)If you are compiling in multiple threads, make sure that the
with Mesh
context manager is inside the function that the threads will execute.- Parameters:
devices (np.ndarray) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from
jax.devices()
).axis_names (tuple[MeshAxisName, ...]) – A sequence of resource axis names to be assigned to the dimensions of the
devices
argument. Its length should match the rank ofdevices
.axis_types (tuple[AxisType, ...] | None)
Examples
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)