jax.lax.gatherΒΆ
-
jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False)[source]ΒΆ Gather operator.
Wraps XLAβs Gather operator.
The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), β¦]), rather than using gather directly.
- Parameters
operand (
Any) β an array from which slices should be takenstart_indices (
Any) β the indices at which slices should be takendimension_numbers (
GatherDimensionNumbers) β a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.slice_sizes (
Sequence[Union[int,Any]]) β the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).indices_are_sorted (
bool) β whether indices is known to be sorted. If true, may improve performance on some backends.unique_indices (
bool) β whether the indices inoperandare guaranteed to not overlap with each other. If true, may improve performance on some backends.
- Return type
- Returns
An array containing the gather output.