[파이썬] PyTorch 피처 추출 및 시각화

PyTorch는 딥 러닝을 위한 강력한 프레임워크로 알려져 있습니다. 이 프레임워크를 사용하면 컨볼루션 신경망(CNN)과 같은 다양한 딥러닝 모델을 생성하고 훈련할 수 있습니다. 그런데 이러한 모델을 구성하는 인공신경망의 내부에서 특징(피처)을 추출하고 시각화하는 것도 중요한 작업 중 하나입니다.

이 블로그 포스트에서는 PyTorch를 사용하여 이미지 데이터의 특징을 추출하고 시각화하는 방법에 대해 알아보겠습니다.

데이터 준비

먼저, 피처 추출을 위해 사용할 이미지 데이터를 준비해야 합니다. 본 예시에서는 CIFAR-10 데이터셋을 사용하겠습니다. CIFAR-10 데이터셋은 10개의 다른 클래스로 구성된 60,000개의 32x32 컬러 이미지로 이루어져 있습니다.

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

모델 생성

다음으로, 이미지의 피처를 추출하기 위해 사전 훈련된 신경망 모델을 사용하겠습니다. 여기서는 torchvision 패키지에서 제공하는 ResNet-50 모델을 사용하겠습니다.

import torchvision.models as models

# 미리 훈련된 ResNet-50 모델 로딩
model = models.resnet50(pretrained=True)

피처 추출

이제 사전 훈련된 모델을 사용하여 이미지의 피처를 추출해보겠습니다. 피처를 추출하기 위해서는 이미지를 모델에 입력하고, 모델의 중간 피처를 반환받아야 합니다.

# 이미지 추출할 레이어 선택
layer = model.avgpool

# 피처 추출을 위해 forward pass 진행
features = []
def hook_fn(module, input, output):
    features.append(output.flatten().detach().numpy())

model.avgpool.register_forward_hook(hook_fn)

# 이미지별로 피처 추출
all_features = []
for images, labels in trainloader:
    outputs = model(images)
    all_features.extend(features)
    features = []

이제 all_features 변수에 모든 이미지의 피처들이 저장되었습니다. 이렇게 추출된 피처는 다양한 시각화나 머신 러닝 모델에 사용될 수 있습니다.

피처 시각화

마지막으로, 추출한 피처를 시각화해보겠습니다. 여기서는 t-SNE(Stochastic Neighbor Embedding)라는 알고리즘을 사용하여 피처를 2차원 공간에 표현하겠습니다.

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# t-SNE 알고리즘으로 2차원으로 변환
tsne = TSNE(n_components=2, random_state=0)
tsne_results = tsne.fit_transform(all_features)

# 시각화
plt.figure(figsize=(10, 10))
plt.scatter(tsne_results[:, 0], tsne_results[:, 1])
plt.title("t-SNE Visualization of Image Features")
plt.show()

위 코드를 실행하면, 이미지의 피처들을 2D 공간에 시각화한 그래프가 출력됩니다.

PyTorch를 사용하여 이미지 데이터에서 피처를 추출하고 시각화하는 방법에 대해 알아보았습니다. 이러한 기법은 이미지 데이터의 특징을 시각화하거나, 이미지 분류 및 검색과 같은 작업에서 중요한 역할을 합니다. 피처 추출과 시각화 기법을 응용하여 다양한 딥러닝 애플리케이션을 개발해보세요.