[python] 파이썬 PyTorch에서 잠재 공간에 있는 데이터를 시각화하는 방법은?

PyTorch는 딥러닝을 위한 강력한 프레임워크입니다. 잠재 공간(latent space)은 학습된 모델에서 추출한 특징들의 공간입니다. 이 잠재 공간에 있는 데이터를 시각화하는 것은 모델의 성능을 이해하고, 데이터의 특성을 파악하는 데에 도움이 됩니다.

아래는 파이썬 PyTorch에서 잠재 공간에 있는 데이터를 시각화하는 방법에 대한 예제 코드입니다.

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

# 데이터셋 로드
root = './data'  # 데이터셋 경로 설정
dataset = datasets.MNIST(root=root, train=True, download=True, transform=ToTensor())
loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

# 모델 정의 및 학습
# 여기에 모델을 정의하고 학습하는 코드를 작성하세요.

# 잠재 공간에서 데이터 추출
with torch.no_grad():
    latent_data = []
    true_labels = []
    for images, labels in loader:
        # 모델에 이미지를 전달하여 잠재 공간 벡터 추출
        latent_vector = model(images)
        latent_data.append(latent_vector)
        true_labels.extend(labels)

# 시각화
latent_data = torch.cat(latent_data, dim=0)  # 잠재 공간 데이터를 하나의 텐서로 통합
true_labels = torch.tensor(true_labels)
plt.scatter(latent_data[:, 0], latent_data[:, 1], c=true_labels, cmap='rainbow')  # 잠재 공간 데이터를 산점도로 시각화
plt.colorbar()
plt.show()

위의 코드에서는 MNIST 데이터셋을 예로 사용하였습니다. 데이터셋을 로드한 후 모델을 정의하고 학습시키는 부분은 별도로 작성해주어야 합니다. 이후 with torch.no_grad() 블록을 사용하여 모델을 통과한 이미지들의 잠재 공간 벡터를 추출하고, 그 벡터를 시각화하는 과정을 수행합니다. 위의 예제는 잠재 공간이 2차원인 경우를 가정하여 산점도로 시각화하였습니다.

추가적으로, 시각화에 사용되는 라이브러리인 matplotlib.pyplot을 설치해야 합니다. pip install matplotlib 명령을 통해 설치할 수 있습니다.

참고 자료: