jax.lax.broadcast_shapes#
- jax.lax.broadcast_shapes(*shapes)[source]#
Returns the shape that results from NumPy broadcasting of shapes.
This follows the rules of NumPy broadcasting.
- Parameters:
shapes – one or more tuples of integers containing the shapes of arrays to be broadcast.
- Returns:
A tuple of integers representing the broadcasted shape.
- Raises:
ValueError – if shapes are not broadcast-compatible.
See also
jax.numpy.broadcast_shapes()
: similar API in the JAX NumPy namespace
Examples
Some examples of broadcasting compatible shapes:
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
Error when attempting to broadcast incompatible shapes:
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]