jax.Array#
- class jax.Array#
Array base class for JAX
jax.Array
is the public interface for instance checks and type annotation of JAX arrays and tracers. Its main applications are in instance checks and type annotations; for example:x = jnp.arange(5) isinstance(x, jax.Array) # returns True both inside and outside traced functions. def f(x: Array) -> Array: # type annotations are valid for traced and non-traced types. return x
jax.Array
should not be used directly for creation of arrays; instead you should use array creation routines offered injax.numpy
, such asjax.numpy.array()
,jax.numpy.zeros()
,jax.numpy.ones()
,jax.numpy.full()
,jax.numpy.arange()
, etc.- __init__()#
Methods
__init__
()addressable_data
(index)Return an array of the addressable data at a particular index.
all
([axis, out, keepdims, where])Test whether all array elements along a given axis evaluate to True.
any
([axis, out, keepdims, where])Test whether any array elements along a given axis evaluate to True.
argmax
([axis, out, keepdims])Return the index of the maximum value.
argmin
([axis, out, keepdims])Return the index of the minimum value.
argpartition
(kth[, axis])Return the indices that partially sort the array.
argsort
([axis, kind, order, stable, descending])Return the indices that sort the array.
astype
(dtype[, copy, device])Copy the array and cast to a specified dtype.
choose
(choices[, out, mode])Construct an array choosing from elements of multiple arrays.
clip
([min, max])Return an array whose values are limited to a specified range.
compress
(condition[, axis, out, size, ...])Return selected slices of this array along given axis.
conj
()Return the complex conjugate of the array.
Return the complex conjugate of the array.
copy
()Return a copy of the array.
Copies an
Array
to the host asynchronously.cumprod
([axis, dtype, out])Return the cumulative product of the array.
cumsum
([axis, dtype, out])Return the cumulative sum of the array.
diagonal
([offset, axis1, axis2])Return the specified diagonal from the array.
dot
(b, *[, precision, preferred_element_type])Compute the dot product of two arrays.
flatten
([order])Flatten array into a 1-dimensional shape.
item
(*args)Copy an element of an array to a standard Python scalar and return it.
max
([axis, out, keepdims, initial, where])Return the maximum of array elements along a given axis.
mean
([axis, dtype, out, keepdims, where])Return the mean of array elements along a given axis.
min
([axis, out, keepdims, initial, where])Return the minimum of array elements along a given axis.
nonzero
(*[, fill_value, size])Return indices of nonzero elements of an array.
prod
([axis, dtype, out, keepdims, initial, ...])Return product of the array elements over a given axis.
ptp
([axis, out, keepdims])Return the peak-to-peak range along a given axis.
ravel
([order])Flatten array into a 1-dimensional shape.
repeat
(repeats[, axis, total_repeat_length])Construct an array from repeated elements.
reshape
(*args[, order])Returns an array containing the same data with a new shape.
round
([decimals, out])Round array elements to a given decimal.
searchsorted
(v[, side, sorter, method])Perform a binary search within a sorted array.
sort
([axis, kind, order, stable, descending])Return a sorted copy of an array.
squeeze
([axis])Remove one or more length-1 axes from array.
std
([axis, dtype, out, ddof, keepdims, ...])Compute the standard deviation along a given axis.
sum
([axis, dtype, out, keepdims, initial, ...])Sum of the elements of the array over a given axis.
swapaxes
(axis1, axis2)Swap two axes of an array.
take
(indices[, axis, out, mode, ...])Take elements from an array.
to_device
(device, *[, stream])Return a copy of the array on the specified device
trace
([offset, axis1, axis2, dtype, out])Return the sum along the diagonal.
transpose
(*args)Returns a copy of the array with axes transposed.
var
([axis, dtype, out, ddof, keepdims, ...])Compute the variance along a given axis.
view
([dtype, type])Return a bitwise copy of the array, viewed as a new dtype.
Attributes
Compute the all-axis array transpose.
List of addressable shards.
Helper property for index update functionality.
Whether the array is committed or not.
Array API-compatible device attribute.
The data type (
numpy.dtype
) of the array.Use
flatten()
instead.List of global shards.
Return the imaginary part of the array.
Is this Array fully addressable?
Is this Array fully replicated?
Length of one array element in bytes.
Compute the (batched) matrix transpose.
Total bytes consumed by the elements of the array.
The number of dimensions in the array.
Return the real part of the array.
The shape of the array.
The sharding for the array.
The total number of elements in the array.