jax.numpy.astype#

jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[source]#

Convert an array to a specified dtype.

JAX imlementation of numpy.astype().

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.

Parameters:
  • x (ArrayLike) – input array to convert

  • dtype (DTypeLike | None) – output dtype

  • copy (bool) – if True, then always return a copy. If False (default) then only return a copy if necessary.

  • device (xc.Device | Sharding | None | None) – optionally specify the device to which the output will be committed.

Returns:

An array with the same shape as x, containing values of the specified dtype.

Return type:

Array

See also

Examples

>>> x = jnp.array([0, 1, 2, 3])
>>> x
Array([0, 1, 2, 3], dtype=int32)
>>> x.astype('float32')
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0])
>>> y.astype(int)  # truncates fractional values
Array([0, 0, 1], dtype=int32)