logo

PyTorchでランダム順列を作成し、テンソルの順序をシャッフルする方法 📂機械学習

PyTorchでランダム順列を作成し、テンソルの順序をシャッフルする方法

torch.randperm()1

  • torch.randperm(n): 0からn-1までのランダムな整数の順列を返す。もちろん、整数型でなければ入力に使えない。
>>> torch.randperm(4)
tensor([2, 1, 0, 3])

>>> torch.randperm(8)
tensor([4, 0, 1, 3, 2, 5, 6, 7])

>>> torch.randperm(16)
tensor([12,  5,  6,  3, 15, 13,  2,  4,  7, 11,  1,  0,  9, 10, 14,  8])

>>> torch.randperm(4.0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: randperm(): argument 'n' (position 1) must be int, not float

tensor[indices]

インデックステンソルでインデキシングすると、それに応じてインデックスが変わる。基準はdim=0だ。numpy arrayについても同様に可能だ。

>>> indices = torch.randperm(4)
>>> indices
tensor([1, 3, 0, 2])

>>> a = torch.tensor([1,2,3,4])
>>> a[indices]
tensor([2, 4, 1, 3])

>>> b = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
>>> b[indices]
tensor([[2, 2, 2],
        [4, 4, 4],
        [1, 1, 1],
        [3, 3, 3]])

>>> c = torch.tensor([[[1,1],[1,1]],[[2,2],[2,2]],[[3,3],[3,3]],[[4,4],[4,4]]])
>>> c
tensor([[[1, 1],
         [1, 1]],

        [[2, 2],
         [2, 2]],

        [[3, 3],
         [3, 3]],

        [[4, 4],
         [4, 4]]])
>>> c[indices]
tensor([[[2, 2],
         [2, 2]],

        [[4, 4],
         [4, 4]],

        [[1, 1],
         [1, 1]],

        [[3, 3],
         [3, 3]]])