此文翻譯自:A quick complete tutorial to save and restore Tensorflow models
這篇tensorflow的教程,將解釋:
1. Tensorflow模型是什么樣的?
2. 如何保存一個Tensorflow模型??
3. 如何恢復(fù)一個Tensorflow模型,用于預(yù)測或者遷移學(xué)習(xí)?
4. 如何利用導(dǎo)入的預(yù)訓(xùn)練好模型,進(jìn)行fine-tuning或改造。
這篇教程,假定讀者對神經(jīng)網(wǎng)絡(luò)的訓(xùn)練有基本的了解。如果不是,請先閱讀Tensorflow Tutorial 2: image classifier using convolutional neural network,然后閱讀本文。
1. Tensorflow 模型是什么?
當(dāng)訓(xùn)練完一個神經(jīng)網(wǎng)絡(luò),你就會保存它,以便日后使用和產(chǎn)品發(fā)布。所以,Tensorflow模型是如何表示的呢?Tensorflow模型主要包含網(wǎng)絡(luò)設(shè)計(Graph)和訓(xùn)練好的參數(shù)的值。因此,Tensorflow模型包含兩個主要的文件:
a) Meta graph:
這是一個協(xié)議緩沖區(qū)(protocol buffer,google推出的數(shù)據(jù)存儲格式),保存完整的Tensorflow的graph信息;例如:所有的變量,操作(ops),集合(collection)等。此文件帶有.meta擴展。
b) Checkpoint file:
它是一個二進(jìn)制文件,包含所有的權(quán)重,偏置,導(dǎo)數(shù)和其他保存變量的值。文件后綴為: .ckpt。但自從0.11版本之后,Temsorflow作了改變,不再是一個單獨的.ckpt文件,取而代之的是兩個文件:
<<mymodel.data-00000-of-00001>>
<<mymodel.index>>
.data文件包含著訓(xùn)練好的變量的值,除此之外,Tensorflow還有一個名為checkpoint的文件,持續(xù)記錄著最新的保存數(shù)據(jù)。
所以,總結(jié)下來,0.10之后的Tensorflow模型如下圖所示:

而,0.11版本之前的Tensorflow模型,僅僅包含三個文件:
<<inception_v1.meta>>
<<inception_v1.ckpt>>
<<checkpoint>>
2. 保存一個Tensorflow模型:
假設(shè),你正在訓(xùn)練一個卷積神經(jīng)網(wǎng)絡(luò),用于圖片分類。作為一個標(biāo)準(zhǔn)操作,你持續(xù)觀測Loss function和Accuracy。一旦你看到網(wǎng)絡(luò)收斂,你可以人為停止訓(xùn)練或者只訓(xùn)練固定數(shù)目的epochs。當(dāng)訓(xùn)練完成之后,我們想要保存所有的變量和網(wǎng)絡(luò)圖(network graph)到一個文件,以便日后使用。因此,在Tensorflow中,為了保存graph和變量,我們應(yīng)該新建一個tf.train.Saver()類。

謹(jǐn)記Tensorflow的變量只有在一個session中才是有效的。因此,你不得不在一個session中保存模型,使用剛剛新建的saver對象,調(diào)用save方法,如下:

這里,sess是一個session對象,“my-test-model”是你想要保存的模型的名字。完整的例子如下:

如果,我們想要在1000次迭代之后保存模型,可以傳入表示步數(shù)的參數(shù):

這行代碼將添加‘-1000’至模型的名字,以下文件將被建立:

假設(shè),訓(xùn)練時,我們每隔1000次迭代保存一次模型,因此,.meta文件第1000次迭代生成.meta文件后,我們不必要每次新建.meta文件(即在2000,3000次等迭代無須新建.meta文件)。我們僅僅保存最新的迭代模型。因為graph結(jié)構(gòu)并沒有改變,因此,也沒必要寫meta-graph,使用如下代碼:

如果你想要只記錄最新的4個模型,并每隔2個小時保存一個模型,可以使用這兩個參數(shù):max_to_keep和keep_checkpoint_every_n_hours,如下:

需要指出的是,如果我們在tf.train.Saver()中不指定任何事情,它將保存所有的變量。如果,我們不想保存所有的變量,僅僅是一部分。我們可以指定想要保存的變量或集合。當(dāng)新建tf.train.Saver實例時,傳遞給它一個想要保存的變量的列表或者字典。看下面的例子:

可以保存Tensorflow Graph的指定的需要的部分。
3. 導(dǎo)入預(yù)訓(xùn)練的模型
如果你想要使用別人訓(xùn)練好的模型做fine-tuning,有兩件事需要做:
a) 構(gòu)建網(wǎng)絡(luò):
你可以寫python代碼,像寫預(yù)訓(xùn)練的網(wǎng)絡(luò)一樣,人為地復(fù)原每一層或者每一個模塊。但是,如果你想到我們已經(jīng)將網(wǎng)絡(luò)保存到.meta文件里了,就可以使用tf.train.import()函數(shù),恢復(fù)網(wǎng)絡(luò)結(jié)構(gòu),如下:
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
記住,import_meta_graph方法將預(yù)定義在.meta文件的網(wǎng)絡(luò)添加到當(dāng)前網(wǎng)絡(luò)。因此,該方法構(gòu)造graph結(jié)構(gòu),但我們?nèi)孕枰虞d預(yù)訓(xùn)練的參數(shù)的值。
b) 加載參數(shù):
我們可以通過tf.train.Saver()的restore方法,恢復(fù)網(wǎng)絡(luò)的參數(shù):

執(zhí)行完上述代碼,w1和w2張量的值就被恢復(fù)了,可以通過如下代碼獲?。?/p>

所以,至此你已經(jīng)理解了如何保存和導(dǎo)入Tensorflow模型的工作。下一章節(jié),我將描述加載任意預(yù)訓(xùn)練模型的實際使用。
4. 使用恢復(fù)模型
既然你已經(jīng)理解如何保存并恢復(fù)Tensorflow模型,讓我們養(yǎng)成一個規(guī)范去恢復(fù)任意預(yù)訓(xùn)練模型,并使用它做預(yù)測,fine-tuning或者進(jìn)一步訓(xùn)練。不管什么時候使用Tensorflow,你將定義一個Graph,包含輸入,一些超參數(shù),如learning rate, global step等。一個標(biāo)準(zhǔn)的喂入數(shù)據(jù)和超參數(shù)的方式是使用placeholders。讓我們構(gòu)建一個小的使用placeholders的網(wǎng)絡(luò),并保存它。值得指出的是。當(dāng)網(wǎng)絡(luò)被保存。placeholders的值并未保存。

現(xiàn)在,當(dāng)我們想要恢復(fù)模型時,不僅需要恢復(fù)graph和權(quán)重,也需要準(zhǔn)備新的feed_dict去喂新的訓(xùn)練數(shù)據(jù)給網(wǎng)絡(luò)。我們可以通過graph.get_tensor_by_name()等方法得到保存的ops和placeholder變量的引用。

如果我們僅僅想要在網(wǎng)絡(luò)上跑不同的數(shù)據(jù),可以通過feed_dict傳遞新的數(shù)據(jù)給網(wǎng)絡(luò)。

如果想要增加更多的操作(增加更多的layers)到graph里,并訓(xùn)練它。當(dāng)然,你也可以如下:

但是,可以只恢復(fù)一部分的graph然后增加一些操作進(jìn)行fine-tuning么?當(dāng)然可以。利用graph.get_tensor_by_name()方法得到相應(yīng)操作的引用,在頂層構(gòu)建網(wǎng)絡(luò)。這里有個實際的例子。我們加載一個預(yù)訓(xùn)練的VGG網(wǎng)絡(luò),改變輸出的單元數(shù)目為2,利用新的訓(xùn)練數(shù)據(jù)fine-tuning。

希望這篇文章能讓你清晰地理解Tensorflow模型的保存和恢復(fù)。
轉(zhuǎn)載請注明來源,謝謝。