logo

Initializing Weights in PyTorch 📂Machine Learning

Initializing Weights in PyTorch

Code1

Assuming we have defined a neural network as follows. The forward part is omitted.

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        torch.nn.init.constant_(self.linear_1.weight.data, 0)
        torch.nn.init.unifiom_(self.linear_2.weight.data) 
        torch.nn.init.xavier_normal_(self.linear_3.weight.data)
        torch.nn.init.xavier_normal_(self.linear_3.bias.data)
          
    def forward(self, x):
        ...

Weight initialization can be set through nn.init. For layers with bias, this also needs to be specifically set.

Basics

  • torch.nn.init.constant_(tensor, val): Sets it to a constant.
  • torch.nn.init.ones_(tensor): Sets it to $1$.
  • torch.nn.init.zeros_(tensor): Sets it to $0$.
  • torch.nn.init.eye_(tensor)
  • torch.nn.init.dirac_(tensor, groups=1)
  • torch.nn.init.unifiom_(tensor, a=0.0, b=1.0): Sets the values to a uniform distribution between a and b. The default is a=0.0, b=1.0.
  • torch.nn.init.normal_(tensor, mean=0.0, std=1.0): Sets it to a distribution with a mean of 0 and a standard deviation of 1.
  • torch.nn.init.orthogonal_(tensor, gain=1)
  • torch.nn.init.sparse_(tensor, sparsity, std=0.01)

Xavier

  • torch.nn.init.xavier_uniform_(tensor, gain=1.0)
  • torch.nn.init.xavier_normal_(tensor, gain=1.0)

Kaiming

  • torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
  • torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

Custom

If you want to use a specific array as an initial value, you can set it like this. When $A$ is a torch.tensor() with size $m \times n$,

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(m, n, bias=False)
        
        self.linear_1.weight.data = A 

Loops

Moreover, initialization can be performed through loops as follows.

import torch
import torch.nn as nn

class Custom_Net(nn.Module):
    def __init__(self):
        super(Custom_Net, self).__init__()

        self.linear_1 = nn.Linear(1024, 1024, bias=False)
        self.linear_2 = nn.Linear(1024, 512, bias=False)
        self.linear_3 = nn.Linear(512, 10, bias=True)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.uniform_(m.weight.data)
          
    def forward(self, x):
        ...

You can also do it as follows without using nn.init.

for m in self.modules():
    if isinstance(m, nn.Linear):
        m.weight.data.zero_()
        m.bias.data.zero_()

Environment

  • OS: Windows10
  • Version: Python 3.9.2, torch 1.8.1+cu111