Saver的用法
1. Saver的背景介紹
我們經(jīng)常在訓(xùn)練完一個(gè)模型之后希望保存訓(xùn)練的結(jié)果,這些結(jié)果指的是模型的參數(shù),以便下次迭代的訓(xùn)練或者用作測(cè)試。Tensorflow針對(duì)這一需求提供了Saver類。
- Saver類提供了向checkpoints文件保存和從checkpoints文件中恢復(fù)變量的相關(guān)方法。Checkpoints文件是一個(gè)二進(jìn)制文件,它把變量名映射到對(duì)應(yīng)的tensor值 。
- 只要提供一個(gè)計(jì)數(shù)器,當(dāng)計(jì)數(shù)器觸發(fā)時(shí),Saver類可以自動(dòng)的生成checkpoint文件。這讓我們可以在訓(xùn)練過程中保存多個(gè)中間結(jié)果。例如,我們可以保存每一步訓(xùn)練的結(jié)果。
- 為了避免填滿整個(gè)磁盤,Saver可以自動(dòng)的管理Checkpoints文件。例如,我們可以指定保存最近的N個(gè)Checkpoints文件。
2. Saver的實(shí)例
下面以一個(gè)例子來講述如何使用Saver類
1. import tensorflow as tf
2. import numpy as np
4. x = tf.placeholder(tf.float32, shape=[None, 1])
5. y = 4 * x + 4
7. w = tf.Variable(tf.random_normal([1], -1, 1))
8. b = tf.Variable(tf.zeros([1]))
9. y_predict = w * x + b
12. loss = tf.reduce_mean(tf.square(y - y_predict))
13. optimizer = tf.train.GradientDescentOptimizer(0.5)
14. train = optimizer.minimize(loss)
16. isTrain = False
17. train_steps = 100
18. checkpoint_steps = 50
19. checkpoint_dir = ''
21. saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
22. x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
24. with tf.Session() as sess:
25. sess.run(tf.initialize_all_variables())
26. if isTrain:
27. for i in xrange(train_steps):
28. sess.run(train, feed_dict={x: x_data})
29. if (i + 1) % checkpoint_steps == 0:
30. saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
31. else:
32. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
33. if ckpt and ckpt.model_checkpoint_path:
34. saver.restore(sess, ckpt.model_checkpoint_path)
35. else:
36. pass
37. print(sess.run(w))
38. print(sess.run(b))
isTrain:用來區(qū)分訓(xùn)練階段和測(cè)試階段,True表示訓(xùn)練,F(xiàn)alse表示測(cè)試
train_steps:表示訓(xùn)練的次數(shù),例子中使用100
checkpoint_steps:表示訓(xùn)練多少次保存一下checkpoints,例子中使用50
checkpoint_dir:表示checkpoints文件的保存路徑,例子中使用當(dāng)前路徑
2.1 訓(xùn)練階段
使用Saver.save()方法保存模型:
- sess:表示當(dāng)前會(huì)話,當(dāng)前會(huì)話記錄了當(dāng)前的變量值
- checkpoint_dir + 'model.ckpt':表示存儲(chǔ)的文件名
- global_step:表示當(dāng)前是第幾步
訓(xùn)練完成后,當(dāng)前目錄底下會(huì)多出5個(gè)文件。
打開名為“checkpoint”的文件,可以看到保存記錄,和最新的模型存儲(chǔ)位置。
2.2測(cè)試階段
測(cè)試階段使用saver.restore()方法恢復(fù)變量:
sess:表示當(dāng)前會(huì)話,之前保存的結(jié)果將被加載入這個(gè)會(huì)話
-
ckpt.model_checkpoint_path:表示模型存儲(chǔ)的位置,不需要提供模型的名字,它會(huì)去查看checkpoint文件,看看最新的是誰,叫做什么。
運(yùn)行結(jié)果如下圖所示,加載了之前訓(xùn)練的參數(shù)w和b的結(jié)果
部分變量保存
默認(rèn)情況下,saver.save()存儲(chǔ)圖形的所有變量,一般建議這樣做。但是,當(dāng)我們創(chuàng)建保存對(duì)象時(shí),您還可以通過將它們作為列表或字典傳入來選擇要存儲(chǔ)的變量。
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# pass them as a list
saver = tf.train.Saver([v1, v2])
# passing a list is equivalent to passing a dict with the variable op names # as keys
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
這里使用了三種不同的方式來創(chuàng)建 saver 對(duì)象, 但是它們內(nèi)部的原理是一樣的。我們都知道,參數(shù)會(huì)保存到 checkpoint 文件中,通過鍵值對(duì)的形式在 checkpoint中存放著。如果 Saver 的構(gòu)造函數(shù)中傳的是 dict,那么在 save 的時(shí)候,checkpoint文件中存放的就是對(duì)應(yīng)的 key-value。如下:
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, 'test-ckpt/model-2')
我們通過官方提供的工具來看一下 checkpoint 中保存了什么
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("test-ckpt/model-2", None, True)
# 輸出:
#tensor_name: variable_1
#1.0
#tensor_name: variable_2
#2.0
如果構(gòu)建saver對(duì)象的時(shí)候,我們傳入的是 list, 那么將會(huì)用對(duì)應(yīng) Variable 的 variable.op.name 作為 key。
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
saver = tf.train.Saver([v1, v2])
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, 'test-ckpt/model-2')
我們?cè)偈褂霉俜焦ぞ叽蛴〕?checkpoint 中的數(shù)據(jù),得到
tensor_name: v1
1.0
tensor_name: v2
2.0
如果我們現(xiàn)在想將 checkpoint 中v2的值restore到v1 中,v1的值restore到v2中,我們?cè)撛趺醋觯?/strong>
這時(shí),我們只能采用基于 dict 的 saver
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, 'test-ckpt/model-2')
save 部分的代碼如上所示,下面寫 restore 的代碼,和save代碼有點(diǎn)不同。
```python
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
#restore的時(shí)候,variable_1對(duì)應(yīng)到v2,variable_2對(duì)應(yīng)到v1,就可以實(shí)現(xiàn)目的了。
saver = tf.train.Saver({"variable_1":v2, "variable_2": v1})
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.restore(sess, 'test-ckpt/model-2')
print(sess.run(v1), sess.run(v2))
# 輸出的結(jié)果是 2.0 1.0,如我們所望
我們發(fā)現(xiàn),其實(shí) 創(chuàng)建 saver對(duì)象時(shí)使用的鍵值對(duì)就是表達(dá)了一種對(duì)應(yīng)關(guān)系:
- save時(shí), 表示:
variable的值應(yīng)該保存到checkpoint文件中的哪個(gè)key下 - restore時(shí),表示:
checkpoint文件中key對(duì)應(yīng)的值,應(yīng)該restore到哪個(gè)variable
其它
一個(gè)快速找到ckpt文件的方式
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)