[파이썬] PyTorch 이미지 스타일 전송

Image style transfer is a technique that allows us to apply the artistic style of one image onto another image, creating a new image that combines the content of the target image and the style of the reference image. PyTorch, a popular deep learning framework, provides a powerful and flexible platform for implementing image style transfer algorithms.

In this tutorial, we will walk through the process of performing image style transfer using PyTorch. We will be using a pre-trained convolutional neural network (CNN) called VGG19, which has been shown to be effective for style transfer. The steps involved in the style transfer process include:

  1. Loading the content and style images
  2. Preprocessing the images
  3. Defining the content and style loss functions
  4. Extracting features from the pre-trained VGG19 network
  5. Computing the Gram matrices for the style features
  6. Initializing the generated image
  7. Performing gradient descent optimization to update the generated image

Let’s start by installing the necessary libraries:

pip install torch torchvision

Next, let’s import the required packages:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

Loading the Content and Style Images

To begin, we need to load both the content image and the style image. The content image represents the main subject of the final stylized image, while the style image provides the desired artistic style. You can use any images you like, or feel free to use the sample images provided.

content_path = 'path/to/content/image.jpg'
style_path = 'path/to/style/image.jpg'

content_image = Image.open(content_path)
style_image = Image.open(style_path)

Preprocessing the Images

Before we can pass the images through the VGG19 network, we need to preprocess them by scaling and normalizing the pixel values. Additionally, we will convert the images to PyTorch tensors.

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

content_tensor = preprocess(content_image).unsqueeze(0)
style_tensor = preprocess(style_image).unsqueeze(0)

Defining the Content and Style Loss Functions

In style transfer, we aim to minimize the difference between the content features of the generated image and the content features of the content image, as well as the difference between the style features of the generated image and the style features of the style image. We can quantify these differences using loss functions.

content_loss = nn.MSELoss()

def get_style_loss(target_features, style_features):
    gram_matrix = torch.matmul(target_features, target_features.transpose(1, 2))
    return torch.mean(torch.square(gram_matrix - style_features))

style_loss = get_style_loss

Extracting Features from the VGG19 Network

PyTorch provides a pre-trained VGG19 model that allows us to easily extract features from the content and style images. We will use the features from several layers of the VGG19 network to capture both low-level and high-level style information.

vgg = models.vgg19(pretrained=True).features

Computing the Gram Matrices for the Style Features

To calculate the style loss, we need to compute the Gram matrices for the style features. The Gram matrix measures the correlations between different feature maps and helps capture the style information.

def get_gram_matrix(input_features):
    batch_size, num_channels, height, width = input_features.size()
    features = input_features.view(batch_size * num_channels, height * width)
    gram_matrix = torch.matmul(features, features.transpose(0, 1))
    return gram_matrix / (batch_size * num_channels * height * width)

style_features = vgg(style_tensor)
style_gram_matrices = [get_gram_matrix(features) for features in style_features]

Initializing the Generated Image

The generated image is initially set to be a copy of the content image. We will optimize the pixel values of the generated image to minimize the content and style losses.

generated_image = content_tensor.clone().requires_grad_(True)

Performing Gradient Descent Optimization

Finally, we will perform gradient descent optimization to update the generated image. The optimization process iteratively adjusts the pixel values of the generated image to minimize the combined content and style losses.

optimizer = optim.Adam([generated_image], lr=0.01)

num_iterations = 500
for i in range(num_iterations):
    optimizer.zero_grad()
    
    generated_features = vgg(generated_image)
    content_features = vgg(content_tensor)
    
    content_loss_value = content_loss(generated_features, content_features)

    style_loss_value = 0
    for j in range(len(style_features)):
        style_loss_value += style_loss(get_gram_matrix(generated_features[j]), style_gram_matrices[j])
    
    total_loss = content_loss_value + style_loss_value
    
    total_loss.backward()
    optimizer.step()

    if i % 50 == 0:
        print(f"Iteration: {i}, Content Loss: {content_loss_value.item()}, Style Loss: {style_loss_value.item()}")

After the optimization process is complete, the generated image will reflect the desired artistic style while preserving the main subject from the content image. You can save the generated image using the following code:

output_image = transforms.ToPILImage()(generated_image.squeeze(0).detach().cpu())
output_image.save('path/to/output/image.jpg')

That’s it! You have now implemented image style transfer using PyTorch. Experiment with different content and style images to create your own stunning stylized images. Happy coding!