PyTorchでランダム順列を作成し、テンソルの順序をシャッフルする方法
torch.randperm()
1
torch.randperm(n)
: 0からn-1までのランダムな整数の順列を返す。もちろん、整数型でなければ入力に使えない。
>>> torch.randperm(4)
tensor([2, 1, 0, 3])
>>> torch.randperm(8)
tensor([4, 0, 1, 3, 2, 5, 6, 7])
>>> torch.randperm(16)
tensor([12, 5, 6, 3, 15, 13, 2, 4, 7, 11, 1, 0, 9, 10, 14, 8])
>>> torch.randperm(4.0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: randperm(): argument 'n' (position 1) must be int, not float
tensor[indices]
インデックステンソルでインデキシングすると、それに応じてインデックスが変わる。基準はdim=0
だ。numpy arrayについても同様に可能だ。
>>> indices = torch.randperm(4)
>>> indices
tensor([1, 3, 0, 2])
>>> a = torch.tensor([1,2,3,4])
>>> a[indices]
tensor([2, 4, 1, 3])
>>> b = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
>>> b[indices]
tensor([[2, 2, 2],
[4, 4, 4],
[1, 1, 1],
[3, 3, 3]])
>>> c = torch.tensor([[[1,1],[1,1]],[[2,2],[2,2]],[[3,3],[3,3]],[[4,4],[4,4]]])
>>> c
tensor([[[1, 1],
[1, 1]],
[[2, 2],
[2, 2]],
[[3, 3],
[3, 3]],
[[4, 4],
[4, 4]]])
>>> c[indices]
tensor([[[2, 2],
[2, 2]],
[[4, 4],
[4, 4]],
[[1, 1],
[1, 1]],
[[3, 3],
[3, 3]]])