jax.lax.broadcast_to_rank#

jax.lax.broadcast_to_rank(x, rank)[source]#

Adds leading dimensions of 1 to give x rank rank.

Parameters:
  • x (ArrayLike)

  • rank (int)

Return type:

Array