jax.numpy.load#
- jax.numpy.load(file, *args, **kwargs)[source]#
Load JAX arrays from npy files.
JAX wrapper of
numpy.load()
.This function is a simple wrapper of
numpy.load()
, but in the case of.npy
files created withnumpy.save()
orjax.numpy.save()
, the output will be returned as ajax.Array
, andbfloat16
data types will be restored. For.npz
files, results will be returned as normal NumPy arrays.This function requires concrete array inputs, and is not compatible with transformations like
jax.jit()
orjax.vmap()
.- Parameters:
file (IO[bytes] | str | os.PathLike[Any]) – string, bytes, or path-like object containing the array data.
args (Any) – for additional arguments, see
numpy.load()
kwargs (Any) – for additional arguments, see
numpy.load()
- Returns:
the array stored in the file.
- Return type:
See also
jax.numpy.save()
: save an array to a file.
Examples
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)