jax.numpy.astype#
- jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[source]#
Convert an array to a specified dtype.
JAX imlementation of
numpy.astype()
.This is implemented via
jax.lax.convert_element_type()
, which may have slightly different behavior thannumpy.astype()
in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.- Parameters:
x (ArrayLike) – input array to convert
dtype (DTypeLike | None) – output dtype
copy (bool) – if True, then always return a copy. If False (default) then only return a copy if necessary.
device (xc.Device | Sharding | None | None) – optionally specify the device to which the output will be committed.
- Returns:
An array with the same shape as
x
, containing values of the specified dtype.- Return type:
See also
jax.lax.convert_element_type()
: lower-level function for XLA-style dtype conversions.
Examples
>>> x = jnp.array([0, 1, 2, 3]) >>> x Array([0, 1, 2, 3], dtype=int32) >>> x.astype('float32') Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0]) >>> y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32)