使用驗證集判斷模型效果
為了評測神經(jīng)網(wǎng)絡(luò)模型在不同參數(shù)下的效果,一般會從訓(xùn)練集中抽取一部分作為驗證數(shù)據(jù)。除了使用驗證數(shù)據(jù)集,還可以采用交叉驗證(cross validation )的方式驗證模型效果,但是使用交叉驗證會花費大量的時間。但在海量數(shù)據(jù)情況下,一般采用驗證數(shù)據(jù)集的形式評測模型的效果。
一般采用的驗證數(shù)據(jù)分布越接近測試數(shù)據(jù)分布,模型在驗證數(shù)據(jù)上的表現(xiàn)越可以體現(xiàn)模型在測試數(shù)據(jù)上的保險。
使用滑動平均模型和指數(shù)衰減的學(xué)習(xí)率在一定程度上都是限制神經(jīng)網(wǎng)絡(luò)中參數(shù)更新的速度。
在處理復(fù)雜問題時,使用滑動平均模型、指數(shù)衰減的學(xué)習(xí)率和正則化損失可以明顯提升模型的訓(xùn)練效果。
變量管理
Tensorflow提供了通過變量名稱來創(chuàng)建或者獲取一個變量的機(jī)制,避免了復(fù)雜神經(jīng)網(wǎng)絡(luò)頻繁傳遞參數(shù)的情況。通過該機(jī)制,在不同的函數(shù)中可以直接通過變量的名字來使用變量,而不需要將變量通過參數(shù)的形式到處傳遞。
Tensorflow中通過變量名獲取變量的機(jī)制主要通過tf.get_variable()和tf.variable_scope()函數(shù)實現(xiàn)。
-
tf.get_variable()
該函數(shù)創(chuàng)建變量的方法和tf.Variable()函數(shù)的用法基本一樣,提供維度信息(shape)以及初始化方法(initializer)的參數(shù)。該函數(shù)的變量名稱是一個必填參數(shù),函數(shù)會根據(jù)這個名字去創(chuàng)建或者獲取變量。當(dāng)已經(jīng)有同名參數(shù)時,會報錯。 -
tf.variable_scope()
該函數(shù)可以控制tf.get_variable()函數(shù)的語義。當(dāng)tf.variable_scope()函數(shù)使用參數(shù)reuse=True生成上下文管理器時,這個上下文管理器內(nèi)所有的tf.get_variable()函數(shù)會直接獲取已經(jīng)創(chuàng)建的變量。如果不存在,則報錯;當(dāng)reuse=False或者reuse=None創(chuàng)建上下文管理器時,tf.get_variable()操作將創(chuàng)建新的變量,如果同名變量已經(jīng)存在,則報錯。
同時tf.variable_scope()函數(shù)可以嵌套。新建一個嵌套的上下文管理器但不指定reuse,這時的reuse的取值和外面一層保持一致。當(dāng)退出reuse設(shè)置為True的上下文之后reuse的值又回到了False(內(nèi)層reuse不設(shè)置)。
同時,tf.variable_scope()函數(shù)生成的上下文管理器也會創(chuàng)建一個Tensorflow中的命名空間,在命名空間內(nèi)創(chuàng)建的變量名稱都會帶上這個命名空間名作為前綴??梢灾苯油ㄟ^帶命名空間名稱的變量名來獲取其它命名空間下的變量(創(chuàng)建一個名稱為空的命名空間,并設(shè)置為reuse=True)。
with tf.variable_scope(" ", reuse=True):
v5 = tf.get_variable("foo/bar/v", [1])
print(v5.name)
===>v:0 # 0表示variable這個運(yùn)算輸出的第一個結(jié)果
Tensorflow模型持久化
將訓(xùn)練得到的模型保存下來,可以方便下次直接使用(避免重新訓(xùn)練花費大量的時間)。Tensorflow提供的持久化機(jī)制可以將訓(xùn)練之后的模型保存到文件中。
Tensorflow提供了tf.train.Saver類來保存和還原神經(jīng)網(wǎng)絡(luò)模型。當(dāng)保存模型之后,目錄下一般會出現(xiàn)三個文件,這是因為Tensorflow會將計算圖的結(jié)構(gòu)和圖上參數(shù)值分開保存。
-
model.ckpy.meta文件,保存了Tensorflow計算圖的結(jié)構(gòu)。 -
model.ckpt文件,保存了Tensorflow程序每一個變量的取值。 -
checkpoint文件,保存了一個目錄下所有的模型文件列表。
保存模型
saver = tf.train.Saver()
saver.save(sess, "path/model.ckpt")
加載模型,此時不用進(jìn)行變量的初始化過程
saver.restore(sess, "path/model.ckpt")
sess.run(result)
為了保存和加載部分變量,在聲明tf.train.Saver類時可以提供一個列表來指定需要保存或加載的變量,saver = tf.train.Saver([v1])。同時,tf.train.Saver類也支持在保存或者加載時給變量重命名,如果直接加載就會導(dǎo)致程序報變量找不到的錯誤,Tensorflow提供通過字典將模型保存時的變量名和要加載的變量聯(lián)系起來。
v = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
saver = tf.train.Saver({"v1": v})
將原先變量名為v1的變量加載到變量v中,變量v的名稱為other-v1。
這樣做的目的時為了方便使用變量的滑動平均值。因為每一個變量的滑動平均值是通過影子變量維護(hù)的,如果在加載模型時直接將影子變量映射到變量自身,就不需要在調(diào)用函數(shù)來獲取變量的平均值了。
為了方便加載重命名滑動平均變量,tf.train.ExponentialMovingAverage類提供了variables_to_restore()函數(shù)來生成tf.train.Saver類所需要的變量重命名字典。
v = tf.Variable(0)
ema = tf.train.ExponentialMovingAverage(0.99)
saver = tf.train.Saver(ema.variable_to_restore())
with tf.Session() as sess:
saver.restore(sess, "path/model.ckpt")
sess.run(v)
有時候不需要類似于變量初始化、模型保存等輔助節(jié)點的信息,Tensorflow提供了convert_variables_to_constants()函數(shù)將計算圖中的變量及其取值通過常量的方式保存。
持久化原理及數(shù)據(jù)格式
Tensorflow程序中所有計算都會被表達(dá)為計算圖上的節(jié)點。
MetaGraphDef
Tensorflow通過元圖(MetaGraph)來記錄計算圖中節(jié)點的信息以及運(yùn)行計算圖中節(jié)點所需要的元數(shù)據(jù),元圖是由MetaGraphDef Protocol Buffer定義的,MetaGraphDef中的內(nèi)容構(gòu)成了Tensorflow持久化的第一個文件,也就是model.ckpt.meta文件。
-
meta_info_def屬性,記錄了Tensorflow計算圖中的元數(shù)據(jù)以及Tensorflow程序中所有使用到的運(yùn)算方法的信息。元數(shù)據(jù)包括了計算圖的版本號以及用戶指定的一些標(biāo)簽,其中meta_info_def屬性的stripped_op_list屬性保存了Tensorflow運(yùn)算方法的信息,如果一個運(yùn)算方法在計算圖中出現(xiàn)了多次,在該字段中也只出現(xiàn)一次。stripped_op_list屬性的類型是OpList,OpList類型是一個OpDef類型的列表,該類型定義了一個運(yùn)算的所有信息,包括運(yùn)算名、輸入輸出和運(yùn)算的參數(shù)信息。 -
graph_def屬性,主要記錄了Tensorflow計算圖上的節(jié)點信息,Tensorflow計算圖的每一個節(jié)點對應(yīng)了Tensorflow程序中的一個運(yùn)算。meta_info_def屬性已經(jīng)包含了所有運(yùn)算的具體信息,所以graph_def屬性只關(guān)注運(yùn)算的連接結(jié)果。
該屬性是通過GraphDef Protocol Buffer定義的,GraphDef主要包含了一個NodeDef類型的列表,GraphDef的versions屬性存儲了Tensorflow的版本號,node屬性記錄了所有的節(jié)點信息。node為NodeDef類型,該類型的op屬性給出了該節(jié)點使用的運(yùn)算方法名稱,具體信息可以通過meta_info_def獲取,input屬性是一個字符串列表,定義了運(yùn)算的輸入,device屬性定義了處理該運(yùn)算的設(shè)備,attr屬性定義了和當(dāng)前運(yùn)算相關(guān)的配置信息。 -
saver_def屬性,記錄了持久化模型所需要用到的一些參數(shù),比如保存到文件的文件名,保存操作和加載操作的名稱以及保存頻率、清理歷史記錄等。
該屬性主要通過SaverDef定義。 -
collention_def屬性,Tensorflow計算圖中可以維護(hù)不同的集合,底層實現(xiàn)就是通過collention_def這個屬性。collection_def屬性是一個從集合名稱到集合內(nèi)容的映射,其中集合名稱為字符串,集合內(nèi)容為CollentionDef Protocol Buffer。Tensorflow計算圖上的集合主要可以維護(hù)4類不同的集合:NodeList用于維護(hù)計算圖上的節(jié)點集合;BytesList用于維護(hù)字符串或者序列化之后的Protocol Buffer的集合;Int64List用于維護(hù)整數(shù)集合;FloatList用于維護(hù)實數(shù)集合。
SSTable
持久化Tensorflow中變量的取值,tf.Saver得到的model.ckpt文件保存了所有的變量,該文件使用SSTable格式存儲的,相當(dāng)于一個(key, value)列表。
CheckpointState
持久化的最后一個文件名叫checkpoint,這個文件是tf.train.Saver類自動生成且自動維護(hù)的。該文件中維護(hù)了一個由tf.train.Saver類持久化的所有Tensoflow模型文件的文件名,當(dāng)某個模型文件被刪除時,這個模型對應(yīng)的文件名也會被移除,checkpoint中內(nèi)容的格式為CheckpointState Protocol Buffer。