[파이썬] PyTorch에서의 데이터 로딩

PyTorch는 딥러닝을 위한 강력한 프레임워크로, 데이터를 로딩하여 모델 학습 및 테스트에 사용할 수 있습니다. 이 블로그 포스트에서는 PyTorch에서 데이터를 로딩하는 방법을 다루겠습니다.

1. 데이터셋 준비

먼저, 데이터를 로딩하기 전에 데이터셋을 준비해야 합니다. PyTorch에서는 torchvision이라는 라이브러리를 통해 여러 가지 데이터셋을 제공합니다. 예를 들어, CIFAR10, MNIST 등의 유명한 데이터셋을 쉽게 사용할 수 있습니다. 또는 직접 데이터셋을 만들어 사용할 수도 있습니다.

import torchvision.datasets as datasets

# CIFAR10 데이터셋 로딩
train_dataset = datasets.CIFAR10(root='data/', train=True, transform=None, download=True)
test_dataset = datasets.CIFAR10(root='data/', train=False, transform=None, download=True)

# MNIST 데이터셋 로딩
train_dataset = datasets.MNIST(root='data/', train=True, transform=None, download=True)
test_dataset = datasets.MNIST(root='data/', train=False, transform=None, download=True)

# 직접 만든 데이터셋 로딩
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, idx):
        item = self.data[idx]
        return item
    
    def __len__(self):
        return len(self.data)

data = [1, 2, 3, 4, 5]
custom_dataset = CustomDataset(data)

앞선 코드 예제에서 root는 데이터셋을 저장할 경로를 의미하며, train 파라미터는 훈련 데이터인지 테스트 데이터인지를 나타냅니다. transform은 데이터셋에 적용할 전처리 함수를 지정할 수 있는 옵션입니다. download는 데이터셋을 인터넷에서 자동으로 다운로드할지 여부를 설정하는 옵션입니다.

2. 데이터 로더 생성

데이터셋을 준비했다면, 이제 데이터를 로딩하기 위한 데이터 로더를 생성해야 합니다. 데이터 로더는 데이터셋으로부터 데이터를 일괄적으로 로딩하는 역할을 담당합니다.

import torch
from torch.utils.data import DataLoader

# 데이터 로더 생성
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 데이터 로더 사용 예시
for batch in train_loader:
    inputs, labels = batch
    # 학습 데이터로 모델 훈련

for batch in test_loader:
    inputs, labels = batch
    # 테스트 데이터로 모델 평가

위의 코드 예제에서 batch_size는 한 번에 로딩하는 데이터의 개수를 의미합니다. shuffle은 데이터를 섞을지 여부를 설정하는 옵션입니다. 모델 훈련 시에는 shuffle=True로 설정하여 데이터를 섞어 다양한 순서로 학습하도록 합니다.

이제 PyTorch에서 데이터를 로딩하는 방법을 알아보았습니다. 데이터 로딩은 딥러닝 모델 학습 및 테스트에 있어 중요한 부분이므로, 다양한 데이터셋과 데이터 로더 옵션을 활용하여 효율적인 데이터 처리를 할 수 있습니다.