[python] scikit-learn을 활용한 k-최근접 이웃 예측 시각화

k-최근접 이웃(K-Nearest Neighbors, KNN)은 분류 문제에 사용되는 간단하고 강력한 알고리즘입니다. 이웃들 중 가장 가까운 이웃들의 특징을 고려하여 예측을 수행합니다. 이번 블로그에서는 Python의 scikit-learn 라이브러리를 사용하여 k-최근접 이웃 알고리즘을 적용하고, 그 결과를 시각화해보겠습니다.

1. 데이터 불러오기

scikit-learn에서 제공하는 예제 데이터 중 하나인 iris 데이터셋을 사용하겠습니다. 이 데이터셋은 꽃의 품종을 분류하는 문제입니다.

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

2. 데이터 전처리

데이터를 훈련 데이터와 테스트 데이터로 나누고, 스케일링을 수행합니다.

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

3. k-최근접 이웃 모델 학습

from sklearn.neighbors import KNeighborsClassifier

k = 3
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train_scaled, y_train)

4. 예측 및 시각화

테스트 데이터를 사용하여 예측을 수행한 후, 결과를 시각화합니다.

import matplotlib.pyplot as plt
import numpy as np

# 예측 결과 구하기
y_pred = knn.predict(X_test_scaled)

# 샘플들의 예측 클래스와 실제 클래스 비교
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, marker='x')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('K-Nearest Neighbors Classification')
plt.legend(['Predicted', 'Actual'])
plt.show()

5. 결과 분석

시각화된 예측 결과를 통해, k-최근접 이웃 알고리즘이 얼마나 잘 분류하는지 확인할 수 있습니다. 예측된 클래스와 실제 클래스가 일치하는 경우는 예측이 올바르게 수행된 것입니다. 잘못 분류된 경우, 다른 알고리즘을 고려해볼 필요가 있습니다.

참고 자료

이상으로 scikit-learn을 활용한 k-최근접 이웃 예측의 시각화에 대해 알아보았습니다. 간단한 코드를 통해 어떻게 데이터를 준비하고 모델을 학습하여 예측 결과를 확인할 수 있는지 살펴보았습니다. 다양한 데이터셋에 대해 k-최근접 이웃 알고리즘을 적용해보고 결과를 시각화하여 분석해보세요!