[python] 파이썬 PyTorch에서 적대적 생성 신경망(GAN)을 구성하고 학습시키는 방법은?

GAN은 적대적인 방식으로 생성 모델과 판별 모델을 함께 학습시켜 실제같은 이미지를 생성하는 생성 모델입니다. 이번 블로그 포스트에서는 PyTorch를 사용하여 GAN을 구성하고 학습시키는 방법을 알아보겠습니다.

필요한 패키지 설치

먼저, GAN을 구현하기 위해 아래와 같은 패키지들을 설치해야 합니다.

pip install torch torchvision

데이터셋 로딩

GAN을 학습시키기 위해 먼저 데이터셋을 로딩해야 합니다. PyTorch에서는 torchvision 패키지를 사용하여 데이터셋을 쉽게 로딩할 수 있습니다. 데이터셋에 따라 필요한 전처리 작업을 수행한 후 DataLoader를 사용하여 배치 단위로 데이터를 공급할 수 있습니다.

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

# 이미지 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

# MNIST 데이터셋 로딩
dataset = MNIST(root='data/', train=True, transform=transform, download=True)

# 데이터 로더 생성
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

생성자 모델 정의

GAN의 생성자 모델은 랜덤 노이즈 벡터를 입력으로 받아 실제같은 이미지를 생성해내는 역할을 합니다. 일반적으로 픽셀 값을 [-1, 1] 범위로 정규화하는 활성화 함수인 tanh를 사용하여 출력을 조정합니다.

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, output_size),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.model(x)
        return x

판별자 모델 정의

GAN의 판별자 모델은 생성된 이미지와 실제 이미지를 구분하는 역할을 합니다. 이진 분류 문제이므로 출력은 [0, 1] 범위로 정규화하는 sigmoid 활성화 함수를 사용합니다.

class Discriminator(nn.Module):
    def __init__(self, input_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.model(x)
        return x

모델 초기화 및 하이퍼파라미터 설정

GAN 모델을 초기화하고 학습에 사용할 하이퍼파라미터를 설정해야 합니다. 생성자와 판별자의 입력 크기, 학습률, 손실 함수 등을 정의합니다.

# 하이퍼파라미터 설정
input_size = 100
output_size = 28 * 28
lr = 0.0002
num_epochs = 100

# 생성자와 판별자 초기화
generator = Generator(input_size, output_size)
discriminator = Discriminator(output_size)

# 손실 함수와 최적화 알고리즘 설정
criterion = nn.BCELoss()
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

학습 루프 구성

GAN 모델을 학습시키기 위해 생성자와 판별자를 번갈아가면서 학습시키는 학습 루프를 구성합니다.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)

for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(dataloader):
        # 실제 이미지 배치를 디바이스에 전송
        real_images = real_images.view(-1, output_size).to(device)
        
        # 판별자 학습
        discriminator_optimizer.zero_grad()
        
        # 진짜 이미지 판별
        real_labels = torch.ones(real_images.size(0), 1).to(device)
        discriminator_output = discriminator(real_images)
        real_loss = criterion(discriminator_output, real_labels)
        
        # 가짜 이미지 생성 및 판별
        noise = torch.randn(real_images.size(0), input_size).to(device)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1).to(device)
        discriminator_output = discriminator(fake_images)
        fake_loss = criterion(discriminator_output, fake_labels)
        
        # 판별자 손실 계산 및 역전파
        discriminator_loss = real_loss + fake_loss
        discriminator_loss.backward()
        discriminator_optimizer.step()
        
        # 생성자 학습
        generator_optimizer.zero_grad()
        
        # 가짜 이미지 생성 및 판별
        noise = torch.randn(real_images.size(0), input_size).to(device)
        fake_images = generator(noise)
        discriminator_output = discriminator(fake_images)
        generator_loss = criterion(discriminator_output, real_labels)
        
        # 생성자 손실 계산 및 역전파
        generator_loss.backward()
        generator_optimizer.step()
        
        # 학습 진행 상황 출력
        if (batch_idx+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Generator Loss: {generator_loss.item():.4f}, Discriminator Loss: {discriminator_loss.item():.4f}')

결과 확인

학습이 완료된 생성자 모델을 사용하여 랜덤 노이즈 벡터를 입력으로 받아 이미지를 생성해봅니다.

import matplotlib.pyplot as plt
import numpy as np

# 랜덤 노이즈 벡터 생성
noise = torch.randn(1, input_size).to(device)

# 생성자로부터 이미지 생성
generated_image = generator(noise).detach().cpu().numpy()

# 생성된 이미지 시각화
plt.imshow(np.transpose(generated_image.reshape(28, 28), (1, 0)), cmap='gray')
plt.axis('off')
plt.show()

GAN 모델을 구성하고 학습시키는 방법에 대해 알아보았습니다. 이제 이 코드를 실행하여 GAN 모델을 구현하고 다양한 데이터셋에 적용해보세요.


참고 자료: