logo

PyTorchでテンソルを結合またはスタックする方法 📂機械学習

PyTorchでテンソルを結合またはスタックする方法

テンソルを連結する cat()1

cat(tensors, dim=0)は、指定された次元を基準に2つ以上のテンソルを連結する。つまり、指定された次元のサイズが増加するようにテンソルが連結されるということだ。そのため、当然、指定された次元以外の他の部分のサイズが同じでなければならない。例えば、$(2,2)$ テンソルと $(2,3)$ テンソルがある場合、0番目の次元で連結することはできないが、1番目の次元であれば連結することができる。

$$ \text{cat} \Big( [(a,b),(a,b)], \dim=0 \Big) = (2a,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=1 \Big) = (a,2b) $$

以下の stack()と比較すると、テンソルが '連結される' ので、'全体の次元数は変わらない.'

2.PNG 3.PNG

例コード

import torch

A = torch.ones(2,3)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
B = 2*torch.ones(2,3)
tensor([[2., 2., 2.],
        [2., 2., 2.]])
C = torch.cat([A,B], dim=0)
torch.Size([4, 3])
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.]])
D = torch.cat([A,B], dim=1)
torch.Size([2, 6])
tensor([[1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.]])

テンソルを積む stack()2

stack(tensors, dim=0)は、指定された次元を基準に2つ以上のテンソルを積む。上の cat()と比較して、テンソルが '積まれる' ので、'新しい次元が追加され、' 指定した次元の場所に新しい次元ができる。

$$ \text{cat} \Big( [(a,b),(a,b)], \dim=0 \Big) = (2,a,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=1 \Big) = (a,2,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=2 \Big) = (a,b,2) $$

例コード

import torch

A = torch.ones(3,4)
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
B = 2*torch.ones(3,4)
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])
C = torch.stack([A,B], dim=0)
torch.Size([2, 3, 4])
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]]])
D = torch.stack([A,B], dim=1)
torch.Size([3, 2, 4])
tensor([[[1., 1., 1., 1.],
         [2., 2., 2., 2.]],

        [[1., 1., 1., 1.],
         [2., 2., 2., 2.]],

        [[1., 1., 1., 1.],
         [2., 2., 2., 2.]]])
E = torch.stack([A,B], dim=2)
torch.Size([3, 4, 2])
tensor([[[1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.]],

        [[1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.]],

        [[1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.]]])

環境

  • OS: Windows10
  • Version: Python 3.9.2, torch 1.8.1+cu111