logo

Functions for Tensor Sorting in PyTorch 📂Machine Learning

Functions for Tensor Sorting in PyTorch

torch.sort()

torch.sort() takes a tensor as input and returns sorted values and indices.

1-dimensional tensor

>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> values, indices = torch.sort(x)

>>> values
tensor([-2, -1,  0,  1,  3,  5,  7])

>>> indices
tensor([2, 4, 6, 0, 1, 3, 5])

Multi-dimensional tensor

If only the tensor is inputted, it sorts each row. That is, torch.sort(x)=torch.sort(x, dim=1). If a dimension is specified, it sorts along that dimension.

>>> x = torch.tensor([[1, 0, 4],
                      [-1, 5, 2],
                      [0, 4, 3]])

>>> value, indices = torch.sort(x)
>>> value
tensor([[ 0,  1,  4],
        [-1,  2,  5],
        [ 0,  3,  4]])
>>> indices
tensor([[1, 0, 2],
        [0, 2, 1],
        [0, 2, 1]])

>>> value, indices = torch.sort(x, dim=0)
>>> value
tensor([[-1,  0,  2],
        [ 0,  4,  3],
        [ 1,  5,  4]])
>>> indices
tensor([[1, 0, 1],
        [2, 2, 2],
        [0, 1, 0]])

torch.argsort()

Sorting based on a specific row (column)

You can sort a tensor based on a specific dimension using the indices obtained from torch.sort(). For example, if you want to sort x in ascending order based on the first column,

>>> value, indices = torch.sort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1,  5,  2],
        [ 0,  4,  3],
        [ 1,  0,  4]])

However, in this case, since it is not necessary to compute value, you can also use torch.argsort(). This function returns only indices, and the code below does the same.

>>> indices = torch.argsort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1,  5,  2],
        [ 0,  4,  3],
        [ 1,  0,  4]])

.sort(), .argsort()

These are also implemented as methods within the tensor.

>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])

>>> x.sort()
torch.return_types.sort(
values=tensor([-2, -1,  0,  1,  3,  5,  7]),
indices=tensor([2, 4, 6, 0, 1, 3, 5]))

>>> x.argsort()
tensor([2, 4, 6, 0, 1, 3, 5])

Environment

  • Colab
  • Version: Python 3.8.10, PyTorch1.13.1+cu116