jax.lax.linalg.symmetric_product#

jax.lax.linalg.symmetric_product(a_matrix, c_matrix, *, alpha=1.0, beta=0.0, symmetrize_output=False)[source]#

Symmetric product.

Computes the symmetric product

\[\alpha \, A \, A^T + \beta \, C\]

where \(A\) is a rectangular matrix and \(C\) is a symmetric matrix.

Parameters:
  • a_matrix (ArrayLike) – A batch of matrices with shape [..., m, n].

  • c_matrix (ArrayLike) – A batch of matrices with shape [..., m, m].

  • alpha (float) – A scalar.

  • beta (float) – A scalar.

  • symmetrize_output (bool) – If True, the upper triangle of the output is replaced with its transpose.

Returns:

A batch of matrices with shape [..., m, m] where only the lower triangle is guaranteed to include the correct values on all platforms. If symmetrize_output is True, the upper triangle is filled with the transpose of the lower triangle, and the whole matrix is valid.