[python] 파이썬 PyTorch에서 신경망을 시각화하는 방법은?

PyTorch는 딥러닝을 위한 강력한 프레임워크로, 신경망 모델을 구축하고 학습하는 데 사용됩니다. 이때 신경망의 구조와 학습된 가중치를 시각화하는 것은 모델을 이해하고 디버깅하는 데 매우 유용합니다.

여기서는 PyTorch에서 신경망을 시각화하는 방법에 대해 알아보겠습니다.

1. Graphviz을 사용한 시각화

import torch
from torchviz import make_dot

# 예시 신경망 모델
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = torch.nn.Linear(10, 20)
        self.fc2 = torch.nn.Linear(20, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 모델 인스턴스화
model = MyModel()

# 입력 데이터
x = torch.randn(1, 10)

# 모델의 그래프 시각화
y = model(x)
make_dot(y, params=dict(model.named_parameters())).render("model", format="png")

위의 코드를 실행하면 model.png 파일로 그래프가 생성됩니다. 이 그래프는 신경망의 계산 그래프를 시각적으로 표현합니다.

2. 인터페이스 시각화

PyTorch의 torchsummary 패키지를 사용하면 모델의 구조를 요약하여 표시할 수 있습니다.

from torchsummary import summary

# 모델 요약 출력
summary(model, (1, 10))

위의 코드를 실행하면 모델의 요약 정보가 출력됩니다. 이 정보에는 신경망 계층, 출력 크기, 학습 가능한 파라미터 등이 포함됩니다.

3. 필터 시각화

합성곱 신경망(Convolutional Neural Network)의 경우, 각 필터가 신경망의 학습된 특징을 나타냅니다. 이를 시각화하여 어떤 특징을 감지하는지 확인할 수 있습니다.

import matplotlib.pyplot as plt

# Convolutional Neural Network 모델
class CNNModel(torch.nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3)
        self.pool = torch.nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        return x

# 모델 인스턴스화
cnn_model = CNNModel()

# 모델의 첫 번째 필터 시각화
filters = cnn_model.conv1.weight.data

fig = plt.figure(figsize=(10, 5))
for i in range(filters.size(0)):
    ax = fig.add_subplot(4, 4, i + 1)
    ax.imshow(filters[i][0], cmap='gray')
    ax.axis('off')

plt.show()

위의 코드를 실행하면 신경망의 첫 번째 합성곱 계층의 필터가 그려지는 그림이 생성됩니다.

이렇게 PyTorch를 사용하여 신경망을 시각화할 수 있습니다. 이를 통해 모델의 구조와 가중치, 특징 추출에 대한 이해를 높일 수 있습니다.