logo

파이토치에서 텐서 붙이거나 쌓는 방법 📂머신러닝

파이토치에서 텐서 붙이거나 쌓는 방법

텐서 붙이기 cat()1

cat(tensors, dim=0)는 2개 이상의 텐서들을 지정한 차원을 기준으로 붙인다. 기준이 된다는 말은 지정한 차원의 크기가 늘어나도록 붙인다는 뜻이다. 따라서 당연하게도 지정한 차원 외의 나머지 부분의 크기가 같아야한다. 예를 들어 $(2,2)$ 텐서와 $(2,3)$ 텐서가 있다면 0번째 차원으로는 붙일 수 없고, 1번째 차원으로는 붙일 수 있다.

$$ \text{cat} \Big( [(a,b),(a,b)], \dim=0 \Big) = (2a,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=1 \Big) = (a,2b) $$

아래의 stack()과 비교하면 텐서를 '붙이는' 것이므로 '전체 차원의 수는 변하지 않는다.'

2.PNG 3.PNG

예제 코드

import torch

A = torch.ones(2,3)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
B = 2*torch.ones(2,3)
tensor([[2., 2., 2.],
        [2., 2., 2.]])
C = torch.cat([A,B], dim=0)
torch.Size([4, 3])
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.]])
D = torch.cat([A,B], dim=1)
torch.Size([2, 6])
tensor([[1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.]])

텐서 쌓기 stack()2

stack(tensors, dim=0)는 2개 이상의 텐서들을 지정한 차원을 기준으로 쌓는다. 위의 cat()과 비교하면 텐서를 '쌓는' 것 이므로 '새로운 차원이 추가되고', 지정해준 차원의 자리에 새로운 차원이 생긴다.

$$ \text{cat} \Big( [(a,b),(a,b)], \dim=0 \Big) = (2,a,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=1 \Big) = (a,2,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=2 \Big) = (a,b,2) $$

예제 코드

import torch

A = torch.ones(3,4)
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
B = 2*torch.ones(3,4)
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])
C = torch.stack([A,B], dim=0)
torch.Size([2, 3, 4])
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]]])
D = torch.stack([A,B], dim=1)
torch.Size([3, 2, 4])
tensor([[[1., 1., 1., 1.],
         [2., 2., 2., 2.]],

        [[1., 1., 1., 1.],
         [2., 2., 2., 2.]],

        [[1., 1., 1., 1.],
         [2., 2., 2., 2.]]])
E = torch.stack([A,B], dim=2)
torch.Size([3, 4, 2])
tensor([[[1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.]],

        [[1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.]],

        [[1., 2.],
         [1., 2.],
         [1., 2.],
         [1., 2.]]])

환경

  • OS: Windows10
  • Version: Python 3.9.2, torch 1.8.1+cu111