jax.numpy.load#

jax.numpy.load(file, *args, **kwargs)[source]#

Load JAX arrays from npy files.

JAX wrapper of numpy.load().

This function is a simple wrapper of numpy.load(), but in the case of .npy files created with numpy.save() or jax.numpy.save(), the output will be returned as a jax.Array, and bfloat16 data types will be restored. For .npz files, results will be returned as normal NumPy arrays.

This function requires concrete array inputs, and is not compatible with transformations like jax.jit() or jax.vmap().

Parameters:
  • file (IO[bytes] | str | os.PathLike[Any]) – string, bytes, or path-like object containing the array data.

  • args (Any) – for additional arguments, see numpy.load()

  • kwargs (Any) – for additional arguments, see numpy.load()

Returns:

the array stored in the file.

Return type:

Array

See also

Examples

>>> import io
>>> f = io.BytesIO()  # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)