How to Concatenate or Stack Tensors in PyTorch
Concatenate Tensors cat()
1
cat(tensors, dim=0)
concatenates two or more tensors along a specified dimension. This means that the size of the specified dimension increases when the tensors are concatenated. Therefore, it is natural that the sizes of the other dimensions need to be the same. For example, if there are tensors $(2,2)$ and $(2,3)$, they cannot be concatenated along the 0th dimension but can be concatenated along the 1st dimension.
$$ \text{cat} \Big( [(a,b),(a,b)], \dim=0 \Big) = (2a,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=1 \Big) = (a,2b) $$
Compared with stack()
below, since tensors are 'concatenated', 'the total number of dimensions does not change.'
Example Code
import torch
A = torch.ones(2,3)
tensor([[1., 1., 1.], [1., 1., 1.]])
B = 2*torch.ones(2,3)
tensor([[2., 2., 2.], [2., 2., 2.]])
C = torch.cat([A,B], dim=0)
torch.Size([4, 3]) tensor([[1., 1., 1.], [1., 1., 1.], [2., 2., 2.], [2., 2., 2.]])
D = torch.cat([A,B], dim=1)
torch.Size([2, 6]) tensor([[1., 1., 1., 2., 2., 2.], [1., 1., 1., 2., 2., 2.]])
Stack Tensors stack()
2
stack(tensors, dim=0)
stacks two or more tensors along a specified dimension. Compared with cat()
above, since tensors are 'stacked', 'a new dimension is added', and a new dimension is created at the location of the specified dimension.
$$ \text{cat} \Big( [(a,b),(a,b)], \dim=0 \Big) = (2,a,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=1 \Big) = (a,2,b) \\ \text{cat} \Big( [(a,b),(a,b)], \dim=2 \Big) = (a,b,2) $$
Example Code
import torch
A = torch.ones(3,4)
tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]])
B = 2*torch.ones(3,4)
tensor([[2., 2., 2., 2.], [2., 2., 2., 2.], [2., 2., 2., 2.]])
C = torch.stack([A,B], dim=0)
torch.Size([2, 3, 4]) tensor([[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]], [[2., 2., 2., 2.], [2., 2., 2., 2.], [2., 2., 2., 2.]]])
D = torch.stack([A,B], dim=1)
torch.Size([3, 2, 4]) tensor([[[1., 1., 1., 1.], [2., 2., 2., 2.]], [[1., 1., 1., 1.], [2., 2., 2., 2.]], [[1., 1., 1., 1.], [2., 2., 2., 2.]]])
E = torch.stack([A,B], dim=2)
torch.Size([3, 4, 2]) tensor([[[1., 2.], [1., 2.], [1., 2.], [1., 2.]], [[1., 2.], [1., 2.], [1., 2.], [1., 2.]], [[1., 2.], [1., 2.], [1., 2.], [1., 2.]]])
Environment
- OS: Windows10
- Version: Python 3.9.2, torch 1.8.1+cu111