파이토치에서 가중치 초기화 하는 방법

파이토치에서 가중치 초기화 하는 방법

코드1

다음과 같이 뉴럴 네트워크를 정의했다고 하자. forward 부분은 생략하였다.

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        torch.nn.init.constant_(self.linear_1.weight.data, 0)
        torch.nn.init.unifiom_(self.linear_2.weight.data) 
        torch.nn.init.xavier_normal_(self.linear_3.weight.data)
        torch.nn.init.xavier_normal_(self.linear_3.bias.data)
          
    def forward(self, x):
        ...

가중치의 초기화는 nn.init을 통해 설정할 수 있다. 바이어스가 있는 층의 경우 이도 따로 설정해주어야 한다.

기본

Xavier

Kaiming

직접지정

내가 원하는 특정한 배열을 초기값으로 쓰고싶다면, 다음과 같이 설정하면 된다. $A$가 크기가 $m \times n$인 torch.tensor()일 때,

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(m, n, bias=False)
        
        self.linear_1.weight.data = A 

반복문

또한 다음과 같이 반복문을 통해서 초기화할 수 있다.

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.uniform_(m.weight.data)
          
    def forward(self, x):
        ...

nn.init을 쓰지 않고 다음과 같이 할 수도 있다.

for m in self.modules():
    if isinstance(m, nn.Linear):
        m.weight.data.zero_()
        m.bias.data.zero_()

환경


  1. https://pytorch.org/docs/stable/nn.init.html ↩︎

댓글