logo

PyTorchでテンソルをディープコピーする方法 📂機械学習

PyTorchでテンソルをディープコピーする方法

説明

PyTorchのテンソルも他のオブジェクトと同様にcopy.deepcopy()を使って深いコピーができる。

>>> import torch
>>> import copy

>>> a = torch.ones(2,2)
>>> b = a
>>> c = copy.deepcopy(a)
>>> a += 1

>>> a
tensor([[2., 2.],
        [2., 2.]])

>>> b
tensor([[2., 2.],
        [2., 2.]])

>>> c
tensor([[1., 1.],
        [1., 1.]])

しかし、これはユーザーが明示的に定義したテンソルに対してのみ可能だ。例えば、ディープラーニングモデルのアウトプットをcopy.deepcopy()でコピーしようとすると、次のようなエラーが発生する。

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

PyTorchでは、.clone()1によってテンソルの深いコピーをサポートしているので、これを利用すれば直接定義したテンソルであるか、出力値として得たテンソルであるかにかかわらず、深いコピーをすることができる。

>>> a = torch.ones(2,2)
>>> b = a.clone()
>>> a += 1

>>> a
tensor([[2., 2.],
        [2., 2.]])

>>> b
tensor([[1., 1.],
        [1., 1.]])