jax.numpy.frombuffer#

jax.numpy.frombuffer(buffer, dtype=<class 'float'>, count=-1, offset=0)[source]#

Convert a buffer into a 1-D JAX array.

JAX implementation of numpy.frombuffer().

Parameters:
  • buffer (bytes | Any) – an object containing the data. It must be either a bytes object with a length that is an integer multiple of the dtype element size, or it must be an object exporting the Python buffer interface.

  • dtype (DTypeLike) – optional. Desired data type for the array. Default is float64. This specifies the dtype used to parse the buffer, but note that after parsing, 64-bit values will be cast to 32-bit JAX arrays if the jax_enable_x64 flag is set to False.

  • count (int) – optional integer specifying the number of items to read from the buffer. If -1 (default), all items from the buffer are read.

  • offset (int) – optional integer specifying the number of bytes to skip at the beginning of the buffer. Default is 0.

Returns:

A 1-D JAX array representing the interpreted data from the buffer.

Return type:

Array

See also

Examples

Using a bytes buffer:

>>> buf = b"\x00\x01\x02\x03\x04"
>>> jnp.frombuffer(buf, dtype=jnp.uint8)
Array([0, 1, 2, 3, 4], dtype=uint8)
>>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1)
Array([1, 2, 3, 4], dtype=uint8)

Constructing a JAX array via the Python buffer interface, using Python’s built-in array module.

>>> from array import array
>>> pybuffer = array('i', [0, 1, 2, 3, 4])
>>> jnp.frombuffer(pybuffer, dtype=jnp.int32)
Array([0, 1, 2, 3, 4], dtype=int32)