使用TensorFlow訓(xùn)練一個(gè)模型,可以多次運(yùn)行訓(xùn)練操作,并在完成后保存訓(xùn)練參數(shù)的檢查點(diǎn)(checkpoint)。這對(duì)能夠在幾個(gè)小時(shí)內(nèi)訓(xùn)練的小模型很有效。但是如果是訓(xùn)練的數(shù)據(jù)量比較大,可能需要訓(xùn)練幾天或者幾個(gè)月。。。
那原生的tensorflow的健壯性可能就比較堪憂。。。
萬(wàn)一斷電了之類。。。
這時(shí)候我們就可以使用supervisor
需要長(zhǎng)時(shí)間訓(xùn)練的較大模型,需要更魯棒(robust)的訓(xùn)練過(guò)程:
- 能處理關(guān)機(jī)以及徹底崩潰的情況。
- 可以在關(guān)機(jī)或崩潰后恢復(fù)。
- 可以通過(guò)TensorBoard進(jìn)行監(jiān)控。
為了能夠在停機(jī)或崩潰后恢復(fù)訓(xùn)練,訓(xùn)練過(guò)程必須周期保存檢查點(diǎn)。在重新啟動(dòng)時(shí),它必須查找最新的檢查點(diǎn),并在恢復(fù)訓(xùn)練之前加載它。supervisor可以看做一個(gè)工具,或者說(shuō)是對(duì)原生tensorflow的一層封裝,目的主要是通過(guò)定期save的方法增強(qiáng)訓(xùn)練健壯性,
就算程序掛掉了也可以從上一次save的checkpoint恢復(fù),而不是從頭再來(lái)(雖然這些也可以手動(dòng)實(shí)現(xiàn),同時(shí)也可以簡(jiǎn)化代碼量
tf.train.Supervisor提供了一套有助于實(shí)施魯棒的訓(xùn)練過(guò)程的服務(wù)。除了supervisor,還有tf.learn庫(kù),里面提供對(duì)原生tensorflow更高層的封裝,也提供更豐富的功能。
請(qǐng)注意,Supervisor對(duì)訓(xùn)練大模型非常有幫助,但也可以用于較小型號(hào),不會(huì)有任何不好的地方。
supervisor可以看做一個(gè)工具,或者說(shuō)是對(duì)原生tensorflow的一層封裝,目的主要是通過(guò)定期save的方法增強(qiáng)訓(xùn)練健壯性。
1.一個(gè)簡(jiǎn)單方案
使用Supervisor的最簡(jiǎn)單的方案是:
創(chuàng)建一個(gè)
Supervisor對(duì)象,將其傳遞到保存檢查點(diǎn)和summary的目錄。用
tf.train.Supervisor.managed_session向Supervisor請(qǐng)求一個(gè)會(huì)話(session)。使用會(huì)話執(zhí)行訓(xùn)練操作,如果Supervisor要求訓(xùn)練停止,請(qǐng)檢查每一步。
...create graph...
my_train_op = ...
sv = tf.train.Supervisor(logdir="/my/training/directory")
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
sess.run(my_train_op)
開(kāi)始服務(wù)
managed_session()啟動(dòng)一些服務(wù),它們?cè)谧约旱木€程中運(yùn)行,并利用managed session在圖中運(yùn)行各種操作。
如果圖中包含一個(gè)名為global_step的整型變量,則服務(wù)使用其值來(lái)測(cè)量執(zhí)行的訓(xùn)練步驟數(shù)量。有關(guān)如何創(chuàng)建global_step變量,請(qǐng)參閱MNIST訓(xùn)練教程。
檢查點(diǎn)服務(wù):在logdir中保存圖形變量的副本。
global_step如果添加到您的圖中,則檢查點(diǎn)文件名將使用該變量的值。默認(rèn)運(yùn)行10分鐘。summary服務(wù):運(yùn)行所有summary操作,并將其輸出附加到logdir 中的 事件文件中。默認(rèn)情況下每2分鐘運(yùn)行一次。
步驟計(jì)數(shù)器:通過(guò)查看
global_step變量的更改來(lái)計(jì)算執(zhí)行了多少步。向事件文件追加一個(gè)summary,報(bào)告每秒鐘的全局步數(shù)。 summary tag 為“global_step / sec”。這也默認(rèn)每2分鐘運(yùn)行一次。Queue Runners:如果
tf.train.QueueRunner添加到圖形中,Supervisor將在自己的線程中啟動(dòng)它們。
構(gòu)建Supervisor對(duì)象時(shí)可以更改所有時(shí)間間隔。有關(guān)詳細(xì)信息,請(qǐng)參閱Supervisor參考。
檢查停止
在主訓(xùn)練循環(huán)中對(duì)停止的檢查是重要和必要的。
在服務(wù)線程中引發(fā)的異常報(bào)告給Supervisor,然后將其should_stop()條件設(shè)置為true。其他服務(wù)線程告知此情形并合理終止。managed_session()塊內(nèi)的主訓(xùn)練循環(huán) 還必須檢查停止條件并終止。
請(qǐng)注意managed_session()捕獲從訓(xùn)練循環(huán)中引發(fā)的異常情況,將其報(bào)告給Supervisor。主循環(huán)不需要對(duì)異常做任何特別的處理。它只需要檢查停止條件。
復(fù)蘇
如果訓(xùn)練程序關(guān)閉或崩潰,其最新的檢查點(diǎn)和事件文件將留在logdir中。當(dāng)重新啟動(dòng)程序時(shí), managed_session()從最近的檢查點(diǎn)恢復(fù)圖形,并恢復(fù)停止的訓(xùn)練。
創(chuàng)建一個(gè)新的事件文件。如果啟動(dòng)TensorBoard并將其指向logdir,它將會(huì)知道如何合并兩個(gè)事件文件的內(nèi)容,并將在檢查點(diǎn)的最后一個(gè)全局步驟中顯示訓(xùn)練恢復(fù)。
2.較大的模式場(chǎng)景
最簡(jiǎn)單的情景已經(jīng)足以處理大多數(shù)小到中模型的訓(xùn)練。更大的模型也許會(huì)在運(yùn)行summary sevice的時(shí)候耗盡內(nèi)存:summary ops是與main loop中的train op一起并行地run的。這會(huì)導(dǎo)致內(nèi)存使用達(dá)到通常使用的兩倍多。
對(duì)于打得模型你可以通知supervisor不要運(yùn)行summary服務(wù),作為替代,你在自己的主訓(xùn)練循環(huán)中來(lái)運(yùn)行:創(chuàng)建supervisor的時(shí)候傳遞summary_op=None。
例如,該代碼在訓(xùn)練循環(huán)中每100個(gè)步驟運(yùn)行摘要:
...create graph...
my_train_op = ...
my_summary_op = tf.summary.merge_all()
sv = tf.train.Supervisor(logdir="/my/training/directory",
summary_op=None) # Do not run the summary service
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
if step % 100 == 0:
_, summ = session.run([my_train_op, my_summary_op])
sv.summary_computed(sess, summ)
else:
session.run(my_train_op)
預(yù)訓(xùn)練的模型情景
managed_session()調(diào)用很關(guān)心在session中初始化模型。模型會(huì)在可能的時(shí)候從一個(gè)checkpoint中加載,亦或從scratch中初始化。
一個(gè)常見(jiàn)的情景是要用加載的預(yù)訓(xùn)練的checkpoint來(lái)初始化模型,而該預(yù)訓(xùn)練模型和當(dāng)前模型有些許的不同。
你可以通過(guò)給supervisor傳遞init function的方式來(lái)加載預(yù)訓(xùn)練的checkpoint。這個(gè)函數(shù)只有在模型需要從scratch初始化時(shí)才被調(diào)用,而模型從logdir中的checkpoint恢復(fù)的時(shí)候并不會(huì)。
為了加載預(yù)訓(xùn)練模型,init 函數(shù)需要一個(gè)tf.train.Saver對(duì)象,所以你應(yīng)該創(chuàng)建一個(gè)saver。新模型也許包含一些預(yù)訓(xùn)練的checkpoint中不存在的變量,所以這是一個(gè)很好的思想:這個(gè)saver必須只加載預(yù)訓(xùn)練的變量。如果你正在使用默認(rèn)的saver,你會(huì)在嘗試加載所有變量的時(shí)候得到一個(gè)錯(cuò)誤。
...create graph...
my_train_op = ...
my_summary_op = tf.summary.merge_all()
sv = tf.train.Supervisor(logdir="/my/training/directory",
summary_op=None) # Do not run the summary service
with sv.managed_session() as sess:
for step in range(100000):
if sv.should_stop():
break
if step % 100 == 0:
_, summ = session.run([my_train_op, my_summary_op])
sv.summary_computed(sess, summ)
else:
session.run(my_train_op)
運(yùn)行你自己的服務(wù)
Supervisor服務(wù),比如checkpointing服務(wù),與主訓(xùn)練循環(huán)并行運(yùn)行。有時(shí)候你想加入你自己的服務(wù),比如取出和 通常的summary的schedule不一樣的不同設(shè)置的summaries。
使用supervisor中的tf.train.Supervisor.loop來(lái)達(dá)成這個(gè)目的。它會(huì)根據(jù)你選擇的定時(shí)器重復(fù)地調(diào)用一個(gè)函數(shù),直到supervisor的stop condition為true,所以它和其他服務(wù)很協(xié)調(diào)。
例如:每20分鐘調(diào)用一次my_additional_summaries():
def my_additional_sumaries(sv, sess):
...fetch and write summaries, see below...
...
sv = tf.train.Supervisor(logdir="/my/training/directory")
with sv.managed_session() as sess:
# Call my_additional_sumaries() every 1200s, or 20mn,
# passing (sv, sess) as arguments.
sv.loop(1200, my_additional_sumaries, args=(sv, sess))
...main training loop...
寫(xiě)summaries
supervisor總是在其logdir中生成一個(gè)事件文件,同時(shí)用一個(gè)tf.summary.FileWriter將事件和summaries添加到事件文件。如果你想寫(xiě)自己的summaries,也可以將它們添加到同一個(gè)事件文件中去:TensorBoard很喜歡在目錄中只有一個(gè)事件文件。
supervisor提供了一個(gè)輔助函數(shù)來(lái)添加summaries:tf.train.Supervisor.summary_computed:只需要傳遞一份summary_op的返回輸出函數(shù)。以下是使用該函數(shù)實(shí)現(xiàn)之前例子中my_additional_sumaries()的例子:
def my_additional_sumaries(sv, sess):
summaries = sess.run(my_additional_summary_op)
sv.summary_computed(sess, summaries)
更多前沿的用法參看tf.train.Supervisor.summary_writer屬性。
supervisor 參考
在簡(jiǎn)單的情景以及更大的模型方案的情景展示了supervisor的基本用法。更高級(jí)的情景可以用supervisor提供的很多選項(xiàng)來(lái)創(chuàng)建。
Checkpointing:何時(shí)何處
managed_session()調(diào)用開(kāi)啟了checkpointing服務(wù),而這可以通過(guò)對(duì)Supervisor()創(chuàng)建時(shí)以下的參數(shù)來(lái)配置:
- logdir: checkpointing服務(wù)床創(chuàng)建checkpoints的目錄路徑。如果需要,創(chuàng)建該目錄。傳遞None禁用checkpointing以及summary服務(wù)。
- checkpoint_basename: 欲創(chuàng)建的checkpoint文件的名稱,默認(rèn)為”model.ckpt”。
如果模型包含一個(gè)名為的標(biāo)量整數(shù)變量global_step,則該變量的值將附加到檢查點(diǎn)文件名。
例如,在global_step 1234,checkpoint 文件名就是 “model.ckpt-1234”。
- save_model_secs: 每個(gè)checkpoint之間的秒數(shù)。默認(rèn)為600,即10分鐘。
當(dāng)選擇一個(gè)值時(shí),要考慮一旦有crash時(shí)你要丟失多少工作:你永遠(yuǎn)不會(huì)丟失多于save_model_secs秒的工作。設(shè)置為0就禁用了checkpointing服務(wù)。
- saver: 一個(gè)tf.train.Saver對(duì)象,用來(lái)checkpointing。
如果不傳遞saver,supervisor會(huì)調(diào)用tf.train.Saver()來(lái)創(chuàng)建一個(gè),該saver會(huì)把所有的ops保存,并加載你模型中所有的變量。你通常也需要這么做。
示例:每30秒使用自定義保護(hù)程序和檢查點(diǎn)。
...create graph...
my_saver = tf.train.Saver(<only some variables>)
sv = tf.train.Supervisor(logdir="/my/training/directory",
saver=my_saver,
save_model_secs=30)
with sv.managed_session() as sess:
...training loop...
Summaries:何時(shí)何處
類似checkpointing,logdir對(duì)summaries的作用也是一樣的。事件文件在此創(chuàng)建,如果None則禁用了summary服務(wù)。
save_summaries_secs:該參數(shù)代表每次運(yùn)行summary sevice服務(wù)的間隔的秒數(shù)。默認(rèn)為120秒,即兩分鐘。同樣,設(shè)置為0時(shí)則禁用了summary服務(wù)。
-
summary_op,用來(lái)取得summaries的op。
如果沒(méi)指定,supervisor會(huì)使用
tf.GraphKeys.SUMMARY_OP圖集合(graph collection)中第一個(gè)op。如果該集合為空,supervisor則創(chuàng)建一個(gè)op,它會(huì)將圖中的所有summaries使用tf.summary.merge_all()聚集在一起。如果給summary_op傳遞None則禁用了summary服務(wù)。
-
global_step:用來(lái)計(jì)算全局步數(shù)的張量。
如果沒(méi)有指明,supervisor使用
tf.GraphKeys.GLOBAL_STEP圖集合(graph collection)中第一個(gè)tensor,如果該集合為空,supervisor在圖中尋找一個(gè)name為global_step的整型的變量的標(biāo)量。
如果找到,global step張量被用來(lái)衡量訓(xùn)練步數(shù)執(zhí)行的數(shù)量。注意,你的訓(xùn)練op會(huì)增加global step的值。
模型的初始化和恢復(fù)
managed_session()調(diào)用野專注于初始化以及恢復(fù)一個(gè)session。它返回一個(gè)session同時(shí)伴隨一個(gè)全部初始化了的模型,準(zhǔn)備去訓(xùn)練。如果managed_session()調(diào)用時(shí)logdir里有一個(gè)checkpoint,模型會(huì)通過(guò)加載checkpoint初始化,否則會(huì)通過(guò)調(diào)用一個(gè)初始化op或者選擇一個(gè)init function。
如果沒(méi)有可用的checkpoint,模型的初始化則有下面的參數(shù)傳遞給supervisor()的創(chuàng)建器來(lái)控制:
-
init_op: 需要被運(yùn)行來(lái)初始化模型的op。
如果沒(méi)有指定,supervisor會(huì)使用tf.GraphKeys.INIT_OP圖集合( collection)中第一個(gè)op。如果集合是空的,則會(huì)通過(guò)調(diào)用tf.global_variables_initializer()添加一個(gè)初始化所有變量的op。
傳遞None則不適用初始化op。
-
init_fn: 調(diào)用它來(lái)初始化模型。
如果指定則這樣調(diào)用 :init_fn(sess),這里的sess是managed session。如果init op同時(shí)使用,則init function在init op之后被調(diào)用。
-
local_init_op: 一個(gè)額外的op,用來(lái)初始化圖段一部分,這部分沒(méi)有被保存在checkpoints中。比如比如tables以及一些local variables。local init op在init op以及 init function之后運(yùn)行。
如果沒(méi)有指定,supervisor使用tf.GraphKeys.LOCAL_INIT_OP集合里的第一個(gè)op。如果集合為空,則通過(guò)調(diào)用tf.tables_initializer() 和 tf.local_variables_initializer()添加一初始化所有tables以及l(fā)ocal variables的op。
傳遞None禁用local init op。
ready_op: 核查模型是否被初始化的op。
運(yùn)行了local init op,init op以及init function之后,supervisor會(huì)通過(guò)執(zhí)行ready op來(lái)驗(yàn)證模型是否被完全初始化。如果初始化則該op返回空字符串,否則返回模型那部分未被初始化的一個(gè)描述。
如果未指定,supervisor會(huì)使用tf.GraphKeys.READY_OP 集合中的第一個(gè)op。若集合未空,supervisosr通過(guò)調(diào)用tf.report_uninitialized_variables()創(chuàng)建一個(gè)ready op來(lái)確保所有變量都被初始化。
傳遞None來(lái)禁用ready op。在這種情況下模型初始化之后不進(jìn)行核查。
checkpoint的恢復(fù)是由以下傳給superfisor()創(chuàng)建器的參數(shù)控制:
-
logdir:尋找checkpoints的路徑。checkpoint服務(wù)保存了一個(gè)metadata文件,名為 “checkpoint”,在這個(gè)checkpoint目錄中指明最近的一個(gè)checkpoint的路徑。
這個(gè)文件是文本格式的。你可以手工編輯它來(lái)從一個(gè)不同于最近的checkpoint的checkpoint中恢復(fù)。
ready_op:和上面的一樣。ready op在加載checkpoint之前和之后運(yùn)行。第一次運(yùn)行檢查模型是否需要被初始化,第二次驗(yàn)證模型完全被初始化。
local_init_op:和上面的一樣。local init op在第一次運(yùn)行ready op之前運(yùn)行,來(lái)初始化局部變量以及tables。
saver:和上面的一樣。用來(lái)加載checkpoint的的Saver對(duì)象。