PyTorchでモデルの重み値を得る方法
説明
次のようなモデルを定義しよう。
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(3, 3, bias=True)
self.conv = nn.Conv2d(3, 5, 2)
f = Model()
すると、.weight
や.bias
メソッドで各層の重みやバイアスにアクセスできる。ただし、.weight
(.bias
)で得られる値はテンソルではなく、パラメータというオブジェクトなので、重みの値を持つテンソルを得たい場合は、.weight.data
(.bias.data
)のように使わなければならない。