jax.experimental.multihost_utils.broadcast_one_to_all#
- jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[source]#
Broadcast data from a source host (host 0 by default) to all other hosts.
- Parameters:
in_tree (Any) â pytree of arrays - each array must have the same shape across the hosts.
is_source (bool | None | None) â optional bool denoting whether the caller is the source. Only âsource hostâ will contribute the data for the broadcast. If None, then host 0 is used.
- Returns:
A pytree matching in_tree where the leaves now all contain the data from the first host.
- Return type:
Any