jax.lax.with_sharding_constraint#
- jax.lax.with_sharding_constraint(x, shardings)[source]#
Mechanism to constrain the sharding of an Array inside a jitted computation
This is a strict constraint for the GSPMD partitioner and not a hint. For examples of how to use this function, see Distributed arrays and automatic parallelization.
Inside of a jitted computation, with_sharding_constraint makes it possible to constrain intermediate values to an uneven sharding. However, if such an unevenly sharded value is output by the jitted computation, it will come out as fully replicated, no matter the sharding annotation given.
- Parameters:
x – PyTree of jax.Arrays which will have their shardings constrained
shardings – PyTree of sharding specifications. Valid values are the same as for the
in_shardings
argument ofjax.experimental.pjit()
.
- Returns:
PyTree of jax.Arrays with specified sharding constraints.
- Return type:
x_with_shardings