logo

PyTorchでのテンソルソートに関する関数 📂機械学習

PyTorchでのテンソルソートに関する関数

torch.sort()

torch.sort()にテンソルを入力すると、ソートされた値とインデックスが返される。

1次元テンソル

>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> values, indices = torch.sort(x)

>>> values
tensor([-2, -1,  0,  1,  3,  5,  7])

>>> indices
tensor([2, 4, 6, 0, 1, 3, 5])

多次元テンソル

テンソルだけを入力すると、各行ごとにソートする。つまり、torch.sort(x)=torch.sort(x, dim=1)である。次元を入力すると、その次元ごとにソートする。

>>> x = torch.tensor([[1, 0, 4],
                      [-1, 5, 2],
                      [0, 4, 3]])

>>> value, indices = torch.sort(x)
>>> value
tensor([[ 0,  1,  4],
        [-1,  2,  5],
        [ 0,  3,  4]])
>>> indices
tensor([[1, 0, 2],
        [0, 2, 1],
        [0, 2, 1]])

>>> value, indices = torch.sort(x, dim=0)
>>> value
tensor([[-1,  0,  2],
        [ 0,  4,  3],
        [ 1,  5,  4]])
>>> indices
tensor([[1, 0, 1],
        [2, 2, 2],
        [0, 1, 0]])

torch.argsort()

特定の行(列)基準でソートする

特定の次元を基準にテンソルをソートすることができる。例えば、xを最初の列を基準に昇順ソートしたいなら、

>>> value, indices = torch.sort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1,  5,  2],
        [ 0,  4,  3],
        [ 1,  0,  4]])

ただし、この場合はvalueを計算する必要がないので、torch.argsort()を使うこともできる。これはインデックスだけを返す関数で、以下のコードは上と同じだ。

>>> indices = torch.argsort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1,  5,  2],
        [ 0,  4,  3],
        [ 1,  0,  4]])

.sort(), .argsort()

これらはテンソル内にメソッドとしても実装されている。

>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])

>>> x.sort()
torch.return_types.sort(
values=tensor([-2, -1,  0,  1,  3,  5,  7]),
indices=tensor([2, 4, 6, 0, 1, 3, 5]))

>>> x.argsort()
tensor([2, 4, 6, 0, 1, 3, 5])

環境

  • Colab
  • Version: Python 3.8.10, PyTorch1.13.1+cu116