[파이썬] PyTorch GANs를 이용한 이미지 생성

GANs (Generative Adversarial Networks)은 딥 러닝 모델의 한 유형으로, 주어진 데이터를 학습하여 새로운 데이터를 생성하는 능력을 갖추고 있습니다. 이 기술은 이미지, 음악, 텍스트 등 다양한 유형의 데이터 생성에 사용될 수 있습니다. 이번 블로그에서는 PyTorch를 사용하여 GANs를 이용해 이미지를 생성하는 방법을 알아보겠습니다.

GANs란?

GANs은 생성자(Generator)와 판별자(Discriminator) 두 개의 네트워크로 구성되어 있습니다. 생성자는 무작위 노이즈 벡터를 입력으로 받아 실제 데이터와 비슷한 가짜 데이터를 생성하는 역할을 합니다. 판별자는 생성자가 생성한 가짜 데이터와 실제 데이터를 구분하는 역할을 합니다. 양쪽 네트워크는 경쟁적인 학습을 통해 서로를 발전시키는 구조로 이루어져 있습니다.

PyTorch를 이용한 GANs 구현

GANs를 구현하기 위해 PyTorch를 사용할 것입니다. PyTorch는 딥 러닝 프레임워크로서 GPU 가속을 지원하여 빠른 학습 속도와 유연한 모델 개발을 가능하게 합니다. 다음은 PyTorch로 GANs를 구현하는 예시 코드입니다.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

# 생성자 클래스 정의
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 784),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.model(x)
        return x

# 판별자 클래스 정의
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.model(x)
        return x

# 데이터 로딩
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = MNIST(root="data/", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 모델 초기화
generator = Generator()
discriminator = Discriminator()

# 손실 함수와 옵티마이저 정의
criterion = nn.BCELoss()
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

# 학습 루프
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(dataloader):
        real_images = real_images.view(-1, 784).to(device)
        batch_size = real_images.shape[0]

        # 생성자 입력을 위한 무작위 노이즈 벡터 생성
        noise = torch.randn(batch_size, 100).to(device)

        # 생성자에 의해 가짜 이미지 생성
        fake_images = generator(noise)

        # 판별자에 실제 이미지와 가짜 이미지를 입력한 결과 반환
        real_pred = discriminator(real_images)
        fake_pred = discriminator(fake_images)

        # 손실 계산
        real_loss = criterion(real_pred, torch.ones_like(real_pred))
        fake_loss = criterion(fake_pred, torch.zeros_like(fake_pred))
        discriminator_loss = (real_loss + fake_loss) / 2

        # 판별자 업데이트
        discriminator_optimizer.zero_grad()
        discriminator_loss.backward(retain_graph=True)
        discriminator_optimizer.step()

        # 생성자 손실 계산
        generator_loss = criterion(fake_pred, torch.ones_like(fake_pred))

        # 생성자 업데이트
        generator_optimizer.zero_grad()
        generator_loss.backward()
        generator_optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Discriminator Loss: {discriminator_loss.item():.4f}, "
          f"Generator Loss: {generator_loss.item():.4f}")

위의 코드에서는 MNIST 데이터셋을 사용하여 GANs를 학습하고 있습니다. 생성자와 판별자는 각각 Fully-Connected 신경망으로 정의되어 있으며, 생성자는 100차원의 랜덤 노이즈 벡터를 입력으로 받아 784차원의 가짜 이미지를 생성합니다. 판별자는 입력 이미지가 실제 이미지인지 가짜 이미지인지 구분합니다.

GANs를 학습하기 위한 일반적인 절차는 다음과 같습니다.

  1. 랜덤한 노이즈 벡터를 생성하고 생성자를 통해 가짜 이미지를 생성합니다.
  2. 실제 이미지와 가짜 이미지를 판별자에 입력하여 판별 결과를 얻습니다.
  3. 판별자와 생성자의 손실을 계산하고 역전파를 통해 각 네트워크를 업데이트합니다.
  4. 이 과정을 반복하며 GANs를 학습시킵니다.

위의 예시 코드에서는 한 번의 학습 루프에서 생성자와 판별자를 번갈아가며 업데이트하고 있습니다. 생성자는 판별자의 판별을 통과하여 가짜 이미지를 실제 이미지처럼 보이도록 학습하고, 판별자는 실제 이미지와 가짜 이미지를 구분할 수 있는 능력을 향상시킵니다.

결론

이번 블로그에서는 PyTorch를 사용하여 GANs를 구현하여 이미지를 생성하는 방법을 알아보았습니다. GANs은 딥 러닝 모델 중에서도 특히 흥미로운 분야 중 하나이며, 다양한 응용 분야에서 활용될 수 있습니다. 음악 생성, 캐릭터 생성, 스타일 변환 등 다양한 예제에서 GANs가 성공적으로 활용되고 있습니다. PyTorch를 사용하면 비교적 간단하게 GANs를 구현할 수 있으며, 이를 통해 창의적이고 흥미로운 모델을 개발해 볼 수 있습니다.