jax.lax.broadcast_shapes#

jax.lax.broadcast_shapes(*shapes)[source]#

Returns the shape that results from NumPy broadcasting of shapes.

This follows the rules of NumPy broadcasting.

Parameters:

shapes – one or more tuples of integers containing the shapes of arrays to be broadcast.

Returns:

A tuple of integers representing the broadcasted shape.

Raises:

ValueError – if shapes are not broadcast-compatible.

See also

Examples

Some examples of broadcasting compatible shapes:

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

Error when attempting to broadcast incompatible shapes:

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]