[python] scikit-learn을 사용한 랜덤 포레스트 시각화
랜덤 포레스트는 앙상블 학습 방법 중 하나로, 여러 개의 의사 결정 트리를 사용하여 분류 또는 회귀 분석을 수행하는 알고리즘입니다. scikit-learn은 파이썬에서 머신 러닝을 위한 유명한 라이브러리입니다. 이 글에서는 scikit-learn을 사용하여 랜덤 포레스트를 시각화하는 방법에 대해 알아보겠습니다.
랜덤 포레스트 모델 훈련
먼저, scikit-learn의 RandomForestClassifier 클래스를 사용하여 랜덤 포레스트 모델을 훈련해야 합니다. 아래의 코드는 scikit-learn의 예제 데이터인 iris 데이터셋을 사용하여 랜덤 포레스트를 훈련하는 방법을 보여줍니다.
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 데이터 로드
iris = load_iris()
X = iris.data
y = iris.target
# 훈련 및 테스트 데이터 분리
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 랜덤 포레스트 모델 훈련
model = RandomForestClassifier()
model.fit(X_train, y_train)
위 코드에서는 RandomForestClassifier
클래스를 사용하여 랜덤 포레스트 모델을 초기화하고 fit
메서드를 사용하여 모델을 훈련합니다. 훈련 데이터는 X_train
과 y_train
에 저장되어 있습니다.
랜덤 포레스트 시각화
랜덤 포레스트 모델을 시각화하려면 scikit-learn의 tree
모듈을 사용해야 합니다. tree
모듈에는 의사 결정 트리를 시각화하기 위한 export_graphviz
함수가 있습니다. 아래의 코드는 랜덤 포레스트 모델을 시각화하는 방법을 보여줍니다.
from sklearn import tree
import matplotlib.pyplot as plt
# 시각화할 의사 결정 트리 선택
estimator = model.estimators_[0]
# 의사 결정 트리 시각화
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(4,4), dpi=300)
tree.plot_tree(estimator,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True)
# 시각화 저장
plt.savefig('random_forest_tree.png')
위 코드에서는 model.estimators_[0]
를 사용하여 랜덤 포레스트 모델의 첫 번째 의사 결정 트리를 선택합니다. tree.plot_tree
함수를 사용하여 의사 결정 트리를 시각화합니다. 시각화된 의사 결정 트리는 matplotlib
을 사용하여 그림 파일로 저장됩니다.
참고 자료
- scikit-learn 랜덤 포레스트 문서: http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
- scikit-learn 트리 시각화 문서: http://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html