jax.numpy.argwhere¶
-
jax.numpy.argwhere(a, *, size=None)[source]¶ Find the indices of array elements that are non-zero, grouped by element.
LAX-backend implementation of
argwhere().Because the size of the output of
argwhereis data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsizeargument, which specifies the size of the leading dimension of the output - it must be specified statically forjnp.argwhereto be traced. Ifsizeis specified, the indices of the firstsizeTrue elements will be returned; if there are fewer nonzero elements than size indicates, the index arrays will be zero-padded.Original docstring below.
- Parameters
a (array_like) – Input data.
- Returns
index_array – Indices of elements that are non-zero. Indices are grouped by element. This array will have shape
(N, a.ndim)whereNis the number of non-zero items.- Return type
(N, a.ndim) ndarray