PyTorchでテンソルをディープコピーする方法
説明
PyTorchのテンソルも他のオブジェクトと同様にcopy.deepcopy()
を使って深いコピーができる。
>>> import torch
>>> import copy
>>> a = torch.ones(2,2)
>>> b = a
>>> c = copy.deepcopy(a)
>>> a += 1
>>> a
tensor([[2., 2.],
[2., 2.]])
>>> b
tensor([[2., 2.],
[2., 2.]])
>>> c
tensor([[1., 1.],
[1., 1.]])
しかし、これはユーザーが明示的に定義したテンソルに対してのみ可能だ。例えば、ディープラーニングモデルのアウトプットをcopy.deepcopy()
でコピーしようとすると、次のようなエラーが発生する。
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
PyTorchでは、.clone()
1によってテンソルの深いコピーをサポートしているので、これを利用すれば直接定義したテンソルであるか、出力値として得たテンソルであるかにかかわらず、深いコピーをすることができる。
>>> a = torch.ones(2,2)
>>> b = a.clone()
>>> a += 1
>>> a
tensor([[2., 2.],
[2., 2.]])
>>> b
tensor([[1., 1.],
[1., 1.]])