logo

Saving and Loading Weights, Models, and Optimizers in PyTorch 📂Machine Learning

Saving and Loading Weights, Models, and Optimizers in PyTorch

Not Re-training1 2 3

Saving

If you’re not planning to re-train, you can simply save the weights or the entire model. However, as mentioned below, if you’re planning to re-train, you also need to save the optimizer. Weights can be easily saved as follows:

# 모델 정의
class CustomModel(nn.module):
    ...(이하생략)

model = CustomModel()

# 가중치 저장
torch.save(model.state_dict(), 'weights.pt')

The extension .pt or .pth should be used for the file. To save the entire model, you can do it like this. The official PyTorch website recommends saving model.state_dict() as shown above.

torch.save(model, 'model.pt')

Loading

If you’ve only saved the weights, they can be loaded as follows:

# 저장된 가중치와 같은 구조의 모델 정의
model = CustomModel()

# 가중치 불러오기
moel.load_state_dict(torch.load('weights.pt'))

# evaluation mode
model.eval()

To load a model that was saved entirely, you do it like this:

model = CustomModel()
model = torch.load()
model.eval()

According to the official website4, even if you’ve saved the entire model, you need to define the model beforehand when loading it. It’s unclear what the practical differences between these two methods are. Nonetheless, if you’re saving the model itself, there might be issues with loading it correctly in different situations, so it’s recommended to follow the guidelines for saving and loading.

In Case of Re-training

As mentioned earlier, if you plan to re-train, you need to save the state of the optimizer along with the weights. It’s unclear why, but failing to do so may result in discontinuation of the training, so be cautious. If necessary, you can also save the current epoch, loss, etc.

Saving

# 모델 정의
class CustomModel(nn.module):
    ...(이하생략)

model = CustomModel()

# 옵티마이저 정의
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

# 가중치, 옵티마이저, 현재 에포크 저장
torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch' : epoch
        }, 'checkpoint.tar')

When saving various pieces of information like this, the .tar extension is used. It’s unclear what the difference is, but that’s what the official website suggests.

Loading

The model and optimizer must be defined before loading.

model = CustomModel()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

# 불러오고 적용하기
checkpoint = torch.load('checkpoint.tar')

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']

# train mode
model.train()

Environment

  • OS: Windows10
  • Version: Python 3.9.2, torch 1.8.1+cu111