파이토치에서 텐서 정렬에 관한 함수
torch.sort()
torch.sort()
에 텐서를 입력하면, 정렬된 값과 인덱스를 반환한다.
1차원 텐서
>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> values, indices = torch.sort(x)
>>> values
tensor([-2, -1, 0, 1, 3, 5, 7])
>>> indices
tensor([2, 4, 6, 0, 1, 3, 5])
다차원 텐서
텐서만 입력하면, 각각의 행마다 정렬한다. 즉, torch.sort(x)
=torch.sort(x, dim=1)
이다. 차원을 입력하면 해당 차원마다 정렬한다.
>>> x = torch.tensor([[1, 0, 4],
[-1, 5, 2],
[0, 4, 3]])
>>> value, indices = torch.sort(x)
>>> value
tensor([[ 0, 1, 4],
[-1, 2, 5],
[ 0, 3, 4]])
>>> indices
tensor([[1, 0, 2],
[0, 2, 1],
[0, 2, 1]])
>>> value, indices = torch.sort(x, dim=0)
>>> value
tensor([[-1, 0, 2],
[ 0, 4, 3],
[ 1, 5, 4]])
>>> indices
tensor([[1, 0, 1],
[2, 2, 2],
[0, 1, 0]])
torch.argsort()
특정 행(열) 기준으로 정렬
torch.sort()
로 얻은 인덱스로 특정 차원을 기준으로 텐서를 정렬시킬 수 있다. 가령 x
를 첫번째 열을 기준으로 오름차순 정렬을 하고 싶다면,
>>> value, indices = torch.sort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1, 5, 2],
[ 0, 4, 3],
[ 1, 0, 4]])
그런데 이때 value
는 계산할 필요가 없으므로 torch.argsort()
를 쓸 수도 있다. 이는 인덱스만 반환하는 함수로, 아래의 코드는 위와 같다.
>>> indices = torch.argsort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1, 5, 2],
[ 0, 4, 3],
[ 1, 0, 4]])
.sort()
, .argsort()
텐서 내에 메서드로도 구현되어있다.
>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> x.sort()
torch.return_types.sort(
values=tensor([-2, -1, 0, 1, 3, 5, 7]),
indices=tensor([2, 4, 6, 0, 1, 3, 5]))
>>> x.argsort()
tensor([2, 4, 6, 0, 1, 3, 5])
환경
- Colab
- Version: Python 3.8.10, PyTorch1.13.1+cu116