[python] 파이썬 PyTorch에서 학습된 모델을 저장하고 불러오는 방법은?

PyTorch는 학습된 모델을 저장하고 불러오는 기능을 제공합니다. 모델을 저장하면 향후에 학습된 결과를 다시 사용하거나 모델을 공유할 수 있습니다. 이제 학습된 모델을 저장하고 불러오는 방법을 알아보겠습니다.

모델 저장하기

모델을 저장하기 위해 torch.save() 함수를 사용합니다. 이 함수는 모델의 가중치와 매개변수를 직렬화하여 파일에 저장합니다.

import torch

# 모델 저장 경로
PATH = "model.pt"

# 모델 저장
torch.save(model.state_dict(), PATH)

위의 코드에서 model.state_dict()는 모델의 가중치와 매개변수를 저장하는데 사용되는 PyTorch 내장 함수입니다. PATH 변수는 모델을 저장할 파일 경로를 설정합니다. 파일의 확장자는 .pt로 지정하는 것이 일반적입니다.

모델 불러오기

모델을 불러오려면 torch.load() 함수를 사용합니다. 이 함수는 저장된 모델 파일을 읽고 모델의 가중치와 매개변수를 복원합니다.

import torch

# 저장된 모델 불러오기
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModelClass(*args, **kwargs)  # 모델 클래스 인스턴스화
model.load_state_dict(torch.load(PATH, map_location=device))
model.eval()

위의 코드에서 PATH 변수는 모델 파일의 경로를 설정합니다. device 변수는 현재 사용 가능한 GPU가 있는 경우 GPU를 사용하고, 그렇지 않으면 CPU를 사용하도록 설정합니다. model = ModelClass(*args, **kwargs)에서는 저장된 모델과 동일한 모델 클래스의 인스턴스를 만듭니다. 마지막으로 model.load_state_dict()를 사용하여 저장된 모델의 가중치와 매개변수를 불러옵니다. model.eval()는 모델을 추론 모드로 설정합니다.

참고 자료