[파이썬] PyTorch 직렬화 및 모델 전송

PyTorch는 딥 러닝 모델을 학습하고 저장하는 데 사용되는 강력한 프레임워크입니다. 이러한 모델을 다른 기기 또는 환경으로 전송하려면 모델을 직렬화하고 전송하는 작업을 수행해야 합니다.

PyTorch에서 모델을 직렬화하고 전송하는 방법을 배워봅시다.

1. 모델 직렬화

PyTorch에서 모델을 직렬화하는 가장 간단한 방법은 torch.save() 함수를 사용하는 것입니다. 이 함수를 사용하여 모델의 상태를 저장할 수 있습니다.

import torch

# 모델을 정의합니다.
model = torch.nn.Sequential(
    torch.nn.Linear(10, 5),
    torch.nn.ReLU(),
    torch.nn.Linear(5, 1)
)

# 모델을 학습시킵니다.

# 모델을 저장합니다.
torch.save(model.state_dict(), 'model.pt')

위의 코드는 model.pt라는 이름으로 모델을 저장합니다. 이 파일은 모델의 가중치와 편향 값을 포함하고 있으며, 이 정보를 사용하여 나중에 모델을 다시 로드할 수 있습니다.

2. 모델 전송

모델을 전송하기 위해서는 저장된 모델 파일을 다른 기기로 복사해야 합니다. 예를 들어, 모델을 CPU에서 GPU로 전송하거나, 서버에서 클라이언트로 전송하는 경우입니다.

import torch

# 저장된 모델 파일을 불러옵니다.
model = torch.nn.Sequential(
    torch.nn.Linear(10, 5),
    torch.nn.ReLU(),
    torch.nn.Linear(5, 1)
)
model.load_state_dict(torch.load('model.pt'))

# 모델을 GPU로 전송합니다.
device = torch.device('cuda')
model.to(device)

# 모델을 클라이언트로 전송합니다.
import socket

HOST = 'localhost'
PORT = 12345

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.connect((HOST, PORT))
    s.sendall(model.state_dict())

위의 코드에서는 먼저 저장된 모델 파일을 로드하고, GPU로 모델을 전송하기 위해 model.to(device)를 사용합니다. 마지막으로, 모델을 클라이언트로 전송하기 위해 소켓을 열고 model.state_dict()를 전송합니다.

전송된 모델은 클라이언트에서 동일한 방식으로 로드하여 사용할 수 있습니다.

3. 모델 역직렬화

전송된 모델을 로드하려면 역직렬화 작업이 필요합니다. 이를 위해 torch.load() 함수를 사용합니다.

import torch

# 전송된 모델을 로드합니다.
model = torch.nn.Sequential(
    torch.nn.Linear(10, 5),
    torch.nn.ReLU(),
    torch.nn.Linear(5, 1)
)
model.load_state_dict(torch.load('model.pt'))

# 모델을 사용합니다.

위의 코드에서는 torch.load() 함수를 사용하여 전송된 모델을 로드합니다. 이제 모델을 사용하여 예측 등의 작업을 수행할 수 있습니다.

결론

PyTorch에서 모델을 직렬화하고 전송하는 방법에 대해 알아보았습니다. 모델을 저장하고 다른 기기나 환경으로 전송하는 작업은 다양한 상황에서 유용하게 사용될 수 있습니다. 이러한 작업을 통해 모델을 보다 쉽게 공유하고, 다양한 응용 프로그램에서 활용할 수 있습니다.