logo

PyTorchでモデルの重み値を得る方法 📂機械学習

PyTorchでモデルの重み値を得る方法

説明

次のようなモデルを定義しよう。

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(3, 3, bias=True)
        self.conv = nn.Conv2d(3, 5, 2)

f = Model()

すると、.weight.biasメソッドで各層の重みやバイアスにアクセスできる。ただし、.weight (.bias)で得られる値はテンソルではなく、パラメータというオブジェクトなので、重みの値を持つテンソルを得たい場合は、.weight.data(.bias.data)のように使わなければならない。