jax.numpy.delete#
- jax.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)[source]#
Delete entry or entries from an array.
JAX implementation of
numpy.delete()
.- Parameters:
arr (ArrayLike) – array from which entries will be deleted.
obj (ArrayLike | slice) – index, indices, or slice to be deleted.
axis (int | None | None) – axis along which entries will be deleted.
assume_unique_indices (bool) – In case of array-like integer (not boolean) indices, assume the indices are unique, and perform the deletion in a way that is compatible with JIT and other JAX transformations.
- Returns:
Copy of
arr
with specified indices deleted.- Return type:
Note
delete()
usually requires the index specification to be static. If the index is an integer array that is guaranteed to contain unique entries, you may specifyassume_unique_indices=True
to perform the operation in a manner that does not require static indices.See also
jax.numpy.insert()
: insert entries into an array.
Examples
Delete entries from a 1D array:
>>> a = jnp.array([4, 5, 6, 7, 8, 9]) >>> jnp.delete(a, 2) Array([4, 5, 7, 8, 9], dtype=int32) >>> jnp.delete(a, slice(1, 4)) # delete a[1:4] Array([4, 8, 9], dtype=int32) >>> jnp.delete(a, slice(None, None, 2)) # delete a[::2] Array([5, 7, 9], dtype=int32)
Delete entries from a 2D array along a specified axis:
>>> a2 = jnp.array([[4, 5, 6], ... [7, 8, 9]]) >>> jnp.delete(a2, 1, axis=1) Array([[4, 6], [7, 9]], dtype=int32)
Delete multiple entries via a sequence of indices:
>>> indices = jnp.array([0, 1, 3]) >>> jnp.delete(a, indices) Array([6, 8, 9], dtype=int32)
This will fail under
jit()
and other transformations, because the output shape cannot be known with the possibility of duplicate indices:>>> jax.jit(jnp.delete)(a, indices) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].
If you can ensure that the indices are unique, pass
assume_unique_indices
to allow this to be executed under JIT:>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices']) >>> jit_delete(a, indices, assume_unique_indices=True) Array([6, 8, 9], dtype=int32)