jax.scipy.special.logsumexp#
- jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array [source]#
- jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
- jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]
Log-sum-exp reduction.
JAX implementation of
scipy.special.logsumexp()
.\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]where the \(j\) indices range over one or more dimensions to be reduced.
- Parameters:
a – the input array
axis – int or sequence of ints, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes.
b – scaling factors for \(\mathrm{exp}(a)\). Must be broadcastable to the shape of a.
keepdims – If
True
, the axes that are reduced are left in the output as dimensions of size 1.return_sign – If
True
, the output will be a(result, sign)
pair, wheresign
is the sign of the sums andresult
contains the logarithms of their absolute values. IfFalse
onlyresult
is returned and it will contain NaN values if the sums are negative.where – Elements to include in the reduction.
- Returns:
Either an array
result
or a pair of arrays(result, sign)
, depending on the value of thereturn_sign
argument.