Tensorflow實(shí)驗(yàn)管理

Saver的用法

1. Saver的背景介紹

我們經(jīng)常在訓(xùn)練完一個(gè)模型之后希望保存訓(xùn)練的結(jié)果,這些結(jié)果指的是模型的參數(shù),以便下次迭代的訓(xùn)練或者用作測(cè)試。Tensorflow針對(duì)這一需求提供了Saver類。

  1. Saver類提供了向checkpoints文件保存和從checkpoints文件中恢復(fù)變量的相關(guān)方法。Checkpoints文件是一個(gè)二進(jìn)制文件,它把變量名映射到對(duì)應(yīng)的tensor值 。
  2. 只要提供一個(gè)計(jì)數(shù)器,當(dāng)計(jì)數(shù)器觸發(fā)時(shí),Saver類可以自動(dòng)的生成checkpoint文件。這讓我們可以在訓(xùn)練過程中保存多個(gè)中間結(jié)果。例如,我們可以保存每一步訓(xùn)練的結(jié)果。
  3. 為了避免填滿整個(gè)磁盤,Saver可以自動(dòng)的管理Checkpoints文件。例如,我們可以指定保存最近的N個(gè)Checkpoints文件。

2. Saver的實(shí)例

下面以一個(gè)例子來講述如何使用Saver類

1.  import tensorflow as tf  
2.  import numpy as np  

4.  x = tf.placeholder(tf.float32, shape=[None, 1])  
5.  y = 4 * x + 4  

7.  w = tf.Variable(tf.random_normal([1], -1, 1))  
8.  b = tf.Variable(tf.zeros([1]))  
9.  y_predict = w * x + b  

12.  loss = tf.reduce_mean(tf.square(y - y_predict))  
13.  optimizer = tf.train.GradientDescentOptimizer(0.5)  
14.  train = optimizer.minimize(loss)  

16.  isTrain = False  
17.  train_steps = 100  
18.  checkpoint_steps = 50  
19.  checkpoint_dir = ''  

21.  saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b  
22.  x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  

24.  with tf.Session() as sess:  
25.  sess.run(tf.initialize_all_variables())  
26.  if isTrain:  
27.  for i in xrange(train_steps):  
28.  sess.run(train, feed_dict={x: x_data})  
29.  if (i + 1) % checkpoint_steps == 0:  
30.  saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)  
31.  else:  
32.  ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
33.  if ckpt and ckpt.model_checkpoint_path:  
34.  saver.restore(sess, ckpt.model_checkpoint_path)  
35.  else:  
36.  pass  
37.  print(sess.run(w))  
38.  print(sess.run(b))  

isTrain:用來區(qū)分訓(xùn)練階段和測(cè)試階段,True表示訓(xùn)練,F(xiàn)alse表示測(cè)試
train_steps:表示訓(xùn)練的次數(shù),例子中使用100
checkpoint_steps:表示訓(xùn)練多少次保存一下checkpoints,例子中使用50
checkpoint_dir:表示checkpoints文件的保存路徑,例子中使用當(dāng)前路徑

2.1 訓(xùn)練階段

使用Saver.save()方法保存模型:

  1. sess:表示當(dāng)前會(huì)話,當(dāng)前會(huì)話記錄了當(dāng)前的變量值
  2. checkpoint_dir + 'model.ckpt':表示存儲(chǔ)的文件名
  3. global_step:表示當(dāng)前是第幾步

訓(xùn)練完成后,當(dāng)前目錄底下會(huì)多出5個(gè)文件。

打開名為“checkpoint”的文件,可以看到保存記錄,和最新的模型存儲(chǔ)位置。

2.2測(cè)試階段

測(cè)試階段使用saver.restore()方法恢復(fù)變量:
  1. sess:表示當(dāng)前會(huì)話,之前保存的結(jié)果將被加載入這個(gè)會(huì)話

  2. ckpt.model_checkpoint_path:表示模型存儲(chǔ)的位置,不需要提供模型的名字,它會(huì)去查看checkpoint文件,看看最新的是誰,叫做什么。

    運(yùn)行結(jié)果如下圖所示,加載了之前訓(xùn)練的參數(shù)w和b的結(jié)果

部分變量保存

默認(rèn)情況下,saver.save()存儲(chǔ)圖形的所有變量,一般建議這樣做。但是,當(dāng)我們創(chuàng)建保存對(duì)象時(shí),您還可以通過將它們作為列表或字典傳入來選擇要存儲(chǔ)的變量。

v1 = tf.Variable(..., name='v1') 
v2 = tf.Variable(..., name='v2') 

# pass the variables as a dict: 
saver = tf.train.Saver({'v1': v1, 'v2': v2}) 

# pass them as a list
saver = tf.train.Saver([v1, v2]) 

# passing a list is equivalent to passing a dict with the variable op names # as keys
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

這里使用了三種不同的方式來創(chuàng)建 saver 對(duì)象, 但是它們內(nèi)部的原理是一樣的。我們都知道,參數(shù)會(huì)保存到 checkpoint 文件中,通過鍵值對(duì)的形式在 checkpoint中存放著。如果 Saver 的構(gòu)造函數(shù)中傳的是 dict,那么在 save 的時(shí)候,checkpoint文件中存放的就是對(duì)應(yīng)的 key-value。如下:

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

我們通過官方提供的工具來看一下 checkpoint 中保存了什么

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file("test-ckpt/model-2", None, True)
# 輸出:
#tensor_name:  variable_1
#1.0
#tensor_name:  variable_2
#2.0

如果構(gòu)建saver對(duì)象的時(shí)候,我們傳入的是 list, 那么將會(huì)用對(duì)應(yīng) Variable 的 variable.op.name 作為 key。

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver([v1, v2])
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

我們?cè)偈褂霉俜焦ぞ叽蛴〕?checkpoint 中的數(shù)據(jù),得到

tensor_name:  v1
1.0
tensor_name:  v2
2.0

如果我們現(xiàn)在想將 checkpoint 中v2的值restore到v1 中,v1的值restore到v2中,我們?cè)撛趺醋觯?/strong>
這時(shí),我們只能采用基于 dictsaver

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

save 部分的代碼如上所示,下面寫 restore 的代碼,和save代碼有點(diǎn)不同。

```python
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
#restore的時(shí)候,variable_1對(duì)應(yīng)到v2,variable_2對(duì)應(yīng)到v1,就可以實(shí)現(xiàn)目的了。
saver = tf.train.Saver({"variable_1":v2, "variable_2": v1})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.restore(sess, 'test-ckpt/model-2')
    print(sess.run(v1), sess.run(v2))
# 輸出的結(jié)果是 2.0 1.0,如我們所望

我們發(fā)現(xiàn),其實(shí) 創(chuàng)建 saver對(duì)象時(shí)使用的鍵值對(duì)就是表達(dá)了一種對(duì)應(yīng)關(guān)系:

  • save時(shí), 表示:variable的值應(yīng)該保存到 checkpoint文件中的哪個(gè) key
  • restore時(shí),表示:checkpoint文件中key對(duì)應(yīng)的值,應(yīng)該restore到哪個(gè)variable

其它

一個(gè)快速找到ckpt文件的方式

ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • Saver的用法 1. Saver的背景介紹 我們經(jīng)常在訓(xùn)練完一個(gè)模型之后希望保存訓(xùn)練的結(jié)果,這些結(jié)果指的是模型的...
    Bruce_Szh閱讀 1,776評(píng)論 0 0
  • 使用TensorFlow訓(xùn)練一個(gè)模型,可以多次運(yùn)行訓(xùn)練操作,并在完成后保存訓(xùn)練參數(shù)的檢查點(diǎn)(checkpoint)...
    是neinei啊閱讀 7,414評(píng)論 1 6
  • 李暖安閱讀 505評(píng)論 12 23
  • 早上的白霧正濃,天才微亮,蛐蛐兒也還在鳴叫,路邊的葉子上滾動(dòng)著晶瑩剔透的露珠。你騎著自行車來到我家窗戶外,俯下身體...
    木木青苔閱讀 271評(píng)論 1 3
  • 2016年10月7日對(duì)我來說是一個(gè)很特殊的日子,一年前的今天,我在糾結(jié)了兩個(gè)月之后終于鼓足勇氣來到素語茶緣,第一次...
    為底遲閱讀 449評(píng)論 3 0

友情鏈接更多精彩內(nèi)容