[파이썬] statsmodels에서 그래프 및 시각화

데이터 분석과 통계 모델링을 위한 파이썬 패키지인 statsmodels는 다양한 통계 기법을 제공합니다. 그러나 모델링 결과를 시각화하여 더욱 직관적인 이해를 돕는 것은 매우 중요합니다. 이번 글에서는 statsmodels에서 그래프와 시각화를 수행하는 방법에 대해 알아보겠습니다.

1. 회귀 모델의 예측값 시각화하기

먼저, 회귀 모델의 예측값을 시각화하는 방법에 대해 살펴보겠습니다. 아래의 예제 코드는 statsmodels에서 제공하는 OLS 클래스를 사용하여 단순 선형 회귀 모델을 적합하고, 이를 시각화하는 방법을 보여줍니다.

import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm

# 예제 데이터 생성
np.random.seed(0)
x = np.random.rand(100) * 10
y = 2 * x + 3 + np.random.randn(100)

# 회귀 모델 적합
X = sm.add_constant(x)
model = sm.OLS(y, X)
results = model.fit()

# 예측값 계산
pred = results.predict(X)

# 산점도와 회귀선 시각화
plt.scatter(x, y, label='Data')
plt.plot(x, pred, color='red', label='Regression Line')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.show()

위 코드에서는 먼저 예제 데이터를 생성하고, add_constant() 함수를 사용하여 설명 변수 x에 상수항을 추가합니다. 그리고 OLS 클래스를 이용하여 회귀 모델을 적합하고, fit() 메서드를 통해 모델을 학습시킵니다. 마지막으로 predict() 메서드를 사용하여 모델의 예측값을 계산하고, 산점도와 회귀선을 함께 그리는 것으로 결과를 시각화합니다.

2. 잔차 분석하기

회귀 분석에서 모델의 성능을 평가하기 위해 잔차(residuals) 분석을 수행하는 것이 일반적입니다. 잔차는 실제 관측값과 모델로부터 예측된 값의 차이를 의미하며, 이를 시각화하여 잔차의 분포와 패턴을 확인할 수 있습니다.

# 잔차 분석 시각화
residuals = results.resid
plt.scatter(pred, residuals)
plt.axhline(y=0, color='red', linestyle='--')
plt.xlabel('Predicted Values')
plt.ylabel('Residuals')
plt.title('Residual Analysis')
plt.show()

위 코드에서는 resid 속성을 사용하여 모델의 잔차를 추출하고, 잔차와 예측값을 산점도로 그립니다. 수평선을 그려서 잔차의 평균이 0인지 확인하며, 잔차가 일정한 패턴을 따르는지 검토합니다.

3. QQ 플롯으로 정규성 검증하기

통계 모델링에서는 가정된 분포의 적합성을 검증하기 위해 QQ 플롯을 활용합니다. QQ 플롯은 잔차의 분포와 정규분포의 분위수를 비교하여, 가정한 분포와 잔차의 분포가 일치하는지 확인하는 도구입니다.

# QQ 플롯으로 정규성 검증
sm.qqplot(residuals, line='s')
plt.title('QQ Plot')
plt.show()

위 코드에서는 qqplot() 함수를 사용하여 QQ 플롯을 그립니다. 잔차가 정규분포를 따른다면, 잔차의 점들이 대각선 주위로 랜덤하게 분포할 것입니다. 주황색 선은 기대되는 정규 분포의 분위수를 나타내며, 점들이 이선 주위에 모여있는지 확인해야 합니다.

4. 시계열 데이터의 그래프와 추세 분석하기

statsmodels는 시계열 데이터 분석을 위한 다양한 기능을 제공합니다. 예를 들어, ARIMA 모델을 사용하여 시계열 데이터의 예측을 수행하고, 그 결과를 시각화할 수 있습니다.

from statsmodels.tsa.arima.model import ARIMA

# 예제 시계열 데이터 생성
np.random.seed(0)
n = 200
t = np.arange(n)
y = np.cumsum(np.random.randn(n))

# ARIMA 모델 적합
model = ARIMA(y, order=(1, 1, 1))
results = model.fit()

# 시계열 데이터와 예측 결과 시각화
plt.plot(t, y, label='Time Series Data')
plt.plot(t, results.fittedvalues, color='red', label='ARIMA Fit')
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('ARIMA Prediction')
plt.legend()
plt.show()

위 코드에서는 ARIMA 클래스를 사용하여 ARIMA 모델을 생성하고, fit() 메서드를 통해 모델을 학습시킵니다. 마지막으로 fittedvalues 속성을 사용하여 학습된 모델의 예측 결과를 얻어와서 데이터와 함께 시각화합니다.

statsmodels는 이외에도 다양한 그래프와 시각화 기능을 제공합니다. 위에서 소개한 방법을 활용하여 모델링 결과를 더욱 직관적으로 이해하고, 데이터의 특성과 패턴을 파악할 수 있습니다.