Functions for Tensor Sorting in PyTorch
torch.sort()
torch.sort()
takes a tensor as input and returns sorted values and indices.
1-dimensional tensor
>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> values, indices = torch.sort(x)
>>> values
tensor([-2, -1, 0, 1, 3, 5, 7])
>>> indices
tensor([2, 4, 6, 0, 1, 3, 5])
Multi-dimensional tensor
If only the tensor is inputted, it sorts each row. That is, torch.sort(x)
=torch.sort(x, dim=1)
. If a dimension is specified, it sorts along that dimension.
>>> x = torch.tensor([[1, 0, 4],
[-1, 5, 2],
[0, 4, 3]])
>>> value, indices = torch.sort(x)
>>> value
tensor([[ 0, 1, 4],
[-1, 2, 5],
[ 0, 3, 4]])
>>> indices
tensor([[1, 0, 2],
[0, 2, 1],
[0, 2, 1]])
>>> value, indices = torch.sort(x, dim=0)
>>> value
tensor([[-1, 0, 2],
[ 0, 4, 3],
[ 1, 5, 4]])
>>> indices
tensor([[1, 0, 1],
[2, 2, 2],
[0, 1, 0]])
torch.argsort()
Sorting based on a specific row (column)
You can sort a tensor based on a specific dimension using the indices obtained from torch.sort()
. For example, if you want to sort x
in ascending order based on the first column,
>>> value, indices = torch.sort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1, 5, 2],
[ 0, 4, 3],
[ 1, 0, 4]])
However, in this case, since it is not necessary to compute value
, you can also use torch.argsort()
. This function returns only indices, and the code below does the same.
>>> indices = torch.argsort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1, 5, 2],
[ 0, 4, 3],
[ 1, 0, 4]])
.sort()
, .argsort()
These are also implemented as methods within the tensor.
>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> x.sort()
torch.return_types.sort(
values=tensor([-2, -1, 0, 1, 3, 5, 7]),
indices=tensor([2, 4, 6, 0, 1, 3, 5]))
>>> x.argsort()
tensor([2, 4, 6, 0, 1, 3, 5])
Environment
- Colab
- Version: Python 3.8.10, PyTorch1.13.1+cu116