PyTorchでのテンソルソートに関する関数
torch.sort()
torch.sort()
にテンソルを入力すると、ソートされた値とインデックスが返される。
1次元テンソル
>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> values, indices = torch.sort(x)
>>> values
tensor([-2, -1, 0, 1, 3, 5, 7])
>>> indices
tensor([2, 4, 6, 0, 1, 3, 5])
多次元テンソル
テンソルだけを入力すると、各行ごとにソートする。つまり、torch.sort(x)
=torch.sort(x, dim=1)
である。次元を入力すると、その次元ごとにソートする。
>>> x = torch.tensor([[1, 0, 4],
[-1, 5, 2],
[0, 4, 3]])
>>> value, indices = torch.sort(x)
>>> value
tensor([[ 0, 1, 4],
[-1, 2, 5],
[ 0, 3, 4]])
>>> indices
tensor([[1, 0, 2],
[0, 2, 1],
[0, 2, 1]])
>>> value, indices = torch.sort(x, dim=0)
>>> value
tensor([[-1, 0, 2],
[ 0, 4, 3],
[ 1, 5, 4]])
>>> indices
tensor([[1, 0, 1],
[2, 2, 2],
[0, 1, 0]])
torch.argsort()
特定の行(列)基準でソートする
特定の次元を基準にテンソルをソートすることができる。例えば、x
を最初の列を基準に昇順ソートしたいなら、
>>> value, indices = torch.sort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1, 5, 2],
[ 0, 4, 3],
[ 1, 0, 4]])
ただし、この場合はvalue
を計算する必要がないので、torch.argsort()
を使うこともできる。これはインデックスだけを返す関数で、以下のコードは上と同じだ。
>>> indices = torch.argsort(x, dim=0)
>>> x[indices[:,0], :]
tensor([[-1, 5, 2],
[ 0, 4, 3],
[ 1, 0, 4]])
.sort()
, .argsort()
これらはテンソル内にメソッドとしても実装されている。
>>> x = torch.tensor([1, 3, -2, 5, -1, 7, 0])
>>> x.sort()
torch.return_types.sort(
values=tensor([-2, -1, 0, 1, 3, 5, 7]),
indices=tensor([2, 4, 6, 0, 1, 3, 5]))
>>> x.argsort()
tensor([2, 4, 6, 0, 1, 3, 5])
環境
- Colab
- Version: Python 3.8.10, PyTorch1.13.1+cu116