[파이썬] TensorFlow에서 TFRecord 데이터 형식

먼저 TFRecord를 생성하기 위해서는 데이터를 직렬화하는 과정이 필요합니다. 예를 들어, 이미지 데이터셋을 TFRecord 형식으로 변환하는 과정을 살펴보겠습니다.

import tensorflow as tf
import os

# 데이터셋 디렉토리 경로
data_dir = 'path/to/dataset'

# TFRecord 파일 경로
tfrecord_file = 'path/to/tfrecord.tfrecord'

# 이미지 데이터셋 로드
image_paths = [os.path.join(data_dir, filename) for filename in os.listdir(data_dir)]
labels = ['cat', 'dog', 'bird']

# TFRecord 파일 생성
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    for image_path, label in zip(image_paths, labels):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image)

        # 이미지를 직렬화하여 TFRecord에 저장
        image_bytes = tf.io.encode_jpeg(image).numpy()
        
        feature = {
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes])),
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label.encode()]))
        }
        
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example.SerializeToString())

위 코드는 디렉토리에 있는 이미지 데이터셋을 읽어 TFRecord 파일로 저장하는 과정입니다. TFRecordWriter를 사용하여 TFRecord 파일을 생성하고, tf.train.Example을 사용하여 각 데이터를 직렬화한 후 TFRecord에 저장합니다.

다음은 TFRecord 데이터를 읽는 예제 코드입니다.

import tensorflow as tf

# TFRecord 파일 경로
tfrecord_file = 'path/to/tfrecord.tfrecord'

# TFRecord 파일 열기
dataset = tf.data.TFRecordDataset(tfrecord_file)

# 데이터 파싱 함수 정의
def _parse_function(record):
    features = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.string)
    }
    
    parsed_record = tf.io.parse_single_example(record, features)
    
    # 이미지 디코딩
    image = tf.image.decode_jpeg(parsed_record['image'])
    image = tf.image.resize(image, [224, 224])
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    # 레이블 디코딩
    label = parsed_record['label']
    
    return image, label

# 데이터셋 파싱 및 전처리
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(100)
dataset = dataset.batch(32)

# 학습 루프
for image, label in dataset:
    # 모델 학습 또는 추론을 수행
    pass

위 코드는 TFRecord 파일을 열고 데이터를 파싱하는 과정을 보여줍니다. TFRecordDataset로 TFRecord 파일을 열고, 파싱 함수를 정의하여 데이터를 전처리합니다. 그 후, 데이터셋을 셔플하고 배치 단위로 나누어 모델의 학습 또는 추론에 활용할 수 있습니다.

TFRecord는 TensorFlow에서 대용량 데이터셋을 효율적으로 처리하기 위한 유용한 데이터 형식입니다. 이 포스트를 통해 TFRecord 데이터를 생성하고 읽는 방법을 학습할 수 있었습니다. TensorFlow의 다양한 기능들을 활용하여 데이터 처리와 모델 학습을 보다 효율적으로 수행할 수 있습니다.