Initializing Weights in PyTorch
Code1
Assuming we have defined a neural network as follows. The forward
part is omitted.
import torch
import torch.nn as nn
class Custom_Net(nn.Module):
def __init__(self):
super(Custom_Net, self).__init__()
self.linear_1 = nn.Linear(1024, 1024, bias=False)
self.linear_2 = nn.Linear(1024, 512, bias=False)
self.linear_3 = nn.Linear(512, 10, bias=True)
torch.nn.init.constant_(self.linear_1.weight.data, 0)
torch.nn.init.unifiom_(self.linear_2.weight.data)
torch.nn.init.xavier_normal_(self.linear_3.weight.data)
torch.nn.init.xavier_normal_(self.linear_3.bias.data)
def forward(self, x):
...
Weight initialization can be set through nn.init
. For layers with bias, this also needs to be specifically set.
Basics
torch.nn.init.constant_(tensor, val)
: Sets it to a constant.torch.nn.init.ones_(tensor)
: Sets it to $1$.torch.nn.init.zeros_(tensor)
: Sets it to $0$.torch.nn.init.eye_(tensor)
torch.nn.init.dirac_(tensor, groups=1)
torch.nn.init.unifiom_(tensor, a=0.0, b=1.0)
: Sets the values to a uniform distribution between a and b. The default is a=0.0, b=1.0.torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
: Sets it to a distribution with a mean of 0 and a standard deviation of 1.torch.nn.init.orthogonal_(tensor, gain=1)
torch.nn.init.sparse_(tensor, sparsity, std=0.01)
Xavier
torch.nn.init.xavier_uniform_(tensor, gain=1.0)
torch.nn.init.xavier_normal_(tensor, gain=1.0)
Kaiming
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
Custom
If you want to use a specific array as an initial value, you can set it like this. When $A$ is a torch.tensor()
with size $m \times n$,
class Custom_Net(nn.Module):
def __init__(self):
super(Custom_Net, self).__init__()
self.linear_1 = nn.Linear(m, n, bias=False)
self.linear_1.weight.data = A
Loops
Moreover, initialization can be performed through loops as follows.
import torch
import torch.nn as nn
class Custom_Net(nn.Module):
def __init__(self):
super(Custom_Net, self).__init__()
self.linear_1 = nn.Linear(1024, 1024, bias=False)
self.linear_2 = nn.Linear(1024, 512, bias=False)
self.linear_3 = nn.Linear(512, 10, bias=True)
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.uniform_(m.weight.data)
def forward(self, x):
...
You can also do it as follows without using nn.init
.
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data.zero_()
m.bias.data.zero_()
Environment
- OS: Windows10
- Version: Python 3.9.2, torch 1.8.1+cu111