jax.numpy.add#

jax.numpy.add = <jnp.ufunc 'add'>#

Add two arrays element-wise.

JAX implementation of numpy.add. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc. This function provides the implementation of the + operator for JAX arrays.

Parameters:
  • x – arrays to add. Must be broadcastable to a common shape.

  • y – arrays to add. Must be broadcastable to a common shape.

  • args (ArrayLike)

  • out (None)

  • where (None)

Returns:

Array containing the result of the element-wise addition.

Return type:

Any

Examples

Calling add explicitly:

>>> x = jnp.arange(4)
>>> jnp.add(x, 10)
Array([10, 11, 12, 13], dtype=int32)

Calling add via the + operator:

>>> x + 10
Array([10, 11, 12, 13], dtype=int32)