jax.numpy.ix_#
- jax.numpy.ix_(*args)[source]#
Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.
JAX implementation of
numpy.ix_()
.- Parameters:
*args (ArrayLike) – N one-dimensional arrays
- Returns:
Tuple of Jax arrays forming an open mesh, each with N dimensions.
- Return type:
Examples
>>> rows = jnp.array([0, 2]) >>> cols = jnp.array([1, 3]) >>> open_mesh = jnp.ix_(rows, cols) >>> open_mesh (Array([[0], [2]], dtype=int32), Array([[1, 3]], dtype=int32)) >>> [grid.shape for grid in open_mesh] [(2, 1), (1, 2)] >>> x = jnp.array([[10, 20, 30, 40], ... [50, 60, 70, 80], ... [90, 100, 110, 120], ... [130, 140, 150, 160]]) >>> x[open_mesh] Array([[ 20, 40], [100, 120]], dtype=int32)