jax.experimental.sparse.bcoo_dot_general_sampled#

jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)[source]#

A contraction operation with output computed at given sparse indices.

Parameters:
  • lhs – An ndarray.

  • rhs – An ndarray.

  • indices (Array) – BCOO indices.

  • dimension_numbers (DotDimensionNumbers) – a tuple of tuples of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)).

  • A (Array)

  • B (Array)

Returns:

BCOO data, an ndarray containing the result.

Return type:

Array