jax.lax.convert_element_type#
- jax.lax.convert_element_type(operand, new_dtype)[source]#
Elementwise cast.
This function lowers directly to the stablehlo.convert operation, which performs an elementwise conversion from one type to another, similar to a C++
static_cast
.- Parameters:
operand (ArrayLike) – an array or scalar value to be cast.
new_dtype (DTypeLike | dtypes.ExtendedDType) – a dtype-like object (e.g. a
numpy.dtype
, a scalar type, or a valid dtype name) representing the target dtype.
- Returns:
An array with the same shape as
operand
, cast elementwise tonew_dtype
.- Return type:
Note
If
new_dtype
is a 64-bit type and x64 mode is not enabled, the appropriate 32-bit type will be used in its place.If the input is a JAX array and the input dtype and output dtype match, then the input array will be returned unmodified.
See also
jax.numpy.astype()
: NumPy-style dtype casting API.jax.Array.astype()
: dtype casting as an array method.jax.lax.bitcast_convert_type()
: cast bits directly to a new dtype.