jax.make_mesh#
- jax.make_mesh(axis_shapes, axis_names, *, devices=None, axis_types=None)[source]#
Creates an efficient mesh with the shape and axis names specified.
This function attempts to automatically compute a good mapping from a set of logical axes to a physical mesh. For example, on a TPU v3 with 8 devices:
>>> mesh = jax.make_mesh((8,), ('x')) >>> [d.id for d in mesh.devices.flat] [0, 1, 2, 3, 6, 7, 4, 5]
The above ordering takes into account the physical topology of TPU v3. It orders the devices into a ring, which yields efficient all-reduces on a TPU v3.
Now, let’s see another example with 16 devices of TPU v3:
>>> mesh = jax.make_mesh((2, 8), ('x', 'y')) >>> [d.id for d in mesh.devices.flat] [0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13] >>> mesh = jax.make_mesh((4, 4), ('x', 'y')) >>> [d.id for d in mesh.devices.flat] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
As you can see, logical axes (axis_shapes) affect the ordering of the devices.
You can use jax.experimental.mesh_utils.create_device_mesh if you want to use the extra arguments it provides like contiguous_submeshes and allow_split_physical_axes.
- Parameters:
axis_shapes (Sequence[int]) – Shape of the mesh. For example, axis_shape=(4, 2)
axis_names (Sequence[str]) – Names of the mesh axes. For example, axis_names=(‘x’, ‘y’)
devices (Sequence[xc.Device] | None) – Optional keyword only argument, that allows you to specify the devices you want to create a mesh with.
axis_types (tuple[mesh_lib.AxisType, ...] | None)
- Returns:
A jax.sharding.Mesh object.
- Return type:
mesh_lib.Mesh