jax.make_array_from_callback#
- jax.make_array_from_callback(shape, sharding, data_callback, dtype=None)[source]#
Returns a
jax.Array
via data fetched fromdata_callback
.data_callback
is used to fetch the data for each addressable shard of the returnedjax.Array
. This function must return concrete arrays, meaning thatmake_array_from_callback
has limited compatibility with JAX transformations likejit()
orvmap()
.- Parameters:
shape (Shape) – Shape of the
jax.Array
.sharding (Sharding | Layout) – A
Sharding
instance which describes how thejax.Array
is laid out across devices.data_callback (Callable[[Index | None], ArrayLike]) – Callback that takes indices into the global array value as input and returns the corresponding data of the global array value. The data can be returned as any array-like object, e.g. a
numpy.ndarray
.dtype (DTypeLike | None | None) – The dtype of the output
jax.Array
. If not provided, the dtype of the data for the first addressable shard is used. If there are no addressable shards, thedtype
argument must be provided.
- Returns:
A
jax.Array
via data fetched fromdata_callback
.- Return type:
ArrayImpl
Examples
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> input_shape = (8, 8) >>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) ... >>> def cb(index): ... return global_input_data[index] ... >>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb) >>> arr.addressable_data(0).shape (4, 2)