jax.lax.bitcast_convert_type#
- jax.lax.bitcast_convert_type(operand, new_dtype)[source]#
Elementwise bitcast.
This function lowers directly to the stablehlo.bitcast_convert operation.
The output shape depends on the size of the input and output dtypes with the following logic:
if new_dtype.itemsize == operand.dtype.itemsize: output_shape = operand.shape if new_dtype.itemsize < operand.dtype.itemsize: output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize) if new_dtype.itemsize > operand.dtype.itemsize: assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize output_shape = operand.shape[:-1]
- Parameters:
operand (ArrayLike) – an array or scalar value to be cast
new_dtype (DTypeLike) – the new type. Should be a NumPy type.
- Returns:
An array of shape output_shape (see above) and type new_dtype, constructed from the same bits as operand.
- Return type:
See also
jax.lax.convert_element_type()
: value-preserving dtype conversion.jax.Array.view()
: NumPy-style API for bitcast type conversion.