[파이썬] TensorFlow에서 tf.data API

tf.data API 개요

tf.data API는 TensorFlow 2.0부터 새로 도입된 데이터 파이프라인 API입니다. 이 API를 사용하면 데이터를 로드, 변환 및 전처리하는 과정을 단순화할 수 있으며, 모델 훈련에 필요한 데이터를 고성능으로 공급할 수 있습니다.

tf.data API의 핵심 개념은 tf.data.Dataset입니다. Dataset은 일련의 데이터 요소를 나타내며, 이를 통해 데이터를 효율적으로 읽고 전처리할 수 있습니다. Dataset을 생성하고 변환하기 위해 다양한 함수들을 사용할 수 있습니다.

데이터셋 생성

가장 기본적인 방법으로는 디스크에서 데이터를 읽어와 tf.data.Dataset을 만들 수 있습니다. 아래 코드는 CSV 파일을 읽어 Dataset을 생성하는 예시입니다.

import tensorflow as tf

dataset = tf.data.experimental.CsvDataset("data.csv", [tf.int32, tf.float32], header=True)

또한 Python 리스트, NumPy 배열, 텐서 등의 데이터로부터도 Dataset을 생성할 수 있습니다. 예를 들어, 아래 코드는 Python 리스트를 이용해 Dataset을 생성하는 방법입니다.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5])

데이터 변환

Dataset은 데이터를 변환하기 위해 다양한 함수들을 제공합니다. 이러한 변환 함수들은 데이터를 섞거나 반복하는 등 다양한 처리를 수행할 수 있습니다.

# 데이터를 섞고, 배치로 만들고, 반복하는 변환
dataset = dataset.shuffle(1000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

변환 함수들은 연결될 수 있으며, 체인 형태로 사용할 수 있습니다. 이를 통해 데이터 파이프라인을 유연하게 구성할 수 있습니다.

데이터 사용

Dataset을 만들고 변환한 후에는 모델 훈련에 사용할 수 있습니다. Dataset은 일련의 데이터 요소를 반환하는 이터레이터와 함께 사용됩니다. 이터레이터를 사용하여 데이터를 하나씩 추출할 수 있습니다.

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    while True:
        try:
            data = sess.run(next_element)
            # 데이터를 사용하여 모델 훈련 또는 추론 수행
        except tf.errors.OutOfRangeError:
            break

결론

이번 포스트에서는 TensorFlow에서 tf.data API를 사용하는 방법과 주요 기능들에 대해 알아보았습니다. tf.data API는 데이터 파이프라인을 간편하게 구성하고 모델의 성능을 최적화하는 데 도움이 됩니다. 자세한 내용은 TensorFlow 공식 문서를 참조하시기 바랍니다.