PyTorchでの重み、モデル、オプティマイザの保存と読み込み方法
再学習しない場合1 2 3
保存する
再学習しないのなら、ウェイトやモデルのみを保存すればいい。ただし、以下で述べるように、再学習をする予定なら、オプティマイザーも保存する必要がある。ウェイトは、以下のように簡単に保存できる。
# 모델 정의
class CustomModel(nn.module):
...(이하생략)
model = CustomModel()
# 가중치 저장
torch.save(model.state_dict(), 'weights.pt')
この時、ファイルの拡張子は.pt
または.pth
を使用する。モデルを丸ごと保存する方法は、以下の通りだ。PyTorchの公式ホームページでは、上記のようにmodel.state_dict()
を保存することを推奨している。
torch.save(model, 'model.pt')
読み込む
ウェイトのみを保存した場合は、以下のように読み込む。
# 저장된 가중치와 같은 구조의 모델 정의
model = CustomModel()
# 가중치 불러오기
moel.load_state_dict(torch.load('weights.pt'))
# evaluation mode
model.eval()
モデルを丸ごと保存した場合は、以下のように読み込む。
model = CustomModel()
model = torch.load()
model.eval()
公式ホームページの説明4によると、モデルを丸ごと保存した場合でも、読み込む前にモデルを事前に定義する必要があるそうだが、この2つの方法の間に実際にどんな違いがあるのかはよく分からない。とにかくモデル自体を保存する場合は、別の場所で使用するなどの状況で正しく読み込めない可能性があるそうなので、推奨されている方法で保存して読み込むのが良いだろう。
再学習する場合
上で述べたように、再学習する場合は、ウェイトと共に必ずオプティマイザーの状態も保存しなければならない。なぜそうする必要があるのかは分からないが、そうしないと学習が続かない可能性があるので注意しよう。必要に応じて、現在のエポック、ロスなども一緒に保存できる。
保存する
# 모델 정의
class CustomModel(nn.module):
...(이하생략)
model = CustomModel()
# 옵티마이저 정의
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
# 가중치, 옵티마이저, 현재 에포크 저장
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch' : epoch
}, 'checkpoint.tar')
このように様々な情報を保存する場合は、拡張子を.tar
とする。その差はよくわからないが、公式ホームページではそうするようにと言っている。
読み込む
読み込む前に、モデルとオプティマイザーが定義されていなければならない。
model = CustomModel()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
# 불러오고 적용하기
checkpoint = torch.load('checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']
# train mode
model.train()
環境
- OS: Windows10
- Version: Python 3.9.2, torch 1.8.1+cu111
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html?highlight=saving%20loading ↩︎
https://tutorials.pytorch.kr/beginner/saving_loading_models.html ↩︎
https://tutorials.pytorch.kr/recipes/recipes/saving_and_loading_a_general_checkpoint.html ↩︎
https://pytorch.org/tutorials/beginner/saving_loading_models.html?highlight=saving%20loading ↩︎