파이토치에서 가중치 초기화 하는 방법

파이토치에서 가중치 초기화 하는 방법

Weights Initialization in PyTorch

코드1

다음과 같이 뉴럴 네트워크를 정의했다고 하자. forward 부분은 생략하였다.

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        torch.nn.init.constant_(self.linear_1.weight.data, 0)
        torch.nn.init.unifiom_(self.linear_2.weight.data) 
        torch.nn.init.xavier_normal_(self.linear_3.weight.data)
        torch.nn.init.xavier_normal_(self.linear_3.bias.data)
          
    def forward(self, x):
        ...

가중치의 초기화는 nn.init을 통해 설정할 수 있다. 바이어스가 있는 층의 경우 이도 따로 설정해주어야 한다.

기본

  • torch.nn.init.constant_(tensor, val): 상수로 설정한다.
  • torch.nn.init.ones_(tensor): $1$로 설정한다.
  • torch.nn.init.zeros_(tensor): $0$으로 설정한다.
  • torch.nn.init.eye_(tensor)
  • torch.nn.init.dirac_(tensor, groups=1)
  • torch.nn.init.unifiom_(tensor, a=0.0, b=1.0): a부터 b사이의 값을 균일한 분포로 설정한다. 디폴트 설정은 a=0.0, b=1.0이다.
  • torch.nn.init.normal_(tensor, mean=0.0, std=1.0): 평균이 0이고 표준편차가 1인 분포로 설정한다.
  • torch.nn.init.orthogonal_(tensor, gain=1)
  • torch.nn.init.sparse_(tensor, sparsity, std=0.01)

Xavier

  • torch.nn.init.xavier_uniform_(tensor, gain=1.0)
  • torch.nn.init.xavier_normal_(tensor, gain=1.0)

Kaiming

  • torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
  • torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

직접지정

내가 원하는 특정한 배열을 초기값으로 쓰고싶다면, 다음과 같이 설정하면 된다. $A$가 크기가 $m \times n$인 torch.tensor()일 때,

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(m, n, bias=False)
        
        self.linear_1.weight.data = A 

반복문

또한 다음과 같이 반복문을 통해서 초기화할 수 있다.

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.uniform_(m.weight.data)
          
    def forward(self, x):
        ...

nn.init을 쓰지 않고 다음과 같이 할 수도 있다.

for m in self.modules():
    if isinstance(m, nn.Linear):
        m.weight.data.zero_()
        m.bias.data.zero_()

환경

  • OS: Windows10
  • Version: Python 3.9.2, torch 1.8.1+cu111

  1. https://pytorch.org/docs/stable/nn.init.html ↩︎

댓글