jax.lax.gatherΒΆ

jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False)[source]ΒΆ

Gather operator.

Wraps XLA’s Gather operator.

The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), …]), rather than using gather directly.

Parameters
  • operand (Any) – an array from which slices should be taken

  • start_indices (Any) – the indices at which slices should be taken

  • dimension_numbers (GatherDimensionNumbers) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.

  • slice_sizes (Sequence[Union[int, Any]]) – the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).

  • indices_are_sorted (bool) – whether indices is known to be sorted. If true, may improve performance on some backends.

  • unique_indices (bool) – whether the indices in operand are guaranteed to not overlap with each other. If true, may improve performance on some backends.

Return type

Any

Returns

An array containing the gather output.