[python] scikit-learn을 이용한 나이브 베이즈 분류 시각화

나이브 베이즈 분류는 머신 러닝의 기초적인 알고리즘 중 하나로, 특히 텍스트 분류에 널리 사용됩니다. scikit-learn은 파이썬에서 머신 러닝 모델을 쉽게 구현하고 시각화할 수 있는 라이브러리입니다. 이번 블로그 포스트에서는 scikit-learn을 사용하여 나이브 베이즈 분류 모델을 구현하고 시각화하는 방법에 대해 알아보겠습니다.

데이터 준비

나이브 베이즈 분류 모델을 시각화하기 위해, 우선적으로 데이터를 준비해야 합니다. 예제로 사용할 데이터는 스팸 메일 분류를 위한 이메일 데이터입니다. scikit-learn에는 이러한 예제 데이터셋을 쉽게 불러올 수 있는 기능이 있습니다.

import numpy as np
from sklearn.datasets import fetch_20newsgroups

# 스팸 메일 데이터셋 불러오기
categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
data_train = fetch_20newsgroups(subset='train', categories=categories)
data_test = fetch_20newsgroups(subset='test', categories=categories)

# 데이터셋 확인
print("훈련 데이터 개수:", len(data_train.data))
print("테스트 데이터 개수:", len(data_test.data))

위 코드는 fetch_20newsgroups 함수를 사용하여 스팸 메일 데이터셋을 불러온 뒤, 훈련 데이터와 테스트 데이터의 개수를 확인하는 예제입니다. 데이터를 불러온 후에는 이를 바탕으로 나이브 베이즈 분류 모델을 구현할 수 있습니다.

나이브 베이즈 분류 모델 구현

scikit-learn에서 제공하는 MultinomialNB 클래스를 사용하여 나이브 베이즈 분류 모델을 구현할 수 있습니다. 이 클래스는 다항분포 나이브 베이즈 모형을 사용하며, 텍스트 데이터를 처리하는 데에 적합합니다.

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline

# 특징 벡터 추출 및 나이브 베이즈 분류 모델 구성
model = make_pipeline(
    CountVectorizer(),
    MultinomialNB()
)

# 모델 학습
model.fit(data_train.data, data_train.target)

위 코드는 CountVectorizer를 사용하여 텍스트 데이터를 특징 벡터로 변환하고, MultinomialNB를 사용하여 나이브 베이즈 분류 모형을 구성하는 예제입니다. 이후에는 데이터를 학습시킨 후, 테스트 데이터에 대한 예측을 수행할 수 있습니다.

나이브 베이즈 분류 결과 시각화

나이브 베이즈 분류 모델의 결과를 시각화하여 성능을 확인할 수 있습니다. 예를 들어, 테스트 데이터의 예측 결과와 실제 결과를 비교하여 정확도를 계산할 수 있습니다.

from sklearn.metrics import accuracy_score

# 테스트 데이터에 대한 예측 수행
predicted = model.predict(data_test.data)

# 정확도 계산
accuracy = accuracy_score(data_test.target, predicted)
print("Accuracy:", accuracy)

위 코드는 accuracy_score 함수를 사용하여 예측 결과와 실제 결과를 비교하고, 정확도를 계산하는 예제입니다. 정확도를 확인함으로써 나이브 베이즈 분류 모델의 성능을 평가할 수 있습니다.

결론

이 블로그 포스트에서는 scikit-learn을 사용하여 나이브 베이즈 분류 모델을 구현하고 시각화하는 방법에 대해 알아보았습니다. 나이브 베이즈 분류는 텍스트 데이터를 처리하는 데에 특히 유용한 알고리즘이며, scikit-learn을 사용하면 쉽게 구현할 수 있습니다.

더 자세한 내용은 scikit-learn 공식 문서를 참조하시기 바랍니다.