파이토치에서 텐서 깊은 복사하는 방법

파이토치에서 텐서 깊은 복사하는 방법

설명

파이토치 텐서도 다른 객체와 마찬가지로 copy.deepcopy()를 이용해 깊은 복사를 할 수 있다.

>>> import torch
>>> import copy

>>> a = torch.ones(2,2)
>>> b = a
>>> c = copy.deepcopy(a)
>>> a += 1

>>> a
tensor([[2., 2.],
        [2., 2.]])

>>> b
tensor([[2., 2.],
        [2., 2.]])

>>> c
tensor([[1., 1.],
        [1., 1.]])

하지만 이는 사용자가 명시적으로 정의한 텐서에 한해서만 가능하다. 가령 딥러닝 모델의 아웃풋을 copy.deepcopy()로 복사하려고 하면 다음과 같은 오류가 발생한다.

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

파이토치에서는 .clone()1으로 텐서의 깊은 복사를 지원하니 이를 이용하면 직접 정의한 텐서인지 출력값으로 얻은 텐서인지에 관계없이 깊은 복사를 할 수 있다.

>>> a = torch.ones(2,2)
>>> b = a.clone()
>>> a += 1

>>> a
tensor([[2., 2.],
        [2., 2.]])

>>> b
tensor([[1., 1.],
        [1., 1.]])

  1. https://pytorch.org/docs/stable/generated/torch.clone.html ↩︎

댓글