????????上一篇文章 TensorFlow 訓(xùn)練 CNN 分類器 中說(shuō)明了訓(xùn)練簡(jiǎn)單 CNN 模型的整個(gè)過(guò)程,并在訓(xùn)練結(jié)束后使用 .save 函數(shù)來(lái)保存訓(xùn)練的結(jié)果,其后通過(guò)使用 tf.train.import_meta_graph 和 .restore 函數(shù)來(lái)導(dǎo)入模型進(jìn)行推斷。本文承接上文,對(duì)模型保存與恢復(fù)做一個(gè)總結(jié)。
????????總的來(lái)說(shuō),模型在保存和恢復(fù)時(shí)最重要的是留下數(shù)據(jù)接口,方便使用時(shí)傳入數(shù)據(jù)和獲取結(jié)果。TensorFlow 中常用的模型保存格式為 .ckpt 和 .pb,下面分別進(jìn)行詳細(xì)說(shuō)明。
一、ckpt 格式模型保存與恢復(fù)
????????.ckpt 格式保存與恢復(fù)都很簡(jiǎn)單,具體可參考 TensorFlow 訓(xùn)練 CNN 分類器。
1. ckpt 格式模型保存
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs') <-- 入口
···
prediction = tf.nn.softmax(logits, name='prediction') <-- 出口(僅作為例子,下同)
···
saver = tf.train.Saver()
···
with tf.Session() as sess:
··· <-- 訓(xùn)練過(guò)程
saver.save(sess, './xxx/xxx.ckpt') <-- 模型保存
????????如上述代碼所示,假設(shè)你定義了一個(gè) TensorFlow 模型,數(shù)據(jù)入口由占位符 inputs 給定,結(jié)果出口由張量 prediction 給定。通過(guò)語(yǔ)句 saver = tf.train.Saver() 定義了模型保存的一個(gè)實(shí)例對(duì)象 saver,當(dāng)模型訓(xùn)練結(jié)束之后只需要簡(jiǎn)單的一條語(yǔ)句:
saver.save(sess, path_to_model.ckpt)
就把訓(xùn)練結(jié)果保存到了指定的路徑。
????????以上代碼之所以把變量 inputs 和 prediction 單獨(dú)列出,一方面是因?yàn)樗鼈兪悄P?Graph 的起點(diǎn)和終點(diǎn)(戲稱為數(shù)據(jù)入口、出口),另一方面的原因是它們被特別的指定了名稱,因而在模型恢復(fù)時(shí)可以通過(guò)它們的名稱而得到 Graph 中對(duì)應(yīng)的節(jié)點(diǎn)。
2. ckpt 格式模型恢復(fù)
????????當(dāng)你需要導(dǎo)入模型進(jìn)行推斷時(shí),只需要通過(guò)張量名獲取數(shù)據(jù)入口和出口,然后傳入數(shù)據(jù)即可:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
saver.restore(sess, './xxx/xxx.ckpt')
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
prediction = tf.get_default_graph().get_tensor_by_name('prediction:0')
pred = sess.run(prediction, feed_dict={inputs: xxx}
????????保存為 .ckpt 模型的一個(gè)好處是,當(dāng)需要繼續(xù)訓(xùn)練時(shí),只需要將訓(xùn)練過(guò)的模型結(jié)果導(dǎo)入,然后在這個(gè)基礎(chǔ)上再繼續(xù)訓(xùn)練。而下面的 .pb 格式則不能繼續(xù)訓(xùn)練,因?yàn)檫@種格式保存的模型參數(shù)都已經(jīng)轉(zhuǎn)化為了常量(而不再是變量)。
二、pb 格式模型保存與恢復(fù)
????????.pb 格式模型保存與恢復(fù)相比于前面的 .ckpt 格式而言要稍微麻煩一點(diǎn),但使用更靈活,特別是模型恢復(fù),因?yàn)樗梢悦撾x會(huì)話(Session)而存在,便于部署。
1. pb 格式模型保存
????????與 .ckpt 格式模型保存類似,首先定義數(shù)據(jù)入口、出口:
from tensorflow.python.framework import graph_util
···
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs')
···
prediction = tf.nn.softmax(logits, name='prediction')
···
with tf.Session() as sess:
··· <-- 訓(xùn)練過(guò)程
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['prediction'] <-- 參數(shù):output_node_names,輸出節(jié)點(diǎn)名
)
with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
serialized_graph = output_graph_def.SerializeToString()
fid.write(serialized_graph)
然后通過(guò)函數(shù) graph_util.convert_variables_to_constants 將模型固話,使得所有變量轉(zhuǎn)化為常量,之后寫入到指定的路徑完成模型保存過(guò)程。
2. pb 格式模型恢復(fù)
????????.pb 格式模型恢復(fù)自由度較大,不需要在會(huì)話里進(jìn)行操作,可以獨(dú)立存在:
import os
def load_model(path_to_model.pb):
if not os.path.exists(path_to_model.pb):
raise ValueError("'path_to_model.pb' is not exist.")
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model.pb, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return model_graph
模型導(dǎo)入之后,便可以獲取數(shù)據(jù)入口和出口,然后進(jìn)行推斷:
model_graph = load_model('./xxx/xxx.pb')
inputs = model_graph.get_tensor_by_name('inputs:0')
prediction = model_graph.get_tensor_by_name('prediction:0')
with model_graph.as_default():
with tf.Session(graph=model_graph) as sess:
···
pred = sess.run(prediction, feed_dict={inputs: xxx}
三、ckpt 格式轉(zhuǎn) pb 格式
????????一般情況下,為了便于從斷點(diǎn)之處繼續(xù)訓(xùn)練,模型通常保存為 .ckpt 格式,而一旦對(duì)訓(xùn)練結(jié)果很滿意之后則可能需要將 .ckpt 格式轉(zhuǎn)化為 .pb 格式。轉(zhuǎn)化方法很簡(jiǎn)單,只需要綜合前面的一、二兩步即可:
from tensorflow.python.framework import graph_util
with tf.Session() as sess:
# Load .ckpt file
saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
saver.restore(sess, './xxx/xxx.ckpt')
# Save as .pb file
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['prediction'] <-- 輸出節(jié)點(diǎn)名,以實(shí)際情況為準(zhǔn)
)
with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
serialized_graph = output_graph_def.SerializeToString()
fid.write(serialized_graph)
????????預(yù)告:下一篇文章將簡(jiǎn)單介紹 tensorflow.contrib.slim 的應(yīng)用,敬請(qǐng)關(guān)注!