jax.numpy.empty

Contents

jax.numpy.empty#

jax.numpy.empty(shape, dtype=None, *, device=None, out_sharding=None)[source]#

Create an empty array.

JAX implementation of numpy.empty(). Starting in JAX v0.11.0, this returns an uninitialized array on platforms that support doing so. Prior to v0.11.0, this function returned an array filled with zeros on all platforms.

Parameters:
  • shape (Any) – int or sequence of ints specifying the shape of the created array.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).

  • device (Device | Sharding | None) – (optional) Device or Sharding to which the created array will be committed. This argument exists for compatibility with the Python Array API standard.

  • out_sharding (NamedSharding | P | None) – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying both out_sharding and device will result in an error.

Returns:

Array of the specified shape and dtype, with the given device/sharding if specified.

Return type:

Array

Examples

>>> jnp.empty(4)  
Array([0., 0., 0., 0.], dtype=float32)
>>> jnp.empty((2, 3), dtype=bool)  
Array([[False, False, False],
       [False, False, False]], dtype=bool)