jax.lax.with_sharding_constraint#

jax.lax.with_sharding_constraint(x, shardings)[source]#

Mechanism to constrain the sharding of an Array inside a jitted computation

This is a strict constraint for the GSPMD partitioner and not a hint. For examples of how to use this function, see Distributed arrays and automatic parallelization.

Inside of a jitted computation, with_sharding_constraint makes it possible to constrain intermediate values to an uneven sharding. However, if such an unevenly sharded value is output by the jitted computation, it will come out as fully replicated, no matter the sharding annotation given.

Parameters:
  • x – PyTree of jax.Arrays which will have their shardings constrained

  • shardings – PyTree of sharding specifications. Valid values are the same as for the in_shardings argument of jax.experimental.pjit().

Returns:

PyTree of jax.Arrays with specified sharding constraints.

Return type:

x_with_shardings