[python] 파이썬 PyTorch에서 언더샘플링과 오버샘플링을 수행하는 방법은?

머신러닝 모델을 훈련시킬 때, 클래스 간의 불균형한 분포로 인해 정확도가 낮아질 수 있습니다. 이를 해결하기 위해 언더샘플링과 오버샘플링 기법을 사용하여 학습 데이터의 클래스를 균형있게 만들 수 있습니다.

PyTorch에서는 ImbalancedDatasetSampler와 RandomOverSampler를 사용하여 언더샘플링과 오버샘플링을 수행할 수 있습니다.

언더샘플링

언더샘플링은 데이터셋에서 다수 클래스의 샘플을 제거하여 클래스 간의 균형을 맞추는 방법입니다. PyTorch에서는 ImbalancedDatasetSampler를 사용하여 언더샘플링을 수행할 수 있습니다.

from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torch.utils.data import WeightedRandomSampler

sampler = ImbalancedDatasetSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)

위의 코드는 ImbalancedDatasetSampler를 사용하여 데이터셋을 언더샘플링한 후 DataLoader를 생성하는 방법을 보여줍니다. 이때 ImbalancedDatasetSampler는 클래스별로 다른 가중치를 적용하여 샘플을 선택합니다.

오버샘플링

오버샘플링은 소수 클래스의 샘플을 복제 또는 생성하여 데이터셋의 클래스 간 균형을 맞추는 방법입니다. PyTorch에서는 RandomOverSampler를 사용하여 오버샘플링을 수행할 수 있습니다.

from imblearn.over_sampling import RandomOverSampler

sampler = RandomOverSampler()
X_resampled, y_resampled = sampler.fit_resample(X, y)

위의 코드는 RandomOverSampler를 사용하여 X와 y 데이터를 오버샘플링하는 방법을 보여줍니다. fit_resample() 메서드를 사용하여 오버샘플링된 데이터를 얻을 수 있습니다.

언더샘플링과 오버샘플링은 모델의 성능을 개선하기 위해 효과적인 방법입니다. PyTorch에서 제공하는 ImbalancedDatasetSampler와 imblearn 패키지의 RandomOverSampler를 활용하여 데이터셋의 클래스 균형을 조절할 수 있습니다.

참고 자료