jax.numpy.result_type#

jax.numpy.result_type(*args)[source]#

Return the result of applying JAX promotion rules to the inputs.

JAX implementation of numpy.result_type().

JAX’s dtype promotion behavior is described in Type promotion semantics.

Parameters:

args (Any) – one or more arrays or dtype-like objects.

Returns:

A numpy.dtype instance representing the result of type promotion for the inputs.

Return type:

DType

Examples

Inputs can be dtype specifiers:

>>> jnp.result_type('int32', 'float32')
dtype('float32')
>>> jnp.result_type(np.uint16, np.dtype('int32'))
dtype('int32')

Inputs may also be scalars or arrays:

>>> jnp.result_type(1.0, jnp.bfloat16(2))
dtype(bfloat16)
>>> jnp.result_type(jnp.arange(4), jnp.zeros(4))
dtype('float32')

Be aware that the result type will be canonicalized based on the state of the jax_enable_x64 configuration flag, meaning that 64-bit types may be downcast to 32-bit:

>>> jnp.result_type('float64')
dtype('float32')

For details on 64-bit values, refer to Sharp bits - double precision: