jax.lax.split# jax.lax.split(operand, sizes, axis=0)[source]# Splits an array along axis. Parameters: operand (ArrayLike) – an array to split sizes (Sequence[int]) – the sizes of the split arrays. The sum of the sizes must be equal to the size of the axis dimension of operand. axis (int) – the axis along which to split the array. Returns: A sequence of len(sizes) arrays. If sizes is [s1, s2, ...], this function returns chunks of sizes s1, s2, taken along axis. Return type: Sequence[Array]