jax.numpy.nanargmaxΒΆ
-
jax.numpy.nanargmax(a, axis=None)[source]ΒΆ Return the indices of the maximum values in the specified axis ignoring
LAX-backend implementation of
nanargmax().Warning: jax.numpy.argmax returns -1 for all-NaN slices and does not raise an error.
Original docstring below.
NaNs. For all-NaN slices
ValueErroris raised. Warning: the results cannot be trusted if a slice contains only NaNs and -Infs.