jax.experimental.multihost_utils.process_allgather#
- jax.experimental.multihost_utils.process_allgather(in_tree, tiled=False)[source]#
Gather data from across processes.
- Parameters:
in_tree (Any) – pytree of arrays - each array _must_ have the same shape across the hosts.
tiled (bool) – Whether to stack or concat the output. Defaults to False i.e. stack into a new positional axis at index 0.
- Returns:
- Pytrees of numpy arrays.
If the input is a non-fully addressable jax.Array, then the data is fully replicated.
If the input is numpy array or fully addressable jax.Array, then the output shape is dependent on the tiled argument. If its False, then the output will be stacked else concatenated.
If the input is a scalar, then the output will be stacked.
- Return type:
Any