jax.numpy.array_equal#
- jax.numpy.array_equal(a1, a2, equal_nan=False)[source]#
Check if two arrays are element-wise equal.
JAX implementation of
numpy.array_equal()
.- Parameters:
a1 (ArrayLike) – first input array to compare.
a2 (ArrayLike) – second input array to compare.
equal_nan (bool) – Boolean. If
True
, NaNs ina1
will be considered equal to NaNs ina2
. Default isFalse
.
- Returns:
Boolean scalar array indicating whether the input arrays are element-wise equal.
- Return type:
Examples
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) Array(True, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')]), equal_nan=True) Array(True, dtype=bool)