logo

PyTorchでの重みの初期化方法 📂機械学習

PyTorchでの重みの初期化方法

コード1

次のようにニューラルネットワークを定義したとする。forward部分は省略されている。

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        torch.nn.init.constant_(self.linear_1.weight.data, 0)
        torch.nn.init.unifiom_(self.linear_2.weight.data) 
        torch.nn.init.xavier_normal_(self.linear_3.weight.data)
        torch.nn.init.xavier_normal_(self.linear_3.bias.data)
          
    def forward(self, x):
        ...

重みの初期化はnn.initを通じて設定できる。バイアスがある層の場合、これも別途設定する必要がある。

基本

  • torch.nn.init.constant_(tensor, val): 定数に設定する。
  • torch.nn.init.ones_(tensor):$1$に設定する。
  • torch.nn.init.zeros_(tensor): $0$に設定する。
  • torch.nn.init.eye_(tensor)
  • torch.nn.init.dirac_(tensor, groups=1)
  • torch.nn.init.unifiom_(tensor, a=0.0, b=1.0): aからbまでの値を一様分布に設定する。デフォルトはa=0.0、b=1.0だ。
  • torch.nn.init.normal_(tensor, mean=0.0, std=1.0): 平均が0で標準偏差が1の分布に設定する。
  • torch.nn.init.orthogonal_(tensor, gain=1)
  • torch.nn.init.sparse_(tensor, sparsity, std=0.01)

Xavier

  • torch.nn.init.xavier_uniform_(tensor, gain=1.0)
  • torch.nn.init.xavier_normal_(tensor, gain=1.0)

Kaiming

  • torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
  • torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

直接指定

特定の配列を初期値として使いたい場合、次のように設定できる。$A$が$m \times n$のサイズのtorch.tensor()の場合、

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(m, n, bias=False)
        
        self.linear_1.weight.data = A 

ループ

さらに、次のようにループを通じて初期化することができる。

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.uniform_(m.weight.data)
          
    def forward(self, x):
        ...

nn.initを使わずに、次のようにもできる。

for m in self.modules():
    if isinstance(m, nn.Linear):
        m.weight.data.zero_()
        m.bias.data.zero_()

環境

  • OS: Windows10
  • Version: Python 3.9.2, torch 1.8.1+cu111