jax.lax.select_n#
- jax.lax.select_n(which, *cases)[source]#
Selects array values from multiple cases.
Generalizes XLA’s Select operator. Unlike XLA’s version, the operator is variadic and can select from many cases using an integer pred.
- Parameters:
which (ArrayLike) – determines which case should be returned. Must be an array containing either a boolean or integer values. May either be a scalar or have shape matching
cases
. For each array element, the value ofwhich
determines which ofcases
is taken.which
must be in the range[0 .. len(cases))
; for values outside that range the behavior is implementation-defined.*cases (ArrayLike) – a non-empty list of array cases. All must have equal dtypes and equal shapes.
- Returns:
An array with shape and dtype equal to the cases, whose values are chosen according to
which
.- Return type: