jax.numpy.frompyfunc#
- jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[source]#
Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
- Parameters:
func (Callable[..., Any]) – a callable that takes nin scalar arguments and returns nout outputs.
nin (int) – integer specifying the number of scalar inputs
nout (int) – integer specifying the number of scalar outputs
identity (Any) – (optional) a scalar specifying the identity of the operation, if any.
- Returns:
jax.numpy.ufunc wrapper of func.
- Return type:
wrapped
Examples
Here is an example of creating a ufunc similar to
jax.numpy.add
:>>> import operator >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)
Now all the standard
jax.numpy.ufunc
methods are available:>>> x = jnp.arange(4) >>> add(x, 10) Array([10, 11, 12, 13], dtype=int32) >>> add.outer(x, x) Array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], dtype=int32) >>> add.reduce(x) Array(6, dtype=int32) >>> add.accumulate(x) Array([0, 1, 3, 6], dtype=int32) >>> add.at(x, 1, 10, inplace=False) Array([ 0, 11, 2, 3], dtype=int32)