[python] 파이썬 PyTorch에서 배치 정규화(batch normalization)를 사용하는 방법은?

배치 정규화는 딥러닝 모델에서 훈련 과정에서 안정적인 학습을 도와주는 중요한 기술 중 하나입니다. PyTorch에서는 torch.nn.BatchNorm2d 클래스를 사용하여 간편하게 배치 정규화를 적용할 수 있습니다. 다음은 PyTorch에서 배치 정규화를 사용하는 예제 코드입니다.

import torch
import torch.nn as nn

# 예제로 사용할 모델 생성
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        # 이후 모델 구성

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # 이후 모델 연산

# 모델 인스턴스 생성
model = MyModel()

# 배치 정규화 적용
input = torch.randn(1, 3, 32, 32)  # 입력 데이터 예시
output = model(input)

# 학습 과정에서 배치 정규화 적용
model.train()

위 코드에서 MyModel 클래스에서 self.bn1 = nn.BatchNorm2d(64)를 사용하여 첫 번째 컨볼루션 레이어의 출력에 배치 정규화를 적용했습니다. 이후 forward 메서드에서 배치 정규화된 출력을 사용하여 모델 연산을 수행할 수 있습니다.

또한, 학습 과정에서는 model.train()을 호출하여 배치 정규화의 학습과정을 활성화할 수 있습니다. 배치 정규화는 학습될 때마다 입력 분포가 변하므로 학습 중에만 사용되어야 합니다.

더 자세한 내용은 PyTorch 공식 문서를 참고하시기 바랍니다.