Keras 數(shù)據(jù)增強(qiáng) ImageDataGenerator

https://keras-zh.readthedocs.io/preprocessing/image/
https://www.cnblogs.com/hutao722/p/10075150.html

一般來(lái)說(shuō),都會(huì)使用ImageDataGenerator.flow()方法構(gòu)造一個(gè)迭代器,以提供給model.fit_generator()方法進(jìn)行訓(xùn)練。然而缺點(diǎn)是需要一次提供所有數(shù)據(jù)到內(nèi)存中,不適合大量圖片訓(xùn)練集。因此可使用ImageDataGenerator.flow_from_directory()方法。但是該方法依舊限制較多。
因此我們可不使用ImageDataGenerator提供的構(gòu)造迭代器方法,而是和之前文章中一樣,自定義一個(gè)迭代器繼承Sequence,然后在__getitem__()方法中僅僅加入ImageDataGenerator.random_transform()方法去進(jìn)行圖像增強(qiáng)。

注意

  1. 需要先從tfrecordDataset的Iterator中獲取圖像,之后對(duì)圖像使用random_transform()方法:
    這個(gè)方法設(shè)計(jì)的是針對(duì)0-255原圖而言的。從源碼上來(lái)看,它會(huì)組合不同的線性變換矩陣,矩陣相乘后,最后應(yīng)用于原圖上,比起在大的原圖上一步步做矩陣乘法的效率高。
  2. 再對(duì)圖像使用preprocess預(yù)處理函數(shù):
    根據(jù)預(yù)訓(xùn)練模型的配置不同,預(yù)處理preprocess函數(shù)是不同的。比如from keras.applications.resnet50 import preprocess_input
    from keras.applications.inception_v3 import preprocess_input這兩個(gè),雖然本質(zhì)最后都會(huì)調(diào)用keras.application.imagenet_utils.preprocess_input()函數(shù),但是參數(shù)不同。
    具體來(lái)說(shuō),Resnet50的預(yù)處理中,imagenet_utils.preprocess_input()的mode參數(shù)為默認(rèn)的"caffe"因此處理為調(diào)整為BGR后減去imagenet的通道均值。而Inception_v3中則會(huì)縮放到-1~1,這與預(yù)訓(xùn)練有關(guān)。因此需要將圖像增強(qiáng)置于該步驟之前。以下代碼,圖像增強(qiáng)和預(yù)處理的步驟就反了:
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import keras
import keras.backend as K
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
# from keras.applications.inception_v3 import preprocess_input
import numpy as np
import matplotlib.pyplot as plt

def main():
    img = image.load_img('Data_sets/s_IMG_3488.jpg', target_size=(224, 224))

    x = image.img_to_array(img)
    print(x.shape)  # (224, 224, 3)
    print(x[0, 0])  # [238. 255. 251.]

    x = preprocess_input(x)  # 默認(rèn)Resnet使用caffee模式,inception使用tf模式
    print(x[0, 0])  # [147.061   138.22101 114.32   ] 使用的是resnet的預(yù)處理配置
    plt.imshow(x.astype(np.uint8))  # float默認(rèn)為0-1顯示,int默認(rèn)為0-255顯示
    plt.show()

    datagen = ImageDataGenerator(
        rescale=1. / 255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True)
    x2 = datagen.random_transform(x)
    print(x2[0, 0])  # 每次隨機(jī)組合:[114.885315 105.54411  110.899704]
    plt.imshow(x2.astype(np.uint8))
    plt.show()

    # 必須在圖像增強(qiáng)之后,random_transform只接受3D tensor
    x2 = np.expand_dims(x2, 0)

    # ...model.predict(x2)...

if __name__ == '__main__':
    main()

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容