jax.lax.split#

jax.lax.split(operand, sizes, axis=0)[source]#

Splits an array along axis.

Parameters:
  • operand (ArrayLike) – an array to split

  • sizes (Sequence[int]) – the sizes of the split arrays. The sum of the sizes must be equal to the size of the axis dimension of operand.

  • axis (int) – the axis along which to split the array.

Returns:

A sequence of len(sizes) arrays. If sizes is [s1, s2, ...], this function returns chunks of sizes s1, s2, taken along axis.

Return type:

Sequence[Array]