[python] 파이썬 PyTorch에서 범주형 데이터 처리를 위한 원-핫 인코딩 방법은?

원-핫 인코딩은 범주형 데이터를 수치형 데이터로 변환하는 일반적인 기법입니다. 이를 PyTorch에서 수행하는 방법은 다음과 같습니다.

import torch

# 입력 범주형 데이터
categories = ['사과', '바나나', '오렌지', '파인애플']

# 원-핫 인코딩을 수행할 범주형 데이터
data = ['사과', '바나나', '오렌지', '파인애플', '바나나']

# 범주형 데이터를 정수로 변환
category_to_index = {category: index for index, category in enumerate(categories)}
data_indices = [category_to_index[datum] for datum in data]

# 변환된 정수를 원-핫 인코딩
one_hot_encoding = torch.zeros(len(data), len(categories))
one_hot_encoding.scatter_(1, torch.tensor(data_indices).unsqueeze(1), 1)

print(one_hot_encoding)

위 코드에서, categories에는 가능한 모든 범주를 담은 리스트를 생성합니다. 그리고 data는 원-핫 인코딩을 수행할 범주형 데이터를 담고 있는 리스트입니다.

먼저, category_to_index라는 딕셔너리를 생성하여 각 범주를 해당하는 정수로 매핑합니다. 그리고 data_indicesdata 리스트의 각 요소를 해당하는 정수로 변환한 값들을 담은 리스트입니다.

마지막으로, torch.zeros를 사용하여 모든 요소가 0인 텐서를 초기화한 후, scatter_ 함수를 호출하여 변환된 정수에 해당하는 위치에 1을 할당하여 원-핫 인코딩을 수행합니다.

입력된 범주형 데이터에 대한 원-핫 인코딩 결과를 출력합니다.

이 코드를 실행하면 다음과 같은 출력을 얻을 수 있습니다.

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.]])

위 출력은 data 리스트의 각 요소를 원-핫 인코딩한 결과입니다. 첫 번째 요소인 ‘사과’는 [1, 0, 0, 0]으로 인코딩되었고, 두 번째 요소인 ‘바나나’는 [0, 1, 0, 0]으로 인코딩되었음을 의미합니다.

이와 같은 방법으로 PyTorch에서 범주형 데이터를 처리하고 원-핫 인코딩할 수 있습니다.

참고 자료: