TensorFlow利用卷積神經(jīng)網(wǎng)絡(luò)在谷歌inception_v3模型基礎(chǔ)上解決花朵分類問題

本篇更多的是在代碼實戰(zhàn)方向,不會涉及太多的理論。本文主要針對TensorFlow和卷積神經(jīng)網(wǎng)絡(luò)有一定基礎(chǔ)的同學(xué),并對圖像處理有一定的了解。

閱讀本文你大概需要以下知識:

1.TensorFlow基礎(chǔ)
2.TensorFlow實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的前向傳播過程
3.TFRecord數(shù)據(jù)格式
4.Dataset的使用
5.Slim的使用

好了廢話不多說,下面開始。

一.數(shù)據(jù)準(zhǔn)備

首先我們需要有一個讓我們訓(xùn)練的數(shù)據(jù)集,這里谷歌已經(jīng)幫我們做好了。這里要把數(shù)據(jù)集下載下來,打開命令行,執(zhí)行如下命令:

wget http://download.tensorflow.org/example_image/flower_photo.tgz
//解壓
tar xzf flower_photos.tgz

這里需要注意的是,文件最好是下載到你的工程目錄下方便你的讀取。什么?你還不會搭建TensorFlow程序?請移步https://www.tensorflow.org/install/
選擇自己的操作系統(tǒng),在這里我的是macOS。我使用的是Virtualenv來搭建TensorFlow運行環(huán)境。
數(shù)據(jù)集下載并解壓后,我們可以看到大概是這個樣子

每一個文件夾里都是一個種類的花的圖片,這里總共有五種花。
好了,數(shù)據(jù)有了?接下來該怎么辦呢?當(dāng)然是把數(shù)據(jù)進行預(yù)處理拉,你不會覺得我們的TensorFlow可以直接識別這些圖片進行訓(xùn)練吧,hhhhhh。

二.數(shù)據(jù)預(yù)處理

接下來我們在目錄下新建pre_data.python文件。TensorFlow對圖片做處理一般是生成TFRecord文件。什么是TFRecord?后面我們會講到。

首先我們要引入我們需要的庫。

# glob模塊的主要方法就是glob,該方法返回所有匹配的文件路徑列表(list)
import glob
#os.path生成路徑方便glob獲取
import os.path
#這里主要用到隨機數(shù)
import numpy as np
#引入tensorflow框架
import tensorflow as tf
#引入gflie對圖片做處理
from tensorflow.python.platform import gfile

相關(guān)庫在我們這個程序中的功能都作了簡單介紹,下面用到的時候我們會更加詳細(xì)的說明。

大家都知道我們的數(shù)據(jù)集一般分訓(xùn)練,測試和驗證數(shù)據(jù)集。觀察上面的數(shù)據(jù)集,谷歌只是給出了每一種花的圖片,并沒有給去哪些我訓(xùn)練,哪些是測試,哪些是驗證數(shù)據(jù)集。所以在這里我們要進行劃分。

#輸入圖片地址
INPUT_DATA = '../../flower_photos'
#訓(xùn)練數(shù)據(jù)集
OUTPUT_FILE = './path/to/output.tfrecords'
#測試數(shù)據(jù)集
OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'
#驗證數(shù)據(jù)集
OUTPUT_VALIDATION_FILE = './path/to/output_validation.tfrecords'
#測試數(shù)據(jù)和驗證數(shù)據(jù)的比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10

關(guān)于VALIDATION_PERCENTAGE和TEST_PERCENTAGE這兩個常量,我們在后面的例子會給出。

下面我們就來定義處理數(shù)據(jù)的方法:

def create_image_lists(sess,testing_percentage,validation_percentage):
    #拿到INPUT_DATA文件夾下的所有目錄(包括root)
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
    #如果是root_dir不需要做處理
    is_root_dir = True
    #定義圖片對應(yīng)的標(biāo)簽,從0-4分別代表不同的花
    current_label = 0
    #寫入TFRecord的數(shù)據(jù)需要首先定義writer
    #這里定義三個writer分別存儲訓(xùn)練,測試和驗證數(shù)據(jù)
    writer = tf.python_io.TFRecordWriter(OUTPUT_FILE)
    writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE)
    writer_validation = tf.python_io.TFRecordWriter(OUTPUT_VALIDATION_FILE)
    #循環(huán)目錄
    for sub_dir in sub_dirs:
        if is_root_dir:
            #跳過根目錄
            is_root_dir = False
            continue
        #定義空數(shù)組來裝圖片路徑
        file_list = []
        #生成查找路徑
        dir_name = os.path.basename(sub_dir)
        file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "jpg")
        # extend合并兩個數(shù)組
        # glob模塊的主要方法就是glob,該方法返回所有匹配的文件路徑列表(list)
        # 比如:glob.glob(r’c:*.txt’) 這里就是獲得C盤下的所有txt文件
        file_list.extend(glob.glob(file_glob))
        #路徑下沒有文件就跳過,不繼續(xù)操作
        if not file_list: continue
        #這里我定義index來打印當(dāng)前進度
        index = 0
        #file_list此時是圖片路徑列表
        for file_name in file_list:
            #使用gfile從路徑中讀取圖片
            image_raw_data = gfile.FastGFile(file_name, 'rb').read()
            #對圖像解碼,解碼結(jié)果為一個張量
            image = tf.image.decode_jpeg(image_raw_data)

            #對圖像矩陣進行歸一化處理
            #因為為了將圖片數(shù)據(jù)能夠保存到 TFRecord 結(jié)構(gòu)體中
            #所以需要將其圖片矩陣轉(zhuǎn)換成 string
            #所以為了在使用時能夠轉(zhuǎn)換回來
            #這里確定下數(shù)據(jù)格式為 tf.float32  
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            # 將圖片轉(zhuǎn)化成299*299方便模型處理
            image = tf.image.resize_images(image, [299, 299])
            #為了拿到圖片的真實數(shù)據(jù)這里我們要運行一個session op
            image_value = sess.run(image)
           
            pixels = image_value.shape[1]
            #存儲在TFrecord里面的不能是array的形式
            #所以我們需要利用tostring()將上面的矩陣
            #轉(zhuǎn)化成字符串
            #再通過tf.train.BytesList轉(zhuǎn)化成可以存儲的形式
            image_raw = image_value.tostring()

            #存到features
            #隨機劃分測試集和訓(xùn)練集
            #這里存入TFRecord三個數(shù)據(jù),圖像的pixels像素
            #圖像原張量,這里我們需要轉(zhuǎn)成string
            #以及當(dāng)前圖像對應(yīng)的標(biāo)簽
            example = tf.train.Example(features=tf.train.Features(feature={
                'pixels': _int64_feature(pixels),
                'label': _int64_feature(current_label),
                'image_raw': _bytes_feature(image_raw)
            }))
            chance = np.random.randint(100)
            #隨機劃分?jǐn)?shù)據(jù)集
            if chance < validation_percentage:
                writer_validation.write(example.SerializeToString())
            elif chance < (testing_percentage+validation_percentage):
                writer_test.write(example.SerializeToString())
            else:
                writer.write(example.SerializeToString())
            # print('example',index)
            index = index + 1

        #每一個文件夾下的所有圖片都是一個類別
        #所以這里每遍歷完一個文件夾,標(biāo)簽就增加1
        current_label += 1

    writer.close()
    writer_validation.close()
    writer_test.close()

運行上述程序需要一定時間,我的電腦比較爛,大概跑了三十分鐘左右。這時候在你的./path/to目錄下可以看到output.tfrecords,output_test.tfrecords,output_validation.tfrecords三個文件,分別存放了訓(xùn)練,測試和驗證數(shù)據(jù)集。上述代碼將所有圖片劃分成訓(xùn)練、驗證和測試數(shù)據(jù)集。并且把圖片從原始的jpg格式轉(zhuǎn)換成inception-v3模型需要的299 * 299 * 3的數(shù)字矩陣。在數(shù)據(jù)處理完畢之后,通過以下命令可以下載谷歌提供好的Inception_v3模型。

wget http://download.tensorflow.org/models/inception_v3_2016_08_26.tar.gz
//解壓之后可以得到訓(xùn)練好的模型文件inception_v3.ckpt
tar xzf inception_v3_2016_08

二.訓(xùn)練

當(dāng)新的數(shù)據(jù)集和已經(jīng)訓(xùn)練好的模型都準(zhǔn)備好之后,我們來寫代碼在谷歌inception_v3的基礎(chǔ)上訓(xùn)練新數(shù)據(jù)集。

首先同樣我們導(dǎo)入相關(guān)的庫并且定義相關(guān)常量。在這里我們通過slim工具來直接加載模型,而不用自己再定義前向傳播過程。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 加載通過TensorFlow-Silm定義好的 inception_v3模型
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

# 輸入數(shù)據(jù)文件
INPUT_DATA = './path/to/output.tfrecords'
# 驗證數(shù)據(jù)集
VALIDATION_DATA = './path/to/output_validation.tfrecords'
# 保存訓(xùn)練好的模型的路徑
ls = './path/to/save_model'
# 谷歌提供的訓(xùn)練好的模型文件地址
CKPT_FILE = './path/to/inception_v3.ckpt'
TRAIN_FILE = './path/to/save_model'

# 定義訓(xùn)練中使用的參數(shù)
LEARNING_RATE = 0.01
#組合batch的大小
BATCH = 32

#用于one_hot函數(shù)輸出概率分布
N_CLASSES = 5
#打亂順序,并設(shè)置出隊和入隊中元素最少的個數(shù),這里是10000個
shuffle_buffer = 10000

# 不需要從谷歌模型中加載的參數(shù),這里就是最后的全連接層。因為輸出類別不一樣,所以最后全連接層的參數(shù)也不一樣
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
# 需要訓(xùn)練的網(wǎng)絡(luò)層參數(shù) 這里就是最后的全連接層
TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'

接下來我們定義幾個輔助方法。首先因為我們的數(shù)據(jù)存在TFRecord里,需要定義方法從TFRecord解析數(shù)據(jù)。

def parse(record):
    features = tf.parse_single_example(
        record,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
            'pixels': tf.FixedLenFeature([], tf.int64)
        }
    )
    #decode_raw用于解析TFRecord里面的字符串
    decoded_image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = features['label']
    #要注意這里的decoded_image并不能直接進行reshape操作
    #之前我們在存儲的時候,把圖片進行了tostring()操作
    #這會導(dǎo)致圖片的長度在原來基礎(chǔ)上*8
    #后面我們要用到numpy的fromstring來處理
    return decoded_image, label

接下來定義兩個方法。因為我們已經(jīng)下載了谷歌訓(xùn)練好的inception_v3模型的參數(shù),下面我們需要定義兩個方法從里面加載參數(shù)。

#直接從inception_v3.ckpt中讀取的參數(shù)
def get_tuned_variables():
    #strip刪除頭尾字符,默認(rèn)為空格
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
    variables_to_restore = []
    #這里給出了所有slim模型下的參數(shù)
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
            if not excluded:
                variables_to_restore.append(var)
        return variables_to_restore

#需要重新訓(xùn)練的參數(shù)
def get_trainable_variables():
    #strip刪除頭尾字符,默認(rèn)為空格
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
    variables_to_train = []
    # 枚舉所有需要訓(xùn)練的參數(shù)前綴,并通過這些前綴找到所有的參數(shù)。
    for scope in scopes:
      #從TRAINABLE_VARIABLES集合中獲取名為scope的變量
      #也就是我們需要重新訓(xùn)練的參數(shù)
        variables = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train

這里我們就寫完了所需要的工具函數(shù),接下來我們定義主函數(shù)。主函數(shù)主要完成數(shù)據(jù)讀取,模型定義,通過模型得出前向傳播結(jié)果,通過損失函數(shù)計算損失,最后把損失交給優(yōu)化器做處理。首先我們先來完成數(shù)據(jù)讀取的代碼,這里我們使用的是TensorFlow高層API Dataset。不清楚的可以去看一下Dataset的用法。

這里我們在訓(xùn)練的同時也對模型做了驗證。所以我們需要加載訓(xùn)練和驗證數(shù)據(jù)

#讀取測試數(shù)據(jù)
    #利用TFRecordDataset讀取TFRecord文件
    dataset = tf.data.TFRecordDataset([INPUT_DATA])
    #解析TFRecord
    dataset = dataset.map(parse)
    #把數(shù)據(jù)打亂順序并組裝成batch
    dataset = dataset.shuffle(shuffle_buffer).batch(BATCH)
    #定義數(shù)據(jù)重復(fù)的次數(shù)
    NUM_EPOCHS = 10
    dataset = dataset.repeat(NUM_EPOCHS)
    #定義迭代器來獲取處理后的數(shù)據(jù)
    iterator = dataset.make_one_shot_iterator()
    #迭代器開始迭代
    img, label = iterator.get_next()

    #讀取驗證數(shù)據(jù)(同上)
    valida_dataset = tf.data.TFRecordDataset([VALIDATION_DATA])
    valida_dataset = valida_dataset.map(parse)
    valida_dataset = valida_dataset.batch(BATCH)
    valida_iterator = valida_dataset.make_one_shot_iterator()
    valida_img,valida_label = valida_iterator.get_next()

    #定義inception-v3的輸入,images為輸入圖片,label為每一張圖片對應(yīng)的標(biāo)簽
    #再解釋下每一個維度 None為batch的大小,299為圖片大小,3為通道
    images = tf.placeholder(tf.float32,[None,299,299,3],name='input_images')
    labels = tf.placeholder(tf.int64,[None],name='labels')

要注意上述定義的只是tensorflow的張量,保存的只是計算過程并沒有具體的數(shù)據(jù)。只有運行session之后才會拿到具體的數(shù)據(jù)。

下面我們來通過slim加載inception-v3模型

 #定義inception-v3模型結(jié)構(gòu) inception_v3.ckpt里只有參數(shù)的取值
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        #logits  inception_v3前向傳播得到的結(jié)果
        logits,_ = inception_v3.inception_v3(images,num_classes=N_CLASSES)
        #獲取需要訓(xùn)練的變量
        trainable_variables = get_trainable_variables()
        #這里用交叉熵作為損失函數(shù),注意一下tf.losses.softmax_cross_entropy的參數(shù)
        # tf.losses.softmax_cross_entropy(
        #     onehot_labels,  # 注意此處參數(shù)名就叫 onehot_labels
        #     logits,
        #     weights=1.0,
        #     label_smoothing=0,
        #     scope=None,
        #     loss_collection=tf.GraphKeys.LOSSES,
        #     reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
        # )
        #這里要把labels轉(zhuǎn)成one_hot類型,logits就是神經(jīng)網(wǎng)絡(luò)的輸出        
        tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0)
        #把計算的損失交給優(yōu)化器處理
        train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())

        #計算正確率。
        with tf.name_scope('evaluation'):
            correct_prediction = tf.equal(tf.argmax(logits,1),labels)
            evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        #定義加載模型的函數(shù)
        load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True)
        #定義保存新的訓(xùn)練好的模型的函數(shù)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            #初始化所有變量
            init = tf.global_variables_initializer()
            sess.run(init)
            print('Loading tuned variables from %s'%CKPT_FILE)
            #加載谷歌已經(jīng)訓(xùn)練好的模型
            load_fn(sess)
            step = 0;
            #在這里我們用一個while來循環(huán)訓(xùn)練,直到dataset里沒有數(shù)據(jù)就結(jié)束循環(huán)
            while True:
                try:
                    if step % 30  == 0 or step + 1 == STEPS:
                      #每30輪輸出一次正確率
                        if step != 0:
                            #每30輪保存一次當(dāng)前模型的參數(shù),以便中途訓(xùn)練中斷可以繼續(xù)
                            saver.save(sess,TRAIN_FILE,global_step=step)
                       #運行session拿到真實圖片的數(shù)據(jù)
                        valida_img_batch,valida_label_batch = sess.run([valida_img,valida_label])
                        #上面有提到TFRecord里圖片數(shù)據(jù)被轉(zhuǎn)成了string,在這里轉(zhuǎn)回來
                        valida_img_batch = np.fromstring(valida_img_batch, dtype=np.float32)
                        #把圖片張量拉成新的維度
                        valida_img_batch = tf.reshape(valida_img_batch, [32, 299, 299, 3])
                        #用session運行上述操作,得到處理后的圖片張量
                        valida_img_batch = sess.run(valida_img_batch)
                        #把圖片張量傳到feed_dict算出正確率并顯示
                        validation_accuracy = sess.run(evaluation_step,feed_dict={
                            images:valida_img_batch,
                            labels:valida_label_batch
                        })
                        print('Step %d: Validation accurary = %.1f%%'%(step,validation_accuracy*100.0))
                    #下面是對訓(xùn)練數(shù)據(jù)的操作,同上
                    img_batch,label_batch = sess.run([img,label])
                    img_batch = np.fromstring(img_batch, dtype=np.float32)
                    img_batch = tf.reshape(img_batch, [32,299, 299, 3])
                    img_batch = sess.run(img_batch)

                    sess.run(train_step,feed_dict={
                        images:img_batch,
                        labels:label_batch
                    })
                    #step僅僅用于記錄
                    step = step + 1
                except tf.errors.OutOfRangeError:
                    break

運行上述程序開始訓(xùn)練。在這里我暫時是使用cpu進行訓(xùn)練,訓(xùn)練過程大約3小時,可以得到類型下面的結(jié)果。

step 0:Validation accuracy = 12.5%
step 30:Validation accuracy = 22.2%
step 60:Validation accuracy = 63.2%
step 90:Validation accuracy = 79.8%
step 120:Validation accuracy = 86.4%
step 150:Validation accuracy = 88.5%
.....

以上就是我使用谷歌Inception-v3模型訓(xùn)練新的數(shù)據(jù)集的全部內(nèi)容。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容