[python] scikit-learn을 활용한 k-최근접 이웃 예측 시각화
k-최근접 이웃(K-Nearest Neighbors, KNN)은 분류 문제에 사용되는 간단하고 강력한 알고리즘입니다. 이웃들 중 가장 가까운 이웃들의 특징을 고려하여 예측을 수행합니다. 이번 블로그에서는 Python의 scikit-learn 라이브러리를 사용하여 k-최근접 이웃 알고리즘을 적용하고, 그 결과를 시각화해보겠습니다.
1. 데이터 불러오기
scikit-learn에서 제공하는 예제 데이터 중 하나인 iris 데이터셋을 사용하겠습니다. 이 데이터셋은 꽃의 품종을 분류하는 문제입니다.
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
2. 데이터 전처리
데이터를 훈련 데이터와 테스트 데이터로 나누고, 스케일링을 수행합니다.
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
3. k-최근접 이웃 모델 학습
from sklearn.neighbors import KNeighborsClassifier
k = 3
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train_scaled, y_train)
4. 예측 및 시각화
테스트 데이터를 사용하여 예측을 수행한 후, 결과를 시각화합니다.
import matplotlib.pyplot as plt
import numpy as np
# 예측 결과 구하기
y_pred = knn.predict(X_test_scaled)
# 샘플들의 예측 클래스와 실제 클래스 비교
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, marker='x')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('K-Nearest Neighbors Classification')
plt.legend(['Predicted', 'Actual'])
plt.show()
5. 결과 분석
시각화된 예측 결과를 통해, k-최근접 이웃 알고리즘이 얼마나 잘 분류하는지 확인할 수 있습니다. 예측된 클래스와 실제 클래스가 일치하는 경우는 예측이 올바르게 수행된 것입니다. 잘못 분류된 경우, 다른 알고리즘을 고려해볼 필요가 있습니다.
참고 자료
- scikit-learn Documentation: https://scikit-learn.org/stable/
- matplotlib Documentation: https://matplotlib.org/stable/contents.html
이상으로 scikit-learn을 활용한 k-최근접 이웃 예측의 시각화에 대해 알아보았습니다. 간단한 코드를 통해 어떻게 데이터를 준비하고 모델을 학습하여 예측 결과를 확인할 수 있는지 살펴보았습니다. 다양한 데이터셋에 대해 k-최근접 이웃 알고리즘을 적용해보고 결과를 시각화하여 분석해보세요!