jax.experimental.shard_map module# API# shard_map(f, mesh, in_specs, out_specs[, ...]) Map a function over shards of data.