https://keras-zh.readthedocs.io/preprocessing/image/
https://www.cnblogs.com/hutao722/p/10075150.html
一般來(lái)說(shuō),都會(huì)使用ImageDataGenerator.flow()方法構(gòu)造一個(gè)迭代器,以提供給model.fit_generator()方法進(jìn)行訓(xùn)練。然而缺點(diǎn)是需要一次提供所有數(shù)據(jù)到內(nèi)存中,不適合大量圖片訓(xùn)練集。因此可使用ImageDataGenerator.flow_from_directory()方法。但是該方法依舊限制較多。
因此我們可不使用ImageDataGenerator提供的構(gòu)造迭代器方法,而是和之前文章中一樣,自定義一個(gè)迭代器繼承Sequence,然后在__getitem__()方法中僅僅加入ImageDataGenerator.random_transform()方法去進(jìn)行圖像增強(qiáng)。
注意:
- 需要先從tfrecordDataset的Iterator中獲取圖像,之后對(duì)圖像使用random_transform()方法:
這個(gè)方法設(shè)計(jì)的是針對(duì)0-255原圖而言的。從源碼上來(lái)看,它會(huì)組合不同的線性變換矩陣,矩陣相乘后,最后應(yīng)用于原圖上,比起在大的原圖上一步步做矩陣乘法的效率高。 - 再對(duì)圖像使用preprocess預(yù)處理函數(shù):
根據(jù)預(yù)訓(xùn)練模型的配置不同,預(yù)處理preprocess函數(shù)是不同的。比如from keras.applications.resnet50 import preprocess_input與
from keras.applications.inception_v3 import preprocess_input這兩個(gè),雖然本質(zhì)最后都會(huì)調(diào)用keras.application.imagenet_utils.preprocess_input()函數(shù),但是參數(shù)不同。
具體來(lái)說(shuō),Resnet50的預(yù)處理中,imagenet_utils.preprocess_input()的mode參數(shù)為默認(rèn)的"caffe"因此處理為調(diào)整為BGR后減去imagenet的通道均值。而Inception_v3中則會(huì)縮放到-1~1,這與預(yù)訓(xùn)練有關(guān)。因此需要將圖像增強(qiáng)置于該步驟之前。以下代碼,圖像增強(qiáng)和預(yù)處理的步驟就反了:
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import keras
import keras.backend as K
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
# from keras.applications.inception_v3 import preprocess_input
import numpy as np
import matplotlib.pyplot as plt
def main():
img = image.load_img('Data_sets/s_IMG_3488.jpg', target_size=(224, 224))
x = image.img_to_array(img)
print(x.shape) # (224, 224, 3)
print(x[0, 0]) # [238. 255. 251.]
x = preprocess_input(x) # 默認(rèn)Resnet使用caffee模式,inception使用tf模式
print(x[0, 0]) # [147.061 138.22101 114.32 ] 使用的是resnet的預(yù)處理配置
plt.imshow(x.astype(np.uint8)) # float默認(rèn)為0-1顯示,int默認(rèn)為0-255顯示
plt.show()
datagen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
x2 = datagen.random_transform(x)
print(x2[0, 0]) # 每次隨機(jī)組合:[114.885315 105.54411 110.899704]
plt.imshow(x2.astype(np.uint8))
plt.show()
# 必須在圖像增強(qiáng)之后,random_transform只接受3D tensor
x2 = np.expand_dims(x2, 0)
# ...model.predict(x2)...
if __name__ == '__main__':
main()