jax.numpy.linalg.norm#
- jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[source]#
Compute the norm of a matrix or vector.
JAX implementation of
numpy.linalg.norm()
.- Parameters:
x (ArrayLike) – N-dimensional array for which the norm will be computed.
ord (int | str | None) – specify the kind of norm to take. Default is Frobenius norm for matrices, and the 2-norm for vectors. For other options, see Notes below.
axis (None | tuple[int, ...] | int) – integer or sequence of integers specifying the axes over which the norm will be computed. For a single axis, compute a vector norm. For two axes, compute a matrix norm. Defaults to all axes of
x
.keepdims (bool) – if True, the output array will have the same number of dimensions as the input, with the size of reduced axes replaced by
1
(default: False).
- Returns:
array containing the specified norm of x.
- Return type:
Notes
The flavor of norm computed depends on the value of
ord
and the number of axes being reduced.For vector norms (i.e. a single axis reduction):
ord=None
(default) computes the 2-normord=inf
computesmax(abs(x))
ord=-inf
computes min(abs(x))``ord=0
computessum(x!=0)
for other numerical values, computes
sum(abs(x) ** ord)**(1/ord)
For matrix norms (i.e. two axes reductions):
ord='fro'
orord=None
(default) computes the Frobenius normord='nuc'
computes the nuclear norm, or the sum of the singular valuesord=1
computesmax(abs(x).sum(0))
ord=-1
computesmin(abs(x).sum(0))
ord=2
computes the 2-norm, i.e. the largest singular valueord=-2
computes the smallest singular value
In the special case of
ord=None
andaxis=None
, this function accepts an array of any dimension and computes the vector 2-norm of the flattened array.Examples
Vector norms:
>>> x = jnp.array([3., 4., 12.]) >>> jnp.linalg.norm(x) Array(13., dtype=float32) >>> jnp.linalg.norm(x, ord=1) Array(19., dtype=float32) >>> jnp.linalg.norm(x, ord=0) Array(3., dtype=float32)
Matrix norms:
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.norm(x) # Frobenius norm Array(10.198039, dtype=float32) >>> jnp.linalg.norm(x, ord='nuc') # nuclear norm Array(10.762535, dtype=float32) >>> jnp.linalg.norm(x, ord=1) # 1-norm Array(10., dtype=float32)
Batched vector norm:
>>> jnp.linalg.norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)