jax.tree_util.tree_map_with_path# jax.tree_util.tree_map_with_path(f, tree, *rest, is_leaf=None, is_leaf_takes_path=False)[source]# Alias of jax.tree.map_with_path(). Parameters: f (Callable[..., Any]) tree (Any) rest (Any) is_leaf (Callable[..., bool] | None) is_leaf_takes_path (bool) Return type: Any