PyTorchでモデルの重み値を得る方法
説明
次のようなモデルを定義しよう。
import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(3, 3, bias=True)
        self.conv = nn.Conv2d(3, 5, 2)
f = Model()
すると、.weightや.biasメソッドで各層の重みやバイアスにアクセスできる。ただし、.weight (.bias)で得られる値はテンソルではなく、パラメータというオブジェクトなので、重みの値を持つテンソルを得たい場合は、.weight.data(.bias.data)のように使わなければならない。
