jax.lax.convert_element_type#

jax.lax.convert_element_type(operand, new_dtype)[source]#

Elementwise cast.

This function lowers directly to the stablehlo.convert operation, which performs an elementwise conversion from one type to another, similar to a C++ static_cast.

Parameters:
  • operand (ArrayLike) – an array or scalar value to be cast.

  • new_dtype (DTypeLike | dtypes.ExtendedDType) – a dtype-like object (e.g. a numpy.dtype, a scalar type, or a valid dtype name) representing the target dtype.

Returns:

An array with the same shape as operand, cast elementwise to new_dtype.

Return type:

Array

Note

If new_dtype is a 64-bit type and x64 mode is not enabled, the appropriate 32-bit type will be used in its place.

If the input is a JAX array and the input dtype and output dtype match, then the input array will be returned unmodified.

See also