jax.numpy.flatnonzero¶
-
jax.numpy.flatnonzero(a, *, size=None)[source]¶ Return indices that are non-zero in the flattened version of a.
LAX-backend implementation of
flatnonzero().Because the size of the output of
nonzerois data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output arrays: it must be specified statically forjnp.nonzeroto be traced. If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the result will be padded withfill_value, which defaults to zero.Original docstring below.
This is equivalent to np.nonzero(np.ravel(a))[0].
- Parameters
a (array_like) – Input data.
- Returns
res – Output array, containing the indices of the elements of a.ravel() that are non-zero.
- Return type