jax.make_mesh#

jax.make_mesh(axis_shapes, axis_names, *, devices=None, axis_types=None)[source]#

Creates an efficient mesh with the shape and axis names specified.

This function attempts to automatically compute a good mapping from a set of logical axes to a physical mesh. For example, on a TPU v3 with 8 devices:

>>> mesh = jax.make_mesh((8,), ('x'))  
>>> [d.id for d in mesh.devices.flat]  
[0, 1, 2, 3, 6, 7, 4, 5]

The above ordering takes into account the physical topology of TPU v3. It orders the devices into a ring, which yields efficient all-reduces on a TPU v3.

Now, let’s see another example with 16 devices of TPU v3:

>>> mesh = jax.make_mesh((2, 8), ('x', 'y'))  
>>> [d.id for d in mesh.devices.flat]  
[0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13]
>>> mesh = jax.make_mesh((4, 4), ('x', 'y'))  
>>> [d.id for d in mesh.devices.flat]  
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

As you can see, logical axes (axis_shapes) affect the ordering of the devices.

You can use jax.experimental.mesh_utils.create_device_mesh if you want to use the extra arguments it provides like contiguous_submeshes and allow_split_physical_axes.

Parameters:
  • axis_shapes (Sequence[int]) – Shape of the mesh. For example, axis_shape=(4, 2)

  • axis_names (Sequence[str]) – Names of the mesh axes. For example, axis_names=(‘x’, ‘y’)

  • devices (Sequence[xc.Device] | None) – Optional keyword only argument, that allows you to specify the devices you want to create a mesh with.

  • axis_types (tuple[mesh_lib.AxisType, ...] | None)

Returns:

A jax.sharding.Mesh object.

Return type:

mesh_lib.Mesh