jax.numpy.argmaxΒΆ
-
jax.numpy.argmax(a, axis=None, out=None)[source]ΒΆ Returns the indices of the maximum values along an axis.
LAX-backend implementation of
argmax().Original docstring below.
- Parameters
a (array_like) β Input array.
axis (int, optional) β By default, the index is into the flattened array, otherwise along the specified axis.
- Returns
index_array β Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
- Return type
ndarray of ints