[파이썬] PyTorch 초음파 및 의료 데이터 분석

PyTorch Logo

PyTorch는 딥 러닝 프레임워크로, 그래픽 처리 장치(GPU)와 함께 연산을 수행하여 딥러닝 모델을 빠르게 학습하고 실행할 수 있습니다. 이러한 기능을 활용하여 의료 분야에서 초음파 이미지를 분석하고 의료 데이터를 이해하는 방법을 알아보겠습니다.

데이터 불러오기

먼저, PyTorch에서 데이터를 불러오는 방법을 알아봅시다. 의료 데이터의 경우, 보통 초음파 이미지와 해당 이미지의 레이블로 구성됩니다. 이미지는 이미지 파일 형식으로 저장되며, 레이블은 해당 이미지가 어떤 카테고리에 속하는지를 나타냅니다. 예를 들어, 초음파 이미지가 심장 질환 유무를 판단하는 용도로 사용된다면, 레이블은 “정상” 또는 “이상”으로 표시될 수 있습니다.

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class UltrasoundDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        label = self.labels[index]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# 데이터 경로와 레이블을 지정합니다.
image_paths = ["path/to/image1.png", "path/to/image2.png", ...]
labels = ["normal", "abnormal", ...]

# 데이터 전처리를 위한 변환 함수를 정의합니다.
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# 데이터셋 객체를 생성합니다.
dataset = UltrasoundDataset(image_paths, labels, transform=transform)

# 데이터 로더를 생성합니다.
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

위의 코드는 UltrasoundDataset 클래스를 정의하고, 데이터셋을 생성하는 방법을 보여줍니다. UltrasoundDataset 클래스에서는 __getitem__ 함수를 사용하여 이미지와 레이블을 반환하며, 필요한 경우 데이터 전처리를 위해 transform 함수를 적용합니다. image_pathslabels는 데이터의 경로와 해당 경로에 대응하는 레이블로 초기화되어야 합니다. 마지막으로, DataLoader로 데이터 로더를 생성하여 배치 단위로 데이터를 로드할 수 있습니다.

모델 학습하기

데이터를 로드했다면, 이제 모델을 정의하고 학습하는 과정을 알아봅시다. 의료 데이터 분석에 가장 많이 사용되는 딥 러닝 모델 중 하나인 CNN(Convolutional Neural Network)을 예로 들겠습니다.

import torch.nn as nn
import torch.optim as optim

# CNN 모델 정의하기
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 2)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(-1, 32 * 32 * 32)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x

# 모델 인스턴스 생성하기
model = CNN()

# 손실 함수와 옵티마이저 정의하기
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 모델 학습하기
num_epochs = 10

for epoch in range(num_epochs):
    for images, labels in dataloader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

위의 코드는 간단한 CNN 모델을 정의하고, 모델을 학습하는 방법을 보여줍니다. CNN 클래스에서는 forward 함수를 사용하여 입력을 모델에 통과시키고 출력을 반환합니다. criterion은 손실 함수로, optimizer는 경사하강법 알고리즘 중 하나인 Adam을 사용하여 가중치를 업데이트합니다. 반복문을 통해 데이터를 여러 번 사용하여 모델을 학습시키고, 각 에포크마다 손실을 출력합니다.

모델 평가하기

모델을 학습했다면, 이제 학습된 모델을 평가하여 성능을 알아볼 수 있습니다. 대표적인 평가 지표로는 정확도(Accuracy)가 있습니다.

correct = 0
total = 0

with torch.no_grad():
    for images, labels in dataloader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy: {accuracy}%")

위의 코드는 학습된 모델을 사용하여 데이터를 예측하고, 예측값과 실제값을 비교하여 정확도를 계산하는 방법을 보여줍니다. torch.max 함수를 사용하여 각 예측값에서 가장 큰 값을 찾고, sum 함수를 사용하여 정확히 예측한 개수를 세어 정확도를 계산합니다.

결론

이처럼 PyTorch를 사용하여 의료 데이터를 분석하고 딥러닝 모델을 학습할 수 있습니다. 초음파 이미지와 같은 의료 데이터를 활용하면 질환의 진단과 예측을 더욱 정확하게 할 수 있습니다. 딥 러닝을 통해 의료 분야의 성과를 향상시키는 데는 많은 잠재력이 있으며, PyTorch는 이를 위한 강력한 도구입니다.

References:

참고 자료: