[python] 파이썬 PyTorch에서 모델 파라미터를 공유하는 방법은?

PyTorch에서는 여러 모델이 같은 파라미터를 공유할 수 있는 기능을 제공합니다. 이를 통해 모델 크기를 줄이고 학습 속도를 향상시킬 수 있습니다. 아래는 모델 파라미터를 공유하는 방법에 대한 예제 코드입니다.

import torch
import torch.nn as nn

class SharedModel(nn.Module):
    def __init__(self):
        super(SharedModel, self).__init__()
        
        # 공유할 파라미터 정의
        self.shared_weight = nn.Parameter(torch.Tensor(1))
        
        # 다른 모듈에서 사용할 파라미터 정의
        self.weight1 = nn.Parameter(torch.Tensor(1))
        self.weight2 = nn.Parameter(torch.Tensor(1))
        
    def forward(self, x):
        # 공유 파라미터 사용
        shared_output = self.shared_weight * x
        
        # 다른 파라미터 사용
        output1 = self.weight1 * shared_output
        output2 = self.weight2 * shared_output
        
        return output1, output2
        
# 모델 인스턴스 생성
model1 = SharedModel()
model2 = SharedModel()

# 모델 파라미터 공유
model2.shared_weight = model1.shared_weight

# 예제 입력
input = torch.Tensor([2])

# 모델 실행
output1, output2 = model1(input)
print(output1, output2)  # tensor([2.0357], grad_fn=<MulBackward0>) tensor([0.3283], grad_fn=<MulBackward0>)

output1, output2 = model2(input)
print(output1, output2)  # tensor([2.0357], grad_fn=<MulBackward0>) tensor([0.3283], grad_fn=<MulBackward0>)

위의 예제 코드에서는 SharedModel 클래스를 정의하고, shared_weight라는 공유할 파라미터와 weight1, weight2라는 다른 파라미터를 정의합니다. forward 메소드에서는 shared_weight를 사용하여 공유 파라미터를 계산하고, 이를 다른 파라미터와 곱하여 출력합니다.

model1model2는 동일한 SharedModel 인스턴스입니다. model2.shared_weight = model1.shared_weight와 같은 코드를 사용하여 모델 파라미터를 공유할 수 있습니다.

예제 입력을 모델에 전달하고 출력을 확인해보면, model1model2가 동일한 출력을 생성하는 것을 확인할 수 있습니다.

자세한 내용은 공식 PyTorch 문서를 참조하시기 바랍니다.