jax.experimental.custom_dce module#

API#

custom_dce(fun, *[, static_argnums])

Customize the DCE behavior of a JAX-transformable function.

custom_dce.def_dce(dce_rule)

Define a custom DCE rule for this function.