[DL] ImageDataGenerator의 사용 - 1

ImageDataGenerator1의 사용(개와 고양이)

TF2.X 에서 이미지 증식을 위해서 사용하는 ImageDataGenerator를 간단히 알아보고 image를 확인해 본다.

사용 Library

사용하는 library들이다.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.preprocessing.image import ImageDataGenerator

데이터셋

train data와 validation date를 정의한다.

train_dir      = './data/cat_dog_full/train'
validation_dir = './data/cat_dog_full/validation'

# 전체 데이터가 25,000장 (고양이 : 12,500, 개 : 12,500)
# 댕댕이 이미지 train      : 7,000장
# 댕댕이 이미지 validation : 3,000장
# 댕댕이 이미지 test       : 2,500장

ImageDataGenerator

ImageDataGenerator 에 대해서 간단히 알아본다. 데이터 증식(augmentation)은 아래에서 진행한다.

train_datagen = ImageDataGenerator(rescale=1/255) 
validation_datagen=ImageDataGenerator(rescale=1/255)

train_generator = train_datagen.flow_from_directory(train_dir,
                                                    classes=['cats', 'dogs'],
                                                    batch_size = 20,
                                                    target_size = (150, 150)
                                                    class_mode='binary')

validation_generator = validation_datagen.flow_from_directory(validataion_dir,
                                                              classes=['cats', 'dogs'],
                                                              batch_size = 20,
                                                              target_size = (150, 150),
                                                              class_mode = 'binary')
for x_data, t_data in train_generator:
    print(x_data.shape)  # (20, 150, 150, 3)
    print(type(x_data))  # <class 'numpy.ndarray'>
    print(t_data)        # [1. 0. 0. 1. 1. 1. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 1. 1. 0. 0.]
    # 0 : 고양이,  1 : 댕댕이
fig = plt.figure(figsize=(10, 10))
axs = []

for x_data, t_data in train_generator:
    for idx, img enumerate(x_data):
        axs.append(fig.add_subplot(5,4,idx+1))
        plt.imshow(axs[idx])
    break

image-20201107144516477

data augmentation (데이터 증식)

ImageDataGenerator를 사용해 데이터 증식하는 법을 알아본다. 여기서는 .flow_from_directory가 아닌 .flow를 사용해 알아본다.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
img = image.load_img('./data/cat_dog_small/train/cats/cat.3.jpg', target_size=(150, 150)) print(type(img)) # <class 'PIL.Image.Image'>

plt.imshow(img)
plt.show()

image-20201107165006652

x = image.img_to_array(img)
print('x_shape : ',x.shape)  # x_shape :  (150, 150, 3)
print('x_type : ',type(x))   # x_type :  <class 'numpy.ndarray'>
print(x.dtype)               # float32
# plt.imshow(x) : error 발생, 원인 : imshow는 0-1 float 또는 0-255 int 값을 input값으로 받음
plt.imshow(np.unit8(x))
plt.show()

x = x.reshape((1,) + x.shape)

image-20201107165745422

datagen = ImageDataGenerator(rotation_range = 90,
                             width_shift_range=0.4,
                             height_shift_range=0.4,
                             vertical_flip =True,
                             horizontal_flip =True)

idx = 0
fig = plt.figure(figsize=(10, 10))
axs = []
for batch in datage.flow(x , batch_size=1): # 여기서 batch는 x가 됨
    axs.append(flg.add_subplot(5, 4, idx+1))
    axs[idx].imshow(image.array_to_img(batch[0]))
    idx += 1
    if idx%20 == 0:
        break
fig.tight_layout()
plt.show()

image-20201107172032692