[파이썬] PyTorch 캡슐 네트워크 구현

소개

캡슐 네트워크(Capsule Network)는 딥러닝 모델 중 하나로, Hinton 등이 제안한 아키텍처입니다. 캡슐 네트워크는 기존 컨볼루션 신경망(Convolutional Neural Network)과 달리, 객체의 성질을 보다 잘 표현할 수 있는 캡슐(Capsule)이라는 개념을 도입하여 성능을 향상시킬 수 있습니다.

이 블로그 포스트에서는 PyTorch를 사용하여 간단한 캡슐 네트워크를 구현하는 방법을 소개합니다.

코드 구현

아래는 캡슐 네트워크 모델을 구현하는 예시 코드입니다.

import torch
import torch.nn as nn
import torch.nn.functional as F


class CapsuleNetwork(nn.Module):
    def __init__(self):
        super(CapsuleNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9)
        self.primary_capsules = PrimaryCapsules()
        self.digit_capsules = DigitCapsules()

    def forward(self, x):
        x = self.conv1(x)
        x = self.primary_capsules(x)
        x = self.digit_capsules(x)
        return x


class PrimaryCapsules(nn.Module):
    def __init__(self):
        super(PrimaryCapsules, self).__init__()
        self.conv = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), 32 * 8, -1)
        x = self.squash(x)
        return x

    def squash(self, x):
        squared_norm = (x ** 2).sum(dim=-1, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        x = scale * x / torch.sqrt(squared_norm)
        return x


class DigitCapsules(nn.Module):
    def __init__(self):
        super(DigitCapsules, self).__init__()
        self.routing_iterations = 3
        self.W = nn.Parameter(torch.randn(10, 32 * 8, 16, 8))

    def forward(self, x):
        batch_size = x.size(0)
        x = x[:, :, None, None, :]
        x_hat = x @ self.W
        logits = torch.zeros(batch_size, 10, 16, 1)
        if torch.cuda.is_available():
            logits = logits.cuda()

        for iteration in range(self.routing_iterations):
            route_probs = F.softmax(logits, dim=1)
            x_hat_weighted = (x_hat * route_probs).sum(dim=1)
            v = self.squash(x_hat_weighted)
            if iteration < self.routing_iterations - 1:
                agreement = (x_hat_weighted * v[:, None, :, :]).sum(dim=-1, keepdim=True)
                logits += agreement

        return v

    def squash(self, x):
        squared_norm = (x ** 2).sum(dim=-1, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        x = scale * x / torch.sqrt(squared_norm)
        return x

요약

이번 포스트에서는 PyTorch를 사용하여 캡슐 네트워크를 구현하는 방법을 알아보았습니다. 이는 캡슐 네트워크의 기본 아키텍처를 간단하게 구현한 예시이며, 추가적인 세부 조정이나 데이터에 맞는 모델링은 필요합니다.

더 자세한 내용은 PyTorch 공식 문서를 참고하세요.