파이토치에서 가중치, 모델, 옵티마이저 저장하고 불러오는 방법

파이토치에서 가중치, 모델, 옵티마이저 저장하고 불러오는 방법

재학습하지 않는 경우1 2 3

저장하기

재학습하지 않는 경우라면 간단하게 가중치 혹은 모델만 저장해도 된다. 아래에서 말하겠지만 재학습을 할거라면 옵티마이저까지 저장해야한다. 가중치는 다음과 같이 간단히 저장할 수 있다.

# 모델 정의
class CustomModel(nn.module):
    ...(이하생략)

model = CustomModel()

# 가중치 저장
torch.save(model.state_dict(), 'weights.pt')

이때 확장자는 .pt 혹은 .pth를 사용한다. 모델을 통째로 저장하는 방법은 다음과 같다. 파이토치 공식 홈페이지에서는 위와 같이 model.state_dict()을 저장하는 것을 권장한다.

torch.save(model, 'model.pt')

불러오기

가중치만 저장한 경우에는 불러올 수 있다.

# 저장된 가중치와 같은 구조의 모델 정의
model = CustomModel()

# 가중치 불러오기
moel.load_state_dict(torch.load('weights.pt'))

# evaluation mode
model.eval()

모델을 통째로 저장한 경우에는 다음과 같이 불러온다.

model = CustomModel()
model = torch.load()
model.eval()

공식 홈페이지의 설명4 대로면 모델을 통째로 저장한 경우에도 모델을 불러오기 전에 미리 정의를 해야한다는 것 같은데, 그러면 이 두 방식에서 실질적으로 무슨 차이가 있는지는 잘 모르겠다. 어찌됐건 모델 자체를 저장하는 경우는 다른 경로에서 사용하거나 하는 등의 상황에서 불러오기가 제대로 되지 않을 수 있다고 하니, 권장하고 있는 방식으로 저장하고 불러오는 것이 좋겠다.

재학습하는 경우

위에서 말했듯이 재학습하는 경우에는 가중치와 더불어 반드시 옵티마이저의 상태까지 같이 저장해야한다. 왜 그래야하는지는 이해가 안되지만 이렇게 안하면 학습이 이어지지 않으니 주의하자. 필요에 따라서 현재 에포크, 로스 등도 같이 저장할 수 있다.

저장하기

# 모델 정의
class CustomModel(nn.module):
    ...(이하생략)

model = CustomModel()

# 옵티마이저 정의
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

# 가중치, 옵티마이저, 현재 에포크 저장
torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch' : epoch
        }, 'checkpoint.tar')

이런식으로 여러 정보를 저장하는 정우에는 확장자를 .tar로 쓴다고 한다. 무슨 차이인지는 모르겠는데 공식 홈페이지에서 그렇게 쓰라고 한다.

불러오기

불러오기 전에 모델과 옵티마이저가 정의되어 있어야 한다.

model = CustomModel()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

# 불러오고 적용하기
checkpoint = torch.load('checkpoint.tar')

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']

# train mode
model.train()

환경


  1. https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html?highlight=saving%20loading ↩︎

  2. https://tutorials.pytorch.kr/beginner/saving_loading_models.html ↩︎

  3. https://tutorials.pytorch.kr/recipes/recipes/saving_and_loading_a_general_checkpoint.html ↩︎

  4. https://pytorch.org/tutorials/beginner/saving_loading_models.html?highlight=saving%20loading ↩︎

댓글