[파이썬][scikit-learn] scikit-learn Imbalanced 데이터 처리

데이터 불균형(Imbalanced data)은 분류 문제에서 자주 발생하는 문제입니다. 데이터 불균형은 한 클래스의 샘플 수가 다른 클래스에 비해 지나치게 적은 경우를 의미합니다. 이러한 데이터 불균형은 모델의 학습을 어렵게 하고, 예측 성능을 저하시킬 수 있습니다. 이번 글에서는 scikit-learn을 사용하여 Imbalanced 데이터를 처리하는 방법에 대해 알아보겠습니다.

1. 데이터 로딩

우선, 데이터를 로딩해야 합니다. scikit-learn에서는 sklearn.datasets 모듈을 사용하여 유명한 데이터셋을 로딩할 수 있습니다. 아래는 예시로 iris 데이터셋을 로딩하는 코드입니다.

from sklearn import datasets

# 데이터 로딩
iris = datasets.load_iris()
X = iris.data
y = iris.target

2. 데이터 확인

데이터를 로딩한 후, 데이터의 분포를 확인해야 합니다. 이 단계에서는 각 클래스의 샘플 수를 확인하고, 데이터 불균형이 있는지 확인해야 합니다. numpypandas를 사용하여 간단히 데이터를 확인할 수 있습니다.

import numpy as np
import pandas as pd

# 클래스별 샘플 수 확인
unique, counts = np.unique(y, return_counts=True)
class_counts = dict(zip(unique, counts))

# 데이터 불균형 확인
df = pd.DataFrame({'Class': unique, 'Count': counts})
print(df)

3. 데이터 샘플링

데이터 불균형을 해결하기 위해 데이터를 샘플링하는 방법을 사용할 수 있습니다. 대표적인 방법으로는 과소표집과대표집이 있습니다. 과소표집은 다수 클래스의 샘플을 제거하여 데이터를 균형있게 만드는 방법이고, 과대표집은 소수 클래스의 샘플을 복제하여 데이터를 균형있게 만드는 방법입니다.

from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import RandomOverSampler

# 과소표집
under_sampler = RandomUnderSampler()
X_resampled, y_resampled = under_sampler.fit_resample(X, y)

# 과대표집
over_sampler = RandomOverSampler()
X_resampled, y_resampled = over_sampler.fit_resample(X, y)

4. 모델 학습 및 예측

데이터 샘플링을 통해 불균형을 해결한 후, 모델을 학습하고 예측할 수 있습니다. 이때, scikit-learn의 다양한 분류 알고리즘을 사용할 수 있습니다.

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

# 학습 데이터 및 테스트 데이터 분할
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)

# 모델 학습
model = DecisionTreeClassifier()
model.fit(X_train, y_train)

# 예측
y_pred = model.predict(X_test)

# 분류 보고서 출력
print(classification_report(y_test, y_pred))

5. 평가

마지막으로, 모델의 예측 성능을 평가해야 합니다. 분류 문제에서는 여러가지 평가 지표를 사용할 수 있습니다. 가장 대표적인 평가 지표 중 하나는 정확도(Accuracy)입니다. 또한, 오차 행렬(Confusion Matrix)를 사용하여 모델의 성능을 확인할 수도 있습니다.

from sklearn.metrics import accuracy_score, confusion_matrix

# 정확도 계산
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')

# 오차 행렬 출력
conf_matrix = confusion_matrix(y_test, y_pred)
print(conf_matrix)

이제 scikit-learn을 사용하여 Imbalanced 데이터를 처리하는 방법에 대해 알아보았습니다. 데이터 로딩, 데이터 확인, 데이터 샘플링, 모델 학습 및 예측, 평가의 과정을 차례대로 진행하여 데이터 불균형 문제를 처리할 수 있습니다.