logo

파이토치에서 텐서 정렬에 관한 함수 📂머신러닝

파이토치에서 텐서 정렬에 관한 함수

torch.sort()

torch.sort()에 텐서를 입력하면, 정렬된 값과 인덱스를 반환한다.

1차원 텐서

>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> values, indices = torch.sort(x)

>>> values
tensor([-2, -1,  0,  1,  3,  5,  7])

>>> indices
tensor([2, 4, 6, 0, 1, 3, 5])

다차원 텐서

텐서만 입력하면, 각각의 행마다 정렬한다. 즉, torch.sort(x)=torch.sort(x, dim=1)이다. 차원을 입력하면 해당 차원마다 정렬한다.

>>> x = torch.tensor([[1, 0, 4],
                      [-1, 5, 2],
                      [0, 4, 3]])

>>> value, indices = torch.sort(x)
>>> value
tensor([[ 0,  1,  4],
        [-1,  2,  5],
        [ 0,  3,  4]])
>>> indices
tensor([[1, 0, 2],
        [0, 2, 1],
        [0, 2, 1]])

>>> value, indices = torch.sort(x, dim=0)
>>> value
tensor([[-1,  0,  2],
        [ 0,  4,  3],
        [ 1,  5,  4]])
>>> indices
tensor([[1, 0, 1],
        [2, 2, 2],
        [0, 1, 0]])

torch.argsort()

특정 행(열) 기준으로 정렬

torch.sort()로 얻은 인덱스로 특정 차원을 기준으로 텐서를 정렬시킬 수 있다. 가령 x를 첫번째 열을 기준으로 오름차순 정렬을 하고 싶다면,

>>> value, indices = torch.sort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1,  5,  2],
        [ 0,  4,  3],
        [ 1,  0,  4]])

그런데 이때 value는 계산할 필요가 없으므로 torch.argsort()를 쓸 수도 있다. 이는 인덱스만 반환하는 함수로, 아래의 코드는 위와 같다.

>>> indices = torch.argsort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1,  5,  2],
        [ 0,  4,  3],
        [ 1,  0,  4]])

.sort(), .argsort()

텐서 내에 메서드로도 구현되어있다.

>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])

>>> x.sort()
torch.return_types.sort(
values=tensor([-2, -1,  0,  1,  3,  5,  7]),
indices=tensor([2, 4, 6, 0, 1, 3, 5]))

>>> x.argsort()
tensor([2, 4, 6, 0, 1, 3, 5])

환경

  • Colab
  • Version: Python 3.8.10, PyTorch1.13.1+cu116