[파이썬] PyTorch 의료 영상 분석

PyTorch는 딥 러닝 프레임워크 중 하나로, 주로 의료 영상 분석 작업에 많이 사용됩니다. 이 글에서는 PyTorch를 사용하여 의료 영상 분석을 수행하는 방법에 대해 알아보겠습니다.

준비 사항

의료 영상 분석을 위해 다음과 같은 패키지를 설치해야 합니다.

pip install torch torchvision matplotlib

또한, 트레이닝 데이터셋과 테스트 데이터셋이 필요합니다. 의료 영상 데이터셋은 Medical Image Analysis Contest (MICCAI)Kaggle 등에서 다운로드할 수 있습니다.

데이터 로딩

PyTorch에서는 torchvision 패키지를 사용하여 의료 영상 데이터를 로딩할 수 있습니다. 다음은 예제 코드입니다.

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

# 데이터셋 경로
train_data_path = "path/to/train/data"
test_data_path = "path/to/test/data"

# 이미지 변환 설정
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# 트레이닝 데이터셋 로딩
train_dataset = ImageFolder(train_data_path, transform=transform)

# 테스트 데이터셋 로딩
test_dataset = ImageFolder(test_data_path, transform=transform)

# 로딩된 데이터셋 정보 확인
print("Number of classes:", len(train_dataset.classes))
print("Number of training images:", len(train_dataset))
print("Number of test images:", len(test_dataset))

모델 정의

의료 영상 분석에는 다양한 딥 러닝 모델을 사용할 수 있습니다. PyTorch에서는 간단하게 사용할 수 있는 사전 학습된 모델들을 제공합니다. 다음은 torchvision.models 패키지에서 사전 학습된 ResNet 모델을 로딩하는 예제입니다.

import torch
import torchvision.models as models

# 사전 학습된 ResNet 모델 로딩
model = models.resnet50(pretrained=True)

# 모델을 GPU로 이동
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

학습 및 평가

의료 영상 분석에서는 큰 규모의 데이터셋과 복잡한 모델을 학습해야 합니다. 따라서, 학습을 위해 GPU를 사용하는 것이 좋습니다. 다음은 학습 및 평가를 수행하는 예제입니다.

import torch.nn as nn
import torch.optim as optim

# 학습 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 학습 함수
def train(model, dataloader, criterion, optimizer):
    model.train()
    train_loss = 0

    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    avg_loss = train_loss / len(dataloader.dataset)
    return avg_loss

# 테스트 함수
def test(model, dataloader, criterion):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * images.size(0)

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()

    avg_loss = test_loss / len(dataloader.dataset)
    accuracy = correct / len(dataloader.dataset) * 100
    return avg_loss, accuracy

# 데이터로더 생성
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# 학습 및 테스트
for epoch in range(10):
    train_loss = train(model, train_dataloader, criterion, optimizer)
    test_loss, accuracy = test(model, test_dataloader, criterion)

    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}, Accuracy = {accuracy:.2f}%")

결론

PyTorch를 사용하여 의료 영상 분석을 수행하는 방법에 대해 알아보았습니다. PyTorch는 다양한 딥 러닝 모델과 편리한 데이터 로딩 및 학습/평가 기능을 제공하여 의료 영상 분석을 더욱 쉽게 할 수 있도록 도와줍니다. 향후 더 복잡한 의료 영상 분석 작업에 대해 더 자세히 알아보기를 추천합니다.

Reference