jax.numpy.frombuffer#
- jax.numpy.frombuffer(buffer, dtype=<class 'float'>, count=-1, offset=0)[source]#
Convert a buffer into a 1-D JAX array.
JAX implementation of
numpy.frombuffer()
.- Parameters:
buffer (bytes | Any) – an object containing the data. It must be either a bytes object with a length that is an integer multiple of the dtype element size, or it must be an object exporting the Python buffer interface.
dtype (DTypeLike) – optional. Desired data type for the array. Default is
float64
. This specifies the dtype used to parse the buffer, but note that after parsing, 64-bit values will be cast to 32-bit JAX arrays if thejax_enable_x64
flag is set toFalse
.count (int) – optional integer specifying the number of items to read from the buffer. If -1 (default), all items from the buffer are read.
offset (int) – optional integer specifying the number of bytes to skip at the beginning of the buffer. Default is 0.
- Returns:
A 1-D JAX array representing the interpreted data from the buffer.
- Return type:
See also
jax.numpy.fromstring()
: convert a string of text into 1-D JAX array.
Examples
Using a bytes buffer:
>>> buf = b"\x00\x01\x02\x03\x04" >>> jnp.frombuffer(buf, dtype=jnp.uint8) Array([0, 1, 2, 3, 4], dtype=uint8) >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1) Array([1, 2, 3, 4], dtype=uint8)
Constructing a JAX array via the Python buffer interface, using Python’s built-in
array
module.>>> from array import array >>> pybuffer = array('i', [0, 1, 2, 3, 4]) >>> jnp.frombuffer(pybuffer, dtype=jnp.int32) Array([0, 1, 2, 3, 4], dtype=int32)