jax.lax.argmax# jax.lax.argmax(operand, axis, index_dtype)[source]# Computes the index of the maximum element along axis. Parameters: operand (ArrayLike) axis (int) index_dtype (DTypeLike) Return type: Array