[파이썬] TensorFlow에서 TFRecord 데이터 형식
먼저 TFRecord를 생성하기 위해서는 데이터를 직렬화하는 과정이 필요합니다. 예를 들어, 이미지 데이터셋을 TFRecord 형식으로 변환하는 과정을 살펴보겠습니다.
import tensorflow as tf
import os
# 데이터셋 디렉토리 경로
data_dir = 'path/to/dataset'
# TFRecord 파일 경로
tfrecord_file = 'path/to/tfrecord.tfrecord'
# 이미지 데이터셋 로드
image_paths = [os.path.join(data_dir, filename) for filename in os.listdir(data_dir)]
labels = ['cat', 'dog', 'bird']
# TFRecord 파일 생성
with tf.io.TFRecordWriter(tfrecord_file) as writer:
for image_path, label in zip(image_paths, labels):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image)
# 이미지를 직렬화하여 TFRecord에 저장
image_bytes = tf.io.encode_jpeg(image).numpy()
feature = {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes])),
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label.encode()]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
위 코드는 디렉토리에 있는 이미지 데이터셋을 읽어 TFRecord 파일로 저장하는 과정입니다. TFRecordWriter
를 사용하여 TFRecord 파일을 생성하고, tf.train.Example
을 사용하여 각 데이터를 직렬화한 후 TFRecord에 저장합니다.
다음은 TFRecord 데이터를 읽는 예제 코드입니다.
import tensorflow as tf
# TFRecord 파일 경로
tfrecord_file = 'path/to/tfrecord.tfrecord'
# TFRecord 파일 열기
dataset = tf.data.TFRecordDataset(tfrecord_file)
# 데이터 파싱 함수 정의
def _parse_function(record):
features = {
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.string)
}
parsed_record = tf.io.parse_single_example(record, features)
# 이미지 디코딩
image = tf.image.decode_jpeg(parsed_record['image'])
image = tf.image.resize(image, [224, 224])
image = tf.image.convert_image_dtype(image, tf.float32)
# 레이블 디코딩
label = parsed_record['label']
return image, label
# 데이터셋 파싱 및 전처리
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(100)
dataset = dataset.batch(32)
# 학습 루프
for image, label in dataset:
# 모델 학습 또는 추론을 수행
pass
위 코드는 TFRecord 파일을 열고 데이터를 파싱하는 과정을 보여줍니다. TFRecordDataset
로 TFRecord 파일을 열고, 파싱 함수를 정의하여 데이터를 전처리합니다. 그 후, 데이터셋을 셔플하고 배치 단위로 나누어 모델의 학습 또는 추론에 활용할 수 있습니다.
TFRecord는 TensorFlow에서 대용량 데이터셋을 효율적으로 처리하기 위한 유용한 데이터 형식입니다. 이 포스트를 통해 TFRecord 데이터를 생성하고 읽는 방법을 학습할 수 있었습니다. TensorFlow의 다양한 기능들을 활용하여 데이터 처리와 모델 학습을 보다 효율적으로 수행할 수 있습니다.