jax.lax.linalg.lu_pivots_to_permutation#

jax.lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)[source]#

Converts the pivots (row swaps) returned by LU to a permutation.

We build a permutation rather than applying pivots directly to the rows of a matrix because lax loops aren’t differentiable.

Parameters:
  • pivots (ArrayLike) – an int32 array of shape (…, k) of row swaps to perform

  • permutation_size (int) – the size of the output permutation. Has to be >= k.

Returns:

An int32 array of shape (…, permutation_size).

Return type:

Array