jax.Array#

class jax.Array#

Array base class for JAX

jax.Array is the public interface for instance checks and type annotation of JAX arrays and tracers. Its main applications are in instance checks and type annotations; for example:

x = jnp.arange(5)
isinstance(x, jax.Array)  # returns True both inside and outside traced functions.

def f(x: Array) -> Array:  # type annotations are valid for traced and non-traced types.
  return x

jax.Array should not be used directly for creation of arrays; instead you should use array creation routines offered in jax.numpy, such as jax.numpy.array(), jax.numpy.zeros(), jax.numpy.ones(), jax.numpy.full(), jax.numpy.arange(), etc.

__init__()#

Methods

__init__()

addressable_data(index)

Return an array of the addressable data at a particular index.

all([axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

any([axis, out, keepdims, where])

Test whether any array elements along a given axis evaluate to True.

argmax([axis, out, keepdims])

Return the index of the maximum value.

argmin([axis, out, keepdims])

Return the index of the minimum value.

argpartition(kth[, axis])

Return the indices that partially sort the array.

argsort([axis, kind, order, stable, descending])

Return the indices that sort the array.

astype(dtype[, copy, device])

Copy the array and cast to a specified dtype.

choose(choices[, out, mode])

Construct an array choosing from elements of multiple arrays.

clip([min, max])

Return an array whose values are limited to a specified range.

compress(condition[, axis, out, size, ...])

Return selected slices of this array along given axis.

conj()

Return the complex conjugate of the array.

conjugate()

Return the complex conjugate of the array.

copy()

Return a copy of the array.

copy_to_host_async()

Copies an Array to the host asynchronously.

cumprod([axis, dtype, out])

Return the cumulative product of the array.

cumsum([axis, dtype, out])

Return the cumulative sum of the array.

diagonal([offset, axis1, axis2])

Return the specified diagonal from the array.

dot(b, *[, precision, preferred_element_type])

Compute the dot product of two arrays.

flatten([order])

Flatten array into a 1-dimensional shape.

item(*args)

Copy an element of an array to a standard Python scalar and return it.

max([axis, out, keepdims, initial, where])

Return the maximum of array elements along a given axis.

mean([axis, dtype, out, keepdims, where])

Return the mean of array elements along a given axis.

min([axis, out, keepdims, initial, where])

Return the minimum of array elements along a given axis.

nonzero(*[, fill_value, size])

Return indices of nonzero elements of an array.

prod([axis, dtype, out, keepdims, initial, ...])

Return product of the array elements over a given axis.

ptp([axis, out, keepdims])

Return the peak-to-peak range along a given axis.

ravel([order])

Flatten array into a 1-dimensional shape.

repeat(repeats[, axis, total_repeat_length])

Construct an array from repeated elements.

reshape(*args[, order])

Returns an array containing the same data with a new shape.

round([decimals, out])

Round array elements to a given decimal.

searchsorted(v[, side, sorter, method])

Perform a binary search within a sorted array.

sort([axis, kind, order, stable, descending])

Return a sorted copy of an array.

squeeze([axis])

Remove one or more length-1 axes from array.

std([axis, dtype, out, ddof, keepdims, ...])

Compute the standard deviation along a given axis.

sum([axis, dtype, out, keepdims, initial, ...])

Sum of the elements of the array over a given axis.

swapaxes(axis1, axis2)

Swap two axes of an array.

take(indices[, axis, out, mode, ...])

Take elements from an array.

to_device(device, *[, stream])

Return a copy of the array on the specified device

trace([offset, axis1, axis2, dtype, out])

Return the sum along the diagonal.

transpose(*args)

Returns a copy of the array with axes transposed.

var([axis, dtype, out, ddof, keepdims, ...])

Compute the variance along a given axis.

view([dtype, type])

Return a bitwise copy of the array, viewed as a new dtype.

Attributes

T

Compute the all-axis array transpose.

addressable_shards

List of addressable shards.

at

Helper property for index update functionality.

committed

Whether the array is committed or not.

device

Array API-compatible device attribute.

dtype

The data type (numpy.dtype) of the array.

flat

Use flatten() instead.

global_shards

List of global shards.

imag

Return the imaginary part of the array.

is_fully_addressable

Is this Array fully addressable?

is_fully_replicated

Is this Array fully replicated?

itemsize

Length of one array element in bytes.

mT

Compute the (batched) matrix transpose.

nbytes

Total bytes consumed by the elements of the array.

ndim

The number of dimensions in the array.

real

Return the real part of the array.

shape

The shape of the array.

sharding

The sharding for the array.

size

The total number of elements in the array.