jax.lax.pbroadcast#
- jax.lax.pbroadcast(x, axis_name, source)[source]#
Perform a collective broadcast and replicate from
source.This is equivalent to:
def pbroadcast(x, axis_name, source): masked = jnp.where(axis_index(axis_name) == source, x, zeros_like(x)) return psum(masked, axis_name)
but implemented in a hardware optimized way.
If
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.This function is an analog of the CollectiveBroadcast HLO.
- Parameters:
x – array(s) with a mapped axis named
axis_name.axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).source – int, representing which index into
axis_namethat should be copied.
- Returns:
Array(s) with
xbeing copied from thesourceindex slice ofaxis_name.