[python] 파이썬 PyTorch에서 추론(inference) 과정을 수행하는 방법은?

PyTorch는 딥러닝 모델을 학습하는 데 주로 사용되지만, 학습된 모델을 사용하여 추론을 수행하는 것도 간단합니다. 추론을 수행하기 위해서는 우선 학습된 모델을 로드해야 합니다. 이후에는 추론할 데이터를 모델에 입력으로 전달하고, 모델의 출력값을 얻을 수 있습니다.

다음은 추론 과정을 수행하는 예제 코드입니다.

import torch
import torchvision.transforms as transforms
from PIL import Image

# 학습된 모델을 로드합니다.
model = torch.load('trained_model.pt')
model.eval()  # 모델을 평가 모드로 설정합니다.

# 추론할 이미지를 불러옵니다.
image = Image.open('test_image.jpg')

# 이미지 전처리를 수행합니다.
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)

# 추론을 수행합니다.
with torch.no_grad():
    output = model(input_batch)

# 출력값을 확인합니다.
_, predicted_idx = torch.max(output, 1)
print(f"추론 결과: {predicted_idx.item()}")

위의 예제 코드에서는 먼저 torch.load 함수를 사용하여 학습된 모델을 로드합니다. 그리고 model.eval()을 호출하여 모델을 평가 모드로 설정합니다. 이후에는 추론할 이미지를 불러와 전처리 작업을 수행한 후 모델에 입력으로 전달합니다. torch.no_grad() 블록 내에서 모델을 실행하고, 출력값을 얻어낼 수 있습니다. 마지막으로, 출력값을 활용하여 추론 결과를 확인합니다.

위의 코드는 단순한 예제이며, 실제로는 데이터 로딩, 전처리, 후처리, 배치 처리 등 다양한 요소를 고려해야할 수 있습니다. 따라서 실제 환경에서 추론 과정을 수행할 때에는 이러한 요소들을 고려하여 코드를 작성해야 합니다.

더 자세한 내용은 PyTorch 공식 문서(https://pytorch.org/docs/stable/index.html)를 참고하세요.

[출처 및 참고 자료]