jax.typing
module#
The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html.
The currently-available types are:
jax.Array
: annotation for any JAX array or tracer (i.e. representations of arrays within JAX transforms).jax.typing.ArrayLike
: annotation for any value that is safe to implicitly cast to a JAX array; this includesjax.Array
,numpy.ndarray
, as well as Python builtin numeric values (e.g.int
,float
, etc.) and numpy scalar values (e.g.numpy.int32
,numpy.float64
, etc.)jax.typing.DTypeLike
: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. âfloat32â, âint32â), scalar types (e.g. float, np.float32), dtypes (e.g. np.dtype(âfloat32â)), or objects with a dtype attribute (e.g. jnp.float32, jnp.int32).
We may add additional types here in future releases.
JAX Typing Best Practices#
When annotating JAX arrays in public API functions, we recommend using ArrayLike
for array inputs, and Array
for array outputs.
For example, your function might look like this:
import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
# Runtime type validation, Python 3.10 or newer:
if not isinstance(x, ArrayLike):
raise TypeError(f"Expected arraylike input; got {x}")
# Runtime type validation, any Python version:
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
raise TypeError(f"Expected arraylike input; got {x}")
# Convert input to jax.Array:
x_arr = jnp.asarray(x)
# ... do some computation; JAX functions will return Array types:
result = x_arr.sum(0) / x_arr.shape[0]
# return an Array
return result
Most of JAXâs public APIs follow this pattern. Note in particular that we recommend JAX functions
to not accept sequences such as list
or tuple
in place of arrays, as this can
cause extra overhead in JAX transforms like jit()
and can behave in unexpected ways with
batch-wise transforms like vmap()
or jax.pmap()
. For more information on this,
see Non-array inputs NumPy vs JAX
List of Members#
Type annotation for JAX array-like objects. |
|