jax.numpy.array_equal#

jax.numpy.array_equal(a1, a2, equal_nan=False)[source]#

Check if two arrays are element-wise equal.

JAX implementation of numpy.array_equal().

Parameters:
  • a1 (ArrayLike) – first input array to compare.

  • a2 (ArrayLike) – second input array to compare.

  • equal_nan (bool) – Boolean. If True, NaNs in a1 will be considered equal to NaNs in a2. Default is False.

Returns:

Boolean scalar array indicating whether the input arrays are element-wise equal.

Return type:

Array

Examples

>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3]))
Array(True, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]), equal_nan=True)
Array(True, dtype=bool)