파이토치에서 가중치, 모델, 옵티마이저 저장하고 불러오는 방법
재학습하지 않는 경우1 2 3
저장하기
재학습하지 않는 경우라면 간단하게 가중치 혹은 모델만 저장해도 된다. 아래에서 말하겠지만 재학습을 할거라면 옵티마이저까지 저장해야한다. 가중치는 다음과 같이 간단히 저장할 수 있다.
# 모델 정의
class CustomModel(nn.module):
...(이하생략)
model = CustomModel()
# 가중치 저장
torch.save(model.state_dict(), 'weights.pt')
이때 확장자는 .pt
혹은 .pth
를 사용한다. 모델을 통째로 저장하는 방법은 다음과 같다. 파이토치 공식 홈페이지에서는 위와 같이 model.state_dict()
을 저장하는 것을 권장한다.
torch.save(model, 'model.pt')
불러오기
가중치만 저장한 경우에는 불러올 수 있다.
# 저장된 가중치와 같은 구조의 모델 정의
model = CustomModel()
# 가중치 불러오기
moel.load_state_dict(torch.load('weights.pt'))
# evaluation mode
model.eval()
모델을 통째로 저장한 경우에는 다음과 같이 불러온다.
model = CustomModel()
model = torch.load()
model.eval()
공식 홈페이지의 설명4 대로면 모델을 통째로 저장한 경우에도 모델을 불러오기 전에 미리 정의를 해야한다는 것 같은데, 그러면 이 두 방식에서 실질적으로 무슨 차이가 있는지는 잘 모르겠다. 어찌됐건 모델 자체를 저장하는 경우는 다른 경로에서 사용하거나 하는 등의 상황에서 불러오기가 제대로 되지 않을 수 있다고 하니, 권장하고 있는 방식으로 저장하고 불러오는 것이 좋겠다.
재학습하는 경우
위에서 말했듯이 재학습하는 경우에는 가중치와 더불어 반드시 옵티마이저의 상태까지 같이 저장해야한다. 왜 그래야하는지는 이해가 안되지만 이렇게 안하면 학습이 이어지지 않으니 주의하자. 필요에 따라서 현재 에포크, 로스 등도 같이 저장할 수 있다.
저장하기
# 모델 정의
class CustomModel(nn.module):
...(이하생략)
model = CustomModel()
# 옵티마이저 정의
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
# 가중치, 옵티마이저, 현재 에포크 저장
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch' : epoch
}, 'checkpoint.tar')
이런식으로 여러 정보를 저장하는 정우에는 확장자를 .tar
로 쓴다고 한다. 무슨 차이인지는 모르겠는데 공식 홈페이지에서 그렇게 쓰라고 한다.
불러오기
불러오기 전에 모델과 옵티마이저가 정의되어 있어야 한다.
model = CustomModel()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
# 불러오고 적용하기
checkpoint = torch.load('checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']
# train mode
model.train()
환경
- OS: Windows10
- Version: Python 3.9.2, torch 1.8.1+cu111
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html?highlight=saving%20loading ↩︎
https://tutorials.pytorch.kr/beginner/saving_loading_models.html ↩︎
https://tutorials.pytorch.kr/recipes/recipes/saving_and_loading_a_general_checkpoint.html ↩︎
https://pytorch.org/tutorials/beginner/saving_loading_models.html?highlight=saving%20loading ↩︎