jax.scipy.cluster.vq.vq#
- jax.scipy.cluster.vq.vq(obs, code_book, check_finite=True)[source]#
Assign codes from a code book to a set of observations.
JAX implementation of
scipy.cluster.vq.vq()
.Assigns each observation vector in
obs
to a code fromcode_book
based on the nearest Euclidean distance.- Parameters:
obs (ArrayLike) – array of observation vectors of shape
(M, N)
. Each row represents a single observation. Ifobs
is one-dimensional, then each entry is treated as a length-1 observation.code_book (ArrayLike) – array of codes with shape
(K, N)
. Each row represents a single code vector. Ifcode_book
is one-dimensional, then each entry is treated as a length-1 code.check_finite (bool) – unused in JAX
- Returns:
A tuple of arrays
(code, dist)
code
is an integer array of shape(M,)
containing indices0 <= i < K
of the closest entry incode_book
for the given entry inobs
.dist
is a float array of shape(M,)
containing the euclidean distance between each observation and the nearest code.
- Return type:
Examples
>>> obs = jnp.array([[1.1, 2.1, 3.1], ... [5.9, 4.8, 6.2]]) >>> code_book = jnp.array([[1., 2., 3.], ... [2., 3., 4.], ... [3., 4., 5.], ... [4., 5., 6.]]) >>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book) >>> print(codes) [0 3] >>> print(distances) [0.17320499 1.9209373 ]