jax.numpy.nonzero¶
-
jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]¶ Return the indices of the elements that are non-zero.
LAX-backend implementation of
nonzero().Because the size of the output of
nonzerois data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output arrays: it must be specified statically forjnp.nonzeroto be traced. If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the result will be padded withfill_value, which defaults to zero.Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.
Note
When called on a zero-d array or scalar,
nonzero(a)is treated asnonzero(atleast1d(a)).Deprecated since version 1.17.0: Use atleast1d explicitly if this behavior is deliberate.
- Parameters
a (array_like) – Input array.
- Returns
tuple_of_arrays – Indices of elements that are non-zero.
- Return type