[파이썬] PyTorch에서의 멀티 GPU 학습

딥러닝 모델의 학습은 많은 계산량과 시간이 필요합니다. 이에 따라, 단일 GPU로는 학습 시간을 단축시키기 어렵고 보다 빠른 학습을 위해 멀티 GPU를 활용하는 것이 필요합니다. 이번 글에서는 PyTorch에서 멀티 GPU 학습을 어떻게 구현하는지 알아보겠습니다.

1. 멀티 GPU 설정

PyTorch는 torch.nn.DataParallel 클래스를 사용하여 멀티 GPU를 지원합니다. 이 클래스를 사용하면 모델이 여러 GPU에 자동으로 병렬로 분산됩니다. 다음은 멀티 GPU 학습을 위한 기본적인 설정 방법입니다.

import torch
import torch.nn as nn
from torch.nn import DataParallel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델 정의
model = MyModel()

# 멀티 GPU로 모델을 래핑합니다.
model = nn.DataParallel(model)

# 모델을 GPU에 전송합니다.
model.to(device)

2. 데이터 로딩 및 데이터 분배

멀티 GPU 학습을 위해 데이터를 분배하는 방법에는 여러 가지가 있습니다. 가장 간단한 방법은 데이터를 일정한 비율로 분배하여 각 GPU에 배치하는 것입니다. PyTorch에서는 torch.nn.DataParallel 클래스가 이를 자동으로 처리해 줍니다. 다음은 데이터를 분배하는 예제입니다.

from torch.utils.data import DataLoader

# 데이터셋 정의
dataset = MyDataset()

# 데이터로더 생성
data_loader = DataLoader(dataset, batch_size=128, shuffle=True)

# 모델 학습
for batch_data in data_loader:
    # 데이터를 GPU에 전송합니다.
    inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
    
    # 모델 예측
    outputs = model(inputs)
    
    # 손실 계산
    loss = criterion(outputs, labels)
    
    # 역전파 및 가중치 업데이트
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

3. 그래디언트 업데이트와 매개변수 합산

멀티 GPU 학습 시, 각 GPU에서 계산된 그래디언트를 평균하여 가중치를 업데이트해야 합니다. 이는 torch.nn.DataParallel이 자동으로 처리해 주므로 추가적인 코드 작성이 필요 없습니다. 다음은 그래디언트 업데이트와 가중치 합산을 보여주는 예제입니다.

# 손실 계산
loss = criterion(outputs, labels)

# 역전파
loss.backward()

# 그래디언트 업데이트
optimizer.step()

# 가중치 합산
model.module.load_state_dict(model.state_dict())

4. 학습 결과 평가

멀티 GPU 학습이 완료된 후, 학습된 모델의 성능을 평가해야 합니다. 이는 단일 GPU에서와 동일한 방식으로 진행됩니다. 다음은 학습 결과를 평가하는 예제입니다.

# 모델 평가 모드로 설정
model.eval()

# 정확도 초기화
total_correct = 0
total_samples = 0

# 데이터로더를 사용하여 테스트 데이터로 예측
for batch_data in test_data_loader:
    inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
    outputs = model(inputs)
    
    # 예측값 가져오기
    _, predicted = torch.max(outputs.data, 1)
    
    # 정확도 계산
    total_samples += labels.size(0)
    total_correct += (predicted == labels).sum().item()

accuracy = total_correct / total_samples
print(f"Accuracy: {accuracy}")

이제 PyTorch에서 멀티 GPU 학습을 위한 기본적인 구현 방법에 대해 알게 되었습니다. 멀티 GPU를 활용하면 학습 시간을 단축시킬 수 있으며, 대용량 데이터셋이나 복잡한 모델을 다룰 때 특히 유용합니다. 다양한 모델 학습에 멀티 GPU 기능을 적용하여 보다 빠르고 효율적인 딥러닝 모델을 구축하세요. Happy coding!