[파이썬] PyTorch 사용자 정의 데이터 로더

PyTorch는 딥러닝을 위한 자연어 처리나 이미지 처리와 같은 분야에서 널리 사용되는 강력한 프레임워크입니다. 이 프레임워크는 기본적인 데이터 로딩 및 전처리를 위한 유틸리티를 제공하지만, 사용자 정의 데이터셋을 로딩하고 처리하기 위한 데이터 로더를 만들어야 할 때가 있습니다. 이번 블로그 포스트에서는 PyTorch에서 사용자 정의 데이터 로더를 생성하는 방법에 대해 알아보겠습니다.

데이터셋 클래스 생성

PyTorch에서 사용자 정의 데이터 로더를 생성하기 위해서는 먼저 torch.utils.data.Dataset 클래스를 상속하여 데이터셋 클래스를 생성해야 합니다. 데이터셋 클래스는 다음 메서드를 구현해야 합니다.

  1. __len__: 데이터셋의 총 샘플 수를 반환하는 메서드입니다.
  2. __getitem__: 인덱스를 기반으로 데이터셋에서 하나의 샘플을 로딩하고 반환하는 메서드입니다.

아래는 예제로 사용할 데이터셋 클래스의 기본 형태입니다.

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        # 데이터셋 초기화 로직 작성
        pass
    
    def __len__(self):
        # 데이터셋의 총 샘플 수 반환
        pass
    
    def __getitem__(self, index):
        # 주어진 인덱스에 해당하는 샘플 반환
        pass

데이터셋 클래스 구현

실제로 데이터셋을 만들기 위해선 데이터셋 클래스 내부에서 데이터를 초기화하는 로직과 __len____getitem__ 메서드를 구현해야 합니다. 예를 들어, 텍스트 데이터를 처리하기 위한 데이터셋 클래스를 구현하는 예제를 살펴보겠습니다.

import torch
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, text_file):
        self.data = []
        with open(text_file, 'r') as f:
            for line in f:
                # 데이터 초기화 로직 작성
                self.data.append(line.strip())
    
    def __len__(self):
        # 데이터셋의 총 샘플 수 반환
        return len(self.data)
    
    def __getitem__(self, index):
        # 주어진 인덱스에 해당하는 샘플 반환
        return self.data[index]

위의 예제에서는 텍스트 파일에서 데이터를 로딩하여 self.data에 저장하고, __len__ 메서드에서는 데이터셋의 총 샘플 수를 반환하고, __getitem__ 메서드에서는 주어진 인덱스에 해당하는 샘플을 반환하고 있습니다.

데이터 로더 생성

데이터셋 클래스를 구현한 후에는 해당 데이터셋을 사용하여 데이터 로더를 생성할 수 있습니다. 데이터 로더는 torch.utils.data.DataLoader 클래스를 사용하여 생성할 수 있습니다. 데이터 로더는 효율적인 배치 처리를 위해 여러 가지 옵션을 제공합니다. 예를 들어, 배치 크기, 데이터 순서 섞기, 병렬 처리 등의 옵션을 설정할 수 있습니다.

from torch.utils.data import DataLoader

# 데이터셋 인스턴스 생성
dataset = TextDataset("data.txt")

# 데이터 로더 생성
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

위의 예제에서는 TextDataset 클래스를 사용하여 dataset 객체를 생성하고, DataLoader 클래스를 사용하여 dataloader 객체를 생성하고 있습니다. 이후 배치 크기를 32로 설정하고, 데이터 순서를 섞기 위해 shuffle=True를 설정하였으며, 병렬 처리를 위해 num_workers를 4로 설정하였습니다.

이제 dataloader 객체를 사용하여 모델 학습이나 추론을 위한 배치 단위의 데이터셋을 제공할 수 있습니다.

결론

이 블로그 포스트에서는 PyTorch에서 사용자 정의 데이터 로더를 구현하는 방법에 대해 알아보았습니다. 데이터셋 클래스를 생성하여 데이터를 초기화하고 __len____getitem__ 메서드를 구현하면 PyTorch의 데이터 로더를 사용하여 배치 단위의 데이터셋을 효율적으로 처리할 수 있습니다.