jax.numpy module#

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

  • Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide an alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function x.at[i].set(y) (see ndarray.at).

  • Relatedly, some NumPy functions often return views of arrays when possible (examples are transpose() and reshape()). JAX versions of such functions will return copies instead, although such are often optimized away by XLA when sequences of operations are compiled using jax.jit().

  • NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion (See Type promotion semantics).

  • Some NumPy routines have data-dependent output shapes (examples include unique() and nonzero()). Because the XLA compiler requires array shapes to be known at compile time, such operations are not compatible with JIT. For this reason, JAX adds an optional size argument to such functions which may be specified statically in order to use them with JIT.

Nearly all applicable NumPy functions are implemented in the jax.numpy namespace; they are listed below.

ndarray.at

Helper property for index update functionality.

abs(x, /)

Alias of jax.numpy.absolute().

absolute(x, /)

Calculate the absolute value element-wise.

acos(x, /)

Alias of jax.numpy.arccos()

acosh(x, /)

Alias of jax.numpy.arccosh()

add

Add two arrays element-wise.

all(a[, axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

allclose(a, b[, rtol, atol, equal_nan])

Check if two arrays are element-wise approximately equal within a tolerance.

amax(a[, axis, out, keepdims, initial, where])

Alias of jax.numpy.max().

amin(a[, axis, out, keepdims, initial, where])

Alias of jax.numpy.min().

angle(z[, deg])

Return the angle of a complex valued number or array.

any(a[, axis, out, keepdims, where])

Test whether any of the array elements along a given axis evaluate to True.

append(arr, values[, axis])

Return a new array with values appended to the end of the original array.

apply_along_axis(func1d, axis, arr, *args, ...)

Apply a function to 1D array slices along an axis.

apply_over_axes(func, a, axes)

Apply a function repeatedly over specified axes.

arange(start[, stop, step, dtype, device])

Create an array of evenly-spaced values.

arccos(x, /)

Compute element-wise inverse of trigonometric cosine of input.

arccosh(x, /)

Calculate element-wise inverse of hyperbolic cosine of input.

arcsin(x, /)

Compute element-wise inverse of trigonometric sine of input.

arcsinh(x, /)

Calculate element-wise inverse of hyperbolic sine of input.

arctan(x, /)

Compute element-wise inverse of trigonometric tangent of input.

arctan2(x1, x2, /)

Compute the arctangent of x1/x2, choosing the correct quadrant.

arctanh(x, /)

Calculate element-wise inverse of hyperbolic tangent of input.

argmax(a[, axis, out, keepdims])

Return the index of the maximum value of an array.

argmin(a[, axis, out, keepdims])

Return the index of the minimum value of an array.

argpartition(a, kth[, axis])

Returns indices that partially sort an array.

argsort(a[, axis, kind, order, stable, ...])

Return indices that sort an array.

argwhere(a, *[, size, fill_value])

Find the indices of nonzero array elements

around(a[, decimals, out])

Alias of jax.numpy.round()

array(object[, dtype, copy, order, ndmin, ...])

Convert an object to a JAX array.

array_equal(a1, a2[, equal_nan])

Check if two arrays are element-wise equal.

array_equiv(a1, a2)

Check if two arrays are element-wise equal.

array_repr(arr[, max_line_width, precision, ...])

Return the string representation of an array.

array_split(ary, indices_or_sections[, axis])

Split an array into sub-arrays.

array_str(a[, max_line_width, precision, ...])

Return a string representation of the data in an array.

asarray(a[, dtype, order, copy, device])

Convert an object to a JAX array.

asin(x, /)

Alias of jax.numpy.arcsin()

asinh(x, /)

Alias of jax.numpy.arcsinh()

astype(x, dtype, /, *[, copy, device])

Convert an array to a specified dtype.

atan(x, /)

Alias of jax.numpy.arctan()

atanh(x, /)

Alias of jax.numpy.arctanh()

atan2(x1, x2, /)

Alias of jax.numpy.arctan2()

atleast_1d(*arys)

Convert inputs to arrays with at least 1 dimension.

atleast_2d(*arys)

Convert inputs to arrays with at least 2 dimensions.

atleast_3d(*arys)

Convert inputs to arrays with at least 3 dimensions.

average(a[, axis, weights, returned, keepdims])

Compute the weighed average.

bartlett(M)

Return a Bartlett window of size M.

bincount(x[, weights, minlength, length])

Count the number of occurrences of each value in an integer array.

bitwise_and

Compute the bitwise AND operation elementwise.

bitwise_count(x, /)

Counts the number of 1 bits in the binary representation of the absolute value of each element of x.

bitwise_invert(x, /)

Alias of jax.numpy.invert().

bitwise_left_shift(x, y, /)

Alias of jax.numpy.left_shift().

bitwise_not(x, /)

Alias of jax.numpy.invert().

bitwise_or

Compute the bitwise OR operation elementwise.

bitwise_right_shift(x1, x2, /)

Alias of jax.numpy.right_shift().

bitwise_xor

Compute the bitwise XOR operation elementwise.

blackman(M)

Return a Blackman window of size M.

block(arrays)

Create an array from a list of blocks.

bool_

alias of bool

broadcast_arrays(*args)

Broadcast arrays to a common shape.

broadcast_shapes(*shapes)

Broadcast input shapes to a common output shape.

broadcast_to(array, shape)

Broadcast an array to a specified shape.

c_

Concatenate slices, scalars and array-like objects along the last axis.

can_cast(from_, to[, casting])

Returns True if cast between data types can occur according to the casting rule.

cbrt(x, /)

Calculates element-wise cube root of the input array.

cdouble

alias of complex128

ceil(x, /)

Round input to the nearest integer upwards.

character()

Abstract base class of all character string scalar types.

choose(a, choices[, out, mode])

Construct an array by stacking slices of choice arrays.

clip([arr, min, max, a, a_min, a_max])

Clip array values to a specified range.

column_stack(tup)

Stack arrays column-wise.

complex_

alias of complex128

complex128(x)

A JAX scalar constructor of type complex128.

complex64(x)

A JAX scalar constructor of type complex64.

complexfloating()

Abstract base class of all complex number scalar types that are made up of floating-point numbers.

ComplexWarning

The warning raised when casting a complex dtype to a real dtype.

compress(condition, a[, axis, size, ...])

Compress an array along a given axis using a boolean condition.

concat(arrays, /, *[, axis])

Join arrays along an existing axis.

concatenate(arrays[, axis, dtype])

Join arrays along an existing axis.

conj(x, /)

Alias of jax.numpy.conjugate()

conjugate(x, /)

Return element-wise complex-conjugate of the input.

convolve(a, v[, mode, precision, ...])

Convolution of two one dimensional arrays.

copy(a[, order])

Return a copy of the array.

copysign(x1, x2, /)

Copies the sign of each element in x2 to the corresponding element in x1.

corrcoef(x[, y, rowvar])

Compute the Pearson correlation coefficients.

correlate(a, v[, mode, precision, ...])

Correlation of two one dimensional arrays.

cos(x, /)

Compute a trigonometric cosine of each element of input.

cosh(x, /)

Calculate element-wise hyperbolic cosine of input.

count_nonzero(a[, axis, keepdims])

Return the number of nonzero elements along a given axis.

cov(m[, y, rowvar, bias, ddof, fweights, ...])

Estimate the weighted sample covariance.

cross(a, b[, axisa, axisb, axisc, axis])

Compute the (batched) cross product of two arrays.

csingle

alias of complex64

cumprod(a[, axis, dtype, out])

Cumulative product of elements along an axis.

cumsum(a[, axis, dtype, out])

Cumulative sum of elements along an axis.

cumulative_prod(x, /, *[, axis, dtype, ...])

Cumulative product along the axis of an array.

cumulative_sum(x, /, *[, axis, dtype, ...])

Cumulative sum along the axis of an array.

deg2rad(x, /)

Convert angles from degrees to radians.

degrees(x, /)

Alias of jax.numpy.rad2deg()

delete(arr, obj[, axis, assume_unique_indices])

Delete entry or entries from an array.

diag(v[, k])

Returns the specified diagonal or constructs a diagonal array.

diag_indices(n[, ndim])

Return indices for accessing the main diagonal of a multidimensional array.

diag_indices_from(arr)

Return indices for accessing the main diagonal of a given array.

diagflat(v[, k])

Return a 2-D array with the flattened input array laid out on the diagonal.

diagonal(a[, offset, axis1, axis2])

Returns the specified diagonal of an array.

diff(a[, n, axis, prepend, append])

Calculate n-th order difference between array elements along a given axis.

digitize(x, bins[, right, method])

Convert an array to bin indices.

divide(x1, x2, /)

Alias of jax.numpy.true_divide().

divmod(x1, x2, /)

Calculates the integer quotient and remainder of x1 by x2 element-wise

dot(a, b, *[, precision, preferred_element_type])

Compute the dot product of two arrays.

double

alias of float64

dsplit(ary, indices_or_sections)

Split an array into sub-arrays depth-wise.

dstack(tup[, dtype])

Stack arrays depth-wise.

dtype(dtype[, align, copy])

Create a data type object.

ediff1d(ary[, to_end, to_begin])

Compute the differences of the elements of the flattened array.

einsum(subscripts, /, *operands[, out, ...])

Einstein summation

einsum_path(subscripts, /, *operands[, optimize])

Evaluates the optimal contraction path without evaluating the einsum.

empty(shape[, dtype, device])

Create an empty array.

empty_like(prototype[, dtype, shape, device])

Create an empty array with the same shape and dtype as an array.

equal(x, y, /)

Returns element-wise truth value of x == y.

exp(x, /)

Calculate element-wise exponential of the input.

exp2(x, /)

Calculate element-wise base-2 exponential of input.

expand_dims(a, axis)

Insert dimensions of length 1 into array

expm1(x, /)

Calculate exp(x)-1 of each element of the input.

extract(condition, arr, *[, size, fill_value])

Return the elements of an array that satisfy a condition.

eye(N[, M, k, dtype, device])

Create a square or rectangular identity matrix

fabs(x, /)

Compute the element-wise absolute values of the real-valued input.

fill_diagonal(a, val[, wrap, inplace])

Return a copy of the array with the diagonal overwritten.

finfo(dtype)

Machine limits for floating point types.

fix(x[, out])

Round input to the nearest integer towards zero.

flatnonzero(a, *[, size, fill_value])

Return indices of nonzero elements in a flattened array

flexible()

Abstract base class of all scalar types without predefined length.

flip(m[, axis])

Reverse the order of elements of an array along the given axis.

fliplr(m)

Reverse the order of elements of an array along axis 1.

flipud(m)

Reverse the order of elements of an array along axis 0.

float_

alias of float64

float_power(x, y, /)

Calculate element-wise base x exponential of y.

float16(x)

A JAX scalar constructor of type float16.

float32(x)

A JAX scalar constructor of type float32.

float64(x)

A JAX scalar constructor of type float64.

floating()

Abstract base class of all floating-point scalar types.

floor(x, /)

Round input to the nearest integer downwards.

floor_divide(x1, x2, /)

Calculates the floor division of x1 by x2 element-wise

fmax(x1, x2)

Return element-wise maximum of the input arrays.

fmin(x1, x2)

Return element-wise minimum of the input arrays.

fmod(x1, x2, /)

Calculate element-wise floating-point modulo operation.

frexp(x, /)

Split floating point values into mantissa and twos exponent.

frombuffer(buffer[, dtype, count, offset])

Convert a buffer into a 1-D JAX array.

fromfile(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromfile.

fromfunction(function, shape, *[, dtype])

Create an array from a function applied over indices.

fromiter(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromiter.

frompyfunc(func, /, nin, nout, *[, identity])

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

fromstring(string[, dtype, count])

Convert a string of text into 1-D JAX array.

from_dlpack(x, /, *[, device, copy])

Construct a JAX array via DLPack.

full(shape, fill_value[, dtype, device])

Create an array full of a specified value.

full_like(a, fill_value[, dtype, shape, device])

Create an array full of a specified value with the same shape and dtype as an array.

gcd(x1, x2)

Compute the greatest common divisor of two arrays.

generic()

Base class for numpy scalar types.

geomspace(start, stop[, num, endpoint, ...])

Generate geometrically-spaced values.

get_printoptions()

Alias of numpy.get_printoptions().

gradient(f, *varargs[, axis, edge_order])

Compute the numerical gradient of a sampled function.

greater(x, y, /)

Return element-wise truth value of x > y.

greater_equal(x, y, /)

Return element-wise truth value of x >= y.

hamming(M)

Return a Hamming window of size M.

hanning(M)

Return a Hanning window of size M.

heaviside(x1, x2, /)

Compute the heaviside step function.

histogram(a[, bins, range, weights, density])

Compute a 1-dimensional histogram.

histogram_bin_edges(a[, bins, range, weights])

Compute the bin edges for a histogram.

histogram2d(x, y[, bins, range, weights, ...])

Compute a 2-dimensional histogram.

histogramdd(sample[, bins, range, weights, ...])

Compute an N-dimensional histogram.

hsplit(ary, indices_or_sections)

Split an array into sub-arrays horizontally.

hstack(tup[, dtype])

Horizontally stack arrays.

hypot(x1, x2, /)

Return element-wise hypotenuse for the given legs of a right angle triangle.

i0(x)

Calculate modified Bessel function of first kind, zeroth order.

identity(n[, dtype])

Create a square identity matrix

iinfo(int_type)

imag(val, /)

Return element-wise imaginary of part of the complex argument.

index_exp

A nicer way to build up index tuples for arrays.

indices(dimensions[, dtype, sparse])

Generate arrays of grid indices.

inexact()

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers.

inner(a, b, *[, precision, ...])

Compute the inner product of two arrays.

insert(arr, obj, values[, axis])

Insert entries into an array at specified indices.

int_

alias of int64

int16(x)

A JAX scalar constructor of type int16.

int32(x)

A JAX scalar constructor of type int32.

int64(x)

A JAX scalar constructor of type int64.

int8(x)

A JAX scalar constructor of type int8.

integer()

Abstract base class of all integer scalar types.

interp(x, xp, fp[, left, right, period])

One-dimensional linear interpolation.

intersect1d(ar1, ar2[, assume_unique, ...])

Compute the set intersection of two 1D arrays.

invert(x, /)

Compute the bitwise inversion of an input.

isclose(a, b[, rtol, atol, equal_nan])

Check if the elements of two arrays are approximately equal within a tolerance.

iscomplex(x)

Return boolean array showing where the input is complex.

iscomplexobj(x)

Check if the input is a complex number or an array containing complex elements.

isdtype(dtype, kind)

Returns a boolean indicating whether a provided dtype is of a specified kind.

isfinite(x, /)

Return a boolean array indicating whether each element of input is finite.

isin(element, test_elements[, ...])

Determine whether elements in element appear in test_elements.

isinf(x, /)

Return a boolean array indicating whether each element of input is infinite.

isnan(x, /)

Returns a boolean array indicating whether each element of input is NaN.

isneginf(x, /[, out])

Return boolean array indicating whether each element of input is negative infinite.

isposinf(x, /[, out])

Return boolean array indicating whether each element of input is positive infinite.

isreal(x)

Return boolean array showing where the input is real.

isrealobj(x)

Check if the input is not a complex number or an array containing complex elements.

isscalar(element)

Return True if the input is a scalar.

issubdtype(arg1, arg2)

Return True if arg1 is equal or lower than arg2 in the type hierarchy.

iterable(y)

Check whether or not an object can be iterated over.

ix_(*args)

Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.

kaiser(M, beta)

Return a Kaiser window of size M.

kron(a, b)

Compute the Kronecker product of two input arrays.

lcm(x1, x2)

Compute the least common multiple of two arrays.

ldexp(x1, x2, /)

Compute x1 * 2 ** x2

left_shift(x, y, /)

Shift bits of x to left by the amount specified in y, element-wise.

less(x, y, /)

Return element-wise truth value of x < y.

less_equal(x, y, /)

Return element-wise truth value of x <= y.

lexsort(keys[, axis])

Sort a sequence of keys in lexicographic order.

linspace(start, stop[, num, endpoint, ...])

Return evenly-spaced numbers within an interval.

load(file, *args, **kwargs)

Load JAX arrays from npy files.

log(x, /)

Calculate element-wise natural logarithm of the input.

log10(x, /)

Calculates the base-10 logarithm of x element-wise

log1p(x, /)

Calculates element-wise logarithm of one plus input, log(x+1).

log2(x, /)

Calculates the base-2 logarithm of x element-wise.

logaddexp

Compute log(exp(x1) + exp(x2)) avoiding overflow.

logaddexp2

Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.

logical_and

Compute the logical AND operation elementwise.

logical_not(x, /)

Compute NOT bool(x) element-wise.

logical_or

Compute the logical OR operation elementwise.

logical_xor

Compute the logical XOR operation elementwise.

logspace(start, stop[, num, endpoint, base, ...])

Generate logarithmically-spaced values.

mask_indices(n, mask_func[, k, size])

Return indices of a mask of an (n, n) array.

matmul(a, b, *[, precision, ...])

Perform a matrix multiplication.

matrix_transpose(x, /)

Transpose the last two dimensions of an array.

matvec(x1, x2, /)

Batched matrix-vector product.

max(a[, axis, out, keepdims, initial, where])

Return the maximum of the array elements along a given axis.

maximum(x, y, /)

Return element-wise maximum of the input arrays.

mean(a[, axis, dtype, out, keepdims, where])

Return the mean of array elements along a given axis.

median(a[, axis, out, overwrite_input, keepdims])

Return the median of array elements along a given axis.

meshgrid(*xi[, copy, sparse, indexing])

Construct N-dimensional grid arrays from N 1-dimensional vectors.

mgrid

Return dense multi-dimensional "meshgrid".

min(a[, axis, out, keepdims, initial, where])

Return the minimum of array elements along a given axis.

minimum(x, y, /)

Return element-wise minimum of the input arrays.

mod(x1, x2, /)

Alias of jax.numpy.remainder()

modf(x, /[, out])

Return element-wise fractional and integral parts of the input array.

moveaxis(a, source, destination)

Move an array axis to a new position

multiply

Multiply two arrays element-wise.

nan_to_num(x[, copy, nan, posinf, neginf])

Replace NaN and infinite entries in an array.

nanargmax(a[, axis, out, keepdims])

Return the index of the maximum value of an array, ignoring NaNs.

nanargmin(a[, axis, out, keepdims])

Return the index of the minimum value of an array, ignoring NaNs.

nancumprod(a[, axis, dtype, out])

Cumulative product of elements along an axis, ignoring NaN values.

nancumsum(a[, axis, dtype, out])

Cumulative sum of elements along an axis, ignoring NaN values.

nanmax(a[, axis, out, keepdims, initial, where])

Return the maximum of the array elements along a given axis, ignoring NaNs.

nanmean(a[, axis, dtype, out, keepdims, where])

Return the mean of the array elements along a given axis, ignoring NaNs.

nanmedian(a[, axis, out, overwrite_input, ...])

Return the median of array elements along a given axis, ignoring NaNs.

nanmin(a[, axis, out, keepdims, initial, where])

Return the minimum of the array elements along a given axis, ignoring NaNs.

nanpercentile(a, q[, axis, out, ...])

Compute the percentile of the data along the specified axis, ignoring NaN values.

nanprod(a[, axis, dtype, out, keepdims, ...])

Return the product of the array elements along a given axis, ignoring NaNs.

nanquantile(a, q[, axis, out, ...])

Compute the quantile of the data along the specified axis, ignoring NaNs.

nanstd(a[, axis, dtype, out, ddof, ...])

Compute the standard deviation along a given axis, ignoring NaNs.

nansum(a[, axis, dtype, out, keepdims, ...])

Return the sum of the array elements along a given axis, ignoring NaNs.

nanvar(a[, axis, dtype, out, ddof, ...])

Compute the variance of array elements along a given axis, ignoring NaNs.

ndarray

alias of Array

ndim(a)

Return the number of dimensions of an array.

negative

Return element-wise negative values of the input.

nextafter(x, y, /)

Return element-wise next floating point value after x towards y.

nonzero(a, *[, size, fill_value])

Return indices of nonzero elements of an array.

not_equal(x, y, /)

Returns element-wise truth value of x != y.

number()

Abstract base class of all numeric scalar types.

object_

Any Python object.

ogrid

Return open multi-dimensional "meshgrid".

ones(shape[, dtype, device])

Create an array full of ones.

ones_like(a[, dtype, shape, device])

Create an array of ones with the same shape and dtype as an array.

outer(a, b[, out])

Compute the outer product of two arrays.

packbits(a[, axis, bitorder])

Pack array of bits into a uint8 array.

pad(array, pad_width[, mode])

Add padding to an array.

partition(a, kth[, axis])

Returns a partially-sorted copy of an array.

percentile(a, q[, axis, out, ...])

Compute the percentile of the data along the specified axis.

permute_dims(a, /, axes)

Permute the axes/dimensions of an array.

piecewise(x, condlist, funclist, *args, **kw)

Evaluate a function defined piecewise across the domain.

place(arr, mask, vals, *[, inplace])

Update array elements based on a mask.

poly(seq_of_zeros)

Returns the coefficients of a polynomial for the given sequence of roots.

polyadd(a1, a2)

Returns the sum of the two polynomials.

polyder(p[, m])

Returns the coefficients of the derivative of specified order of a polynomial.

polydiv(u, v, *[, trim_leading_zeros])

Returns the quotient and remainder of polynomial division.

polyfit(x, y, deg[, rcond, full, w, cov])

Least squares polynomial fit to data.

polyint(p[, m, k])

Returns the coefficients of the integration of specified order of a polynomial.

polymul(a1, a2, *[, trim_leading_zeros])

Returns the product of two polynomials.

polysub(a1, a2)

Returns the difference of two polynomials.

polyval(p, x, *[, unroll])

Evaluates the polynomial at specific values.

positive(x, /)

Return element-wise positive values of the input.

pow(x1, x2, /)

Alias of jax.numpy.power()

power(x1, x2, /)

Calculate element-wise base x1 exponential of x2.

printoptions(*args, **kwargs)

Alias of numpy.printoptions().

prod(a[, axis, dtype, out, keepdims, ...])

Return product of the array elements over a given axis.

promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

ptp(a[, axis, out, keepdims])

Return the peak-to-peak range along a given axis.

put(a, ind, v[, mode, inplace])

Put elements into an array at given indices.

put_along_axis(arr, indices, values, axis[, ...])

Put values into the destination array by matching 1d index and data slices.

quantile(a, q[, axis, out, overwrite_input, ...])

Compute the quantile of the data along the specified axis.

r_

Concatenate slices, scalars and array-like objects along the first axis.

rad2deg(x, /)

Convert angles from radians to degrees.

radians(x, /)

Alias of jax.numpy.deg2rad()

ravel(a[, order])

Flatten array into a 1-dimensional shape.

ravel_multi_index(multi_index, dims[, mode, ...])

Convert multi-dimensional indices into flat indices.

real(val, /)

Return element-wise real part of the complex argument.

reciprocal(x, /)

Calculate element-wise reciprocal of the input.

remainder(x1, x2, /)

Returns element-wise remainder of the division.

repeat(a, repeats[, axis, total_repeat_length])

Construct an array from repeated elements.

reshape(a[, shape, order, newshape, copy])

Return a reshaped copy of an array.

resize(a, new_shape)

Return a new array with specified shape.

result_type(*args)

Return the result of applying JAX promotion rules to the inputs.

right_shift(x1, x2, /)

Right shift the bits of x1 to the amount specified in x2.

rint(x, /)

Rounds the elements of x to the nearest integer

roll(a, shift[, axis])

Roll the elements of an array along a specified axis.

rollaxis(a, axis[, start])

Roll the specified axis to a given position.

roots(p, *[, strip_zeros])

Returns the roots of a polynomial given the coefficients p.

rot90(m[, k, axes])

Rotate an array by 90 degrees counterclockwise in the plane specified by axes.

round(a[, decimals, out])

Round input evenly to the given number of decimals.

s_

A nicer way to build up index tuples for arrays.

save(file, arr[, allow_pickle, fix_imports])

Save an array to a binary file in NumPy .npy format.

savez(file, *args[, allow_pickle])

Save several arrays into a single file in uncompressed .npz format.

searchsorted(a, v[, side, sorter, method])

Perform a binary search within a sorted array.

select(condlist, choicelist[, default])

Select values based on a series of conditions.

set_printoptions(*args, **kwargs)

Alias of numpy.set_printoptions().

setdiff1d(ar1, ar2[, assume_unique, size, ...])

Compute the set difference of two 1D arrays.

setxor1d(ar1, ar2[, assume_unique, size, ...])

Compute the set-wise xor of elements in two arrays.

shape(a)

Return the shape an array.

sign(x, /)

Return an element-wise indication of sign of the input.

signbit(x, /)

Return the sign bit of array elements.

signedinteger()

Abstract base class of all signed integer scalar types.

sin(x, /)

Compute a trigonometric sine of each element of input.

sinc(x, /)

Calculate the normalized sinc function.

single

alias of float32

sinh(x, /)

Calculate element-wise hyperbolic sine of input.

size(a[, axis])

Return number of elements along a given axis.

sort(a[, axis, kind, order, stable, descending])

Return a sorted copy of an array.

sort_complex(a)

Return a sorted copy of complex array.

spacing(x, /)

Return the spacing between x and the next adjacent number.

split(ary, indices_or_sections[, axis])

Split an array into sub-arrays.

sqrt(x, /)

Calculates element-wise non-negative square root of the input array.

square(x, /)

Calculate element-wise square of the input array.

squeeze(a[, axis])

Remove one or more length-1 axes from array

stack(arrays[, axis, out, dtype])

Join arrays along a new axis.

std(a[, axis, dtype, out, ddof, keepdims, ...])

Compute the standard deviation along a given axis.

subtract

Subtract two arrays element-wise.

sum(a[, axis, dtype, out, keepdims, ...])

Sum of the elements of the array over a given axis.

swapaxes(a, axis1, axis2)

Swap two axes of an array.

take(a, indices[, axis, out, mode, ...])

Take elements from an array.

take_along_axis(arr, indices, axis[, mode, ...])

Take elements from an array.

tan(x, /)

Compute a trigonometric tangent of each element of input.

tanh(x, /)

Calculate element-wise hyperbolic tangent of input.

tensordot(a, b[, axes, precision, ...])

Compute the tensor dot product of two N-dimensional arrays.

tile(A, reps)

Construct an array by repeating A along specified dimensions.

trace(a[, offset, axis1, axis2, dtype, out])

Calculate sum of the diagonal of input along the given axes.

trapezoid(y[, x, dx, axis])

Integrate along the given axis using the composite trapezoidal rule.

transpose(a[, axes])

Return a transposed version of an N-dimensional array.

tri(N[, M, k, dtype])

Return an array with ones on and below the diagonal and zeros elsewhere.

tril(m[, k])

Return lower triangle of an array.

tril_indices(n[, k, m])

Return the indices of lower triangle of an array of size (n, m).

tril_indices_from(arr[, k])

Return the indices of lower triangle of a given array.

trim_zeros(filt[, trim])

Trim leading and/or trailing zeros of the input array.

triu(m[, k])

Return upper triangle of an array.

triu_indices(n[, k, m])

Return the indices of upper triangle of an array of size (n, m).

triu_indices_from(arr[, k])

Return the indices of upper triangle of a given array.

true_divide(x1, x2, /)

Calculates the division of x1 by x2 element-wise

trunc(x)

Round input to the nearest integer towards zero.

ufunc(func, /, nin, nout, *[, name, nargs, ...])

Universal functions which operation element-by-element on arrays.

uint

alias of uint64

uint16(x)

A JAX scalar constructor of type uint16.

uint32(x)

A JAX scalar constructor of type uint32.

uint64(x)

A JAX scalar constructor of type uint64.

uint8(x)

A JAX scalar constructor of type uint8.

union1d(ar1, ar2, *[, size, fill_value])

Compute the set union of two 1D arrays.

unique(ar[, return_index, return_inverse, ...])

Return the unique values from an array.

unique_all(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unique_counts(x, /, *[, size, fill_value])

Return unique values from x, along with counts.

unique_inverse(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unique_values(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unpackbits(a[, axis, count, bitorder])

Unpack the bits in a uint8 array.

unravel_index(indices, shape)

Convert flat indices into multi-dimensional indices.

unstack(x, /, *[, axis])

Unstack an array along an axis.

unsignedinteger()

Abstract base class of all unsigned integer scalar types.

unwrap(p[, discont, axis, period])

Unwrap a periodic signal.

vander(x[, N, increasing])

Generate a Vandermonde matrix.

var(a[, axis, dtype, out, ddof, keepdims, ...])

Compute the variance along a given axis.

vdot(a, b, *[, precision, ...])

Perform a conjugate multiplication of two 1D vectors.

vecdot(x1, x2, /, *[, axis, precision, ...])

Perform a conjugate multiplication of two batched vectors.

vecmat(x1, x2, /)

Batched conjugate vector-matrix product.

vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

vsplit(ary, indices_or_sections)

Split an array into sub-arrays vertically.

vstack(tup[, dtype])

Vertically stack arrays.

where(condition[, x, y, size, fill_value])

Select elements from two arrays based on a condition.

zeros(shape[, dtype, device])

Create an array full of zeros.

zeros_like(a[, dtype, shape, device])

Create an array full of zeros with the same shape and dtype as an array.

jax.numpy.fft#

fft(a[, n, axis, norm])

Compute a one-dimensional discrete Fourier transform along a given axis.

fft2(a[, s, axes, norm])

Compute a two-dimensional discrete Fourier transform along given axes.

fftfreq(n[, d, dtype, device])

Return sample frequencies for the discrete Fourier transform.

fftn(a[, s, axes, norm])

Compute a multidimensional discrete Fourier transform along given axes.

fftshift(x[, axes])

Shift zero-frequency fft component to the center of the spectrum.

hfft(a[, n, axis, norm])

Compute a 1-D FFT of an array whose spectrum has Hermitian symmetry.

ifft(a[, n, axis, norm])

Compute a one-dimensional inverse discrete Fourier transform.

ifft2(a[, s, axes, norm])

Compute a two-dimensional inverse discrete Fourier transform.

ifftn(a[, s, axes, norm])

Compute a multidimensional inverse discrete Fourier transform.

ifftshift(x[, axes])

The inverse of jax.numpy.fft.fftshift().

ihfft(a[, n, axis, norm])

Compute a 1-D inverse FFT of an array whose spectrum has Hermitian-symmetry.

irfft(a[, n, axis, norm])

Compute a real-valued one-dimensional inverse discrete Fourier transform.

irfft2(a[, s, axes, norm])

Compute a real-valued two-dimensional inverse discrete Fourier transform.

irfftn(a[, s, axes, norm])

Compute a real-valued multidimensional inverse discrete Fourier transform.

rfft(a[, n, axis, norm])

Compute a one-dimensional discrete Fourier transform of a real-valued array.

rfft2(a[, s, axes, norm])

Compute a two-dimensional discrete Fourier transform of a real-valued array.

rfftfreq(n[, d, dtype, device])

Return sample frequencies for the discrete Fourier transform.

rfftn(a[, s, axes, norm])

Compute a multidimensional discrete Fourier transform of a real-valued array.

jax.numpy.linalg#

cholesky(a, *[, upper])

Compute the Cholesky decomposition of a matrix.

cond(x[, p])

Compute the condition number of a matrix.

cross(x1, x2, /, *[, axis])

Compute the cross-product of two 3D vectors

det(a)

Compute the determinant of an array.

diagonal(x, /, *[, offset])

Extract the diagonal of an matrix or stack of matrices.

eig(a)

Compute the eigenvalues and eigenvectors of a square array.

eigh(a[, UPLO, symmetrize_input])

Compute the eigenvalues and eigenvectors of a Hermitian matrix.

eigvals(a)

Compute the eigenvalues of a general matrix.

eigvalsh(a[, UPLO])

Compute the eigenvalues of a Hermitian matrix.

inv(a)

Return the inverse of a square matrix

lstsq(a, b[, rcond, numpy_resid])

Return the least-squares solution to a linear equation.

matmul(x1, x2, /, *[, precision, ...])

Perform a matrix multiplication.

matrix_norm(x, /, *[, keepdims, ord])

Compute the norm of a matrix or stack of matrices.

matrix_power(a, n)

Raise a square matrix to an integer power.

matrix_rank(M[, rtol, tol])

Compute the rank of a matrix.

matrix_transpose(x, /)

Transpose a matrix or stack of matrices.

multi_dot(arrays, *[, precision])

Efficiently compute matrix products between a sequence of arrays.

norm(x[, ord, axis, keepdims])

Compute the norm of a matrix or vector.

outer(x1, x2, /)

Compute the outer product of two 1-dimensional arrays.

pinv(a[, rtol, hermitian, rcond])

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

qr(a[, mode])

Compute the QR decomposition of an array

slogdet(a, *[, method])

Compute the sign and (natural) logarithm of the determinant of an array.

solve(a, b)

Solve a linear system of equations.

svd(a[, full_matrices, compute_uv, ...])

Compute the singular value decomposition.

svdvals(x, /)

Compute the singular values of a matrix.

tensordot(x1, x2, /, *[, axes, precision, ...])

Compute the tensor dot product of two N-dimensional arrays.

tensorinv(a[, ind])

Compute the tensor inverse of an array.

tensorsolve(a, b[, axes])

Solve the tensor equation a x = b for x.

trace(x, /, *[, offset, dtype])

Compute the trace of a matrix.

vector_norm(x, /, *[, axis, keepdims, ord])

Compute the vector norm of a vector or batch of vectors.

vecdot(x1, x2, /, *[, axis, precision, ...])

Compute the (batched) vector conjugate dot product of two arrays.

JAX Array#

The JAX Array (along with its alias, jax.numpy.ndarray) is the core array object in JAX: you can think of it as JAX’s equivalent of a numpy.ndarray. Like numpy.ndarray, most users will not need to instantiate Array objects manually, but rather will create them via jax.numpy functions like array(), arange(), linspace(), and others listed above.

Copying and Serialization#

JAX Array objects are designed to work seamlessly with Python standard library tools where appropriate.

With the built-in copy module, when copy.copy() or copy.deepcopy() encounder an Array, it is equivalent to calling the copy() method, which will create a copy of the buffer on the same device as the original array. This will work correctly within traced/JIT-compiled code, though copy operations may be elided by the compiler in this context.

When the built-in pickle module encounters an Array, it will be serialized via a compact bit representation in a similar manner to pickled numpy.ndarray objects. When unpickled, the result will be a new Array object on the default device. This is because in general, pickling and unpickling may take place in different runtime environments, and there is no general way to map the device IDs of one runtime to the device IDs of another. If pickle is used in traced/JIT-compiled code, it will result in a ConcretizationTypeError.

Python Array API standard#

Note

Prior to JAX v0.4.32, you must import jax.experimental.array_api in order to enable the array API for JAX arrays. After JAX v0.4.32, importing this module is no longer required, and will raise a deprecation warning. After JAX v0.5.0, this import will raise an error.

Starting with JAX v0.4.32, jax.Array and jax.numpy are compatible with the Python Array API Standard. You can access the Array API namespace via jax.Array.__array_namespace__():

>>> def f(x):
...   nx = x.__array_namespace__()
...   return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX departs from the standard in a few places, namely because JAX arrays are immutable, in-place updates are not supported. Some of these incompatibilities are being addressed via the array-api-compat module.

For more information, refer to the Python Array API Standard documentation.