jax.numpy.result_type#
- jax.numpy.result_type(*args)[source]#
Return the result of applying JAX promotion rules to the inputs.
JAX implementation of
numpy.result_type()
.JAX’s dtype promotion behavior is described in Type promotion semantics.
- Parameters:
args (Any) – one or more arrays or dtype-like objects.
- Returns:
A
numpy.dtype
instance representing the result of type promotion for the inputs.- Return type:
DType
Examples
Inputs can be dtype specifiers:
>>> jnp.result_type('int32', 'float32') dtype('float32') >>> jnp.result_type(np.uint16, np.dtype('int32')) dtype('int32')
Inputs may also be scalars or arrays:
>>> jnp.result_type(1.0, jnp.bfloat16(2)) dtype(bfloat16) >>> jnp.result_type(jnp.arange(4), jnp.zeros(4)) dtype('float32')
Be aware that the result type will be canonicalized based on the state of the
jax_enable_x64
configuration flag, meaning that 64-bit types may be downcast to 32-bit:>>> jnp.result_type('float64') dtype('float32')
For details on 64-bit values, refer to Sharp bits - double precision: