jax.make_array_from_callback#

jax.make_array_from_callback(shape, sharding, data_callback, dtype=None)[source]#

Returns a jax.Array via data fetched from data_callback.

data_callback is used to fetch the data for each addressable shard of the returned jax.Array. This function must return concrete arrays, meaning that make_array_from_callback has limited compatibility with JAX transformations like jit() or vmap().

Parameters:
  • shape (Shape) – Shape of the jax.Array.

  • sharding (Sharding | Layout) – A Sharding instance which describes how the jax.Array is laid out across devices.

  • data_callback (Callable[[Index | None], ArrayLike]) – Callback that takes indices into the global array value as input and returns the corresponding data of the global array value. The data can be returned as any array-like object, e.g. a numpy.ndarray.

  • dtype (DTypeLike | None | None) – The dtype of the output jax.Array. If not provided, the dtype of the data for the first addressable shard is used. If there are no addressable shards, the dtype argument must be provided.

Returns:

A jax.Array via data fetched from data_callback.

Return type:

ArrayImpl

Examples

>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
...
>>> def cb(index):
...  return global_input_data[index]
...
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
>>> arr.addressable_data(0).shape
(4, 2)