[python] 파이썬 PyTorch에서 배치 정규화(batch normalization)를 사용하는 방법은?
배치 정규화는 딥러닝 모델에서 훈련 과정에서 안정적인 학습을 도와주는 중요한 기술 중 하나입니다. PyTorch에서는 torch.nn.BatchNorm2d
클래스를 사용하여 간편하게 배치 정규화를 적용할 수 있습니다. 다음은 PyTorch에서 배치 정규화를 사용하는 예제 코드입니다.
import torch
import torch.nn as nn
# 예제로 사용할 모델 생성
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# 이후 모델 구성
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# 이후 모델 연산
# 모델 인스턴스 생성
model = MyModel()
# 배치 정규화 적용
input = torch.randn(1, 3, 32, 32) # 입력 데이터 예시
output = model(input)
# 학습 과정에서 배치 정규화 적용
model.train()
위 코드에서 MyModel
클래스에서 self.bn1 = nn.BatchNorm2d(64)
를 사용하여 첫 번째 컨볼루션 레이어의 출력에 배치 정규화를 적용했습니다. 이후 forward
메서드에서 배치 정규화된 출력을 사용하여 모델 연산을 수행할 수 있습니다.
또한, 학습 과정에서는 model.train()
을 호출하여 배치 정규화의 학습과정을 활성화할 수 있습니다. 배치 정규화는 학습될 때마다 입력 분포가 변하므로 학습 중에만 사용되어야 합니다.
더 자세한 내용은 PyTorch 공식 문서를 참고하시기 바랍니다.