TensorFlow 模型保存與恢復(fù)

????????上一篇文章 TensorFlow 訓(xùn)練 CNN 分類器 中說(shuō)明了訓(xùn)練簡(jiǎn)單 CNN 模型的整個(gè)過(guò)程,并在訓(xùn)練結(jié)束后使用 .save 函數(shù)來(lái)保存訓(xùn)練的結(jié)果,其后通過(guò)使用 tf.train.import_meta_graph.restore 函數(shù)來(lái)導(dǎo)入模型進(jìn)行推斷。本文承接上文,對(duì)模型保存與恢復(fù)做一個(gè)總結(jié)。

????????總的來(lái)說(shuō),模型在保存和恢復(fù)時(shí)最重要的是留下數(shù)據(jù)接口,方便使用時(shí)傳入數(shù)據(jù)和獲取結(jié)果。TensorFlow 中常用的模型保存格式為 .ckpt 和 .pb,下面分別進(jìn)行詳細(xì)說(shuō)明。

一、ckpt 格式模型保存與恢復(fù)

????????.ckpt 格式保存與恢復(fù)都很簡(jiǎn)單,具體可參考 TensorFlow 訓(xùn)練 CNN 分類器。

1. ckpt 格式模型保存

inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs')  <-- 入口
···
prediction = tf.nn.softmax(logits, name='prediction')  <-- 出口(僅作為例子,下同)
···
saver = tf.train.Saver()
···

with tf.Session() as sess:
    ···    <-- 訓(xùn)練過(guò)程
    saver.save(sess, './xxx/xxx.ckpt')  <-- 模型保存

????????如上述代碼所示,假設(shè)你定義了一個(gè) TensorFlow 模型,數(shù)據(jù)入口由占位符 inputs 給定,結(jié)果出口由張量 prediction 給定。通過(guò)語(yǔ)句 saver = tf.train.Saver() 定義了模型保存的一個(gè)實(shí)例對(duì)象 saver,當(dāng)模型訓(xùn)練結(jié)束之后只需要簡(jiǎn)單的一條語(yǔ)句:

saver.save(sess, path_to_model.ckpt)

就把訓(xùn)練結(jié)果保存到了指定的路徑。

????????以上代碼之所以把變量 inputsprediction 單獨(dú)列出,一方面是因?yàn)樗鼈兪悄P?Graph 的起點(diǎn)和終點(diǎn)(戲稱為數(shù)據(jù)入口、出口),另一方面的原因是它們被特別的指定了名稱,因而在模型恢復(fù)時(shí)可以通過(guò)它們的名稱而得到 Graph 中對(duì)應(yīng)的節(jié)點(diǎn)。

2. ckpt 格式模型恢復(fù)

????????當(dāng)你需要導(dǎo)入模型進(jìn)行推斷時(shí),只需要通過(guò)張量名獲取數(shù)據(jù)入口和出口,然后傳入數(shù)據(jù)即可:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
    saver.restore(sess, './xxx/xxx.ckpt')

    inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
    prediction = tf.get_default_graph().get_tensor_by_name('prediction:0')

    pred = sess.run(prediction, feed_dict={inputs: xxx}

????????保存為 .ckpt 模型的一個(gè)好處是,當(dāng)需要繼續(xù)訓(xùn)練時(shí),只需要將訓(xùn)練過(guò)的模型結(jié)果導(dǎo)入,然后在這個(gè)基礎(chǔ)上再繼續(xù)訓(xùn)練。而下面的 .pb 格式則不能繼續(xù)訓(xùn)練,因?yàn)檫@種格式保存的模型參數(shù)都已經(jīng)轉(zhuǎn)化為了常量(而不再是變量)。

二、pb 格式模型保存與恢復(fù)

????????.pb 格式模型保存與恢復(fù)相比于前面的 .ckpt 格式而言要稍微麻煩一點(diǎn),但使用更靈活,特別是模型恢復(fù),因?yàn)樗梢悦撾x會(huì)話(Session)而存在,便于部署。

1. pb 格式模型保存

????????與 .ckpt 格式模型保存類似,首先定義數(shù)據(jù)入口、出口:

from tensorflow.python.framework import graph_util

···
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs') 
···
prediction = tf.nn.softmax(logits, name='prediction') 
···

with tf.Session() as sess:
    ···    <-- 訓(xùn)練過(guò)程
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, 
        graph_def, 
        ['prediction']  <-- 參數(shù):output_node_names,輸出節(jié)點(diǎn)名
    )
    with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
        serialized_graph = output_graph_def.SerializeToString()
        fid.write(serialized_graph)

然后通過(guò)函數(shù) graph_util.convert_variables_to_constants 將模型固話,使得所有變量轉(zhuǎn)化為常量,之后寫入到指定的路徑完成模型保存過(guò)程。

2. pb 格式模型恢復(fù)

????????.pb 格式模型恢復(fù)自由度較大,不需要在會(huì)話里進(jìn)行操作,可以獨(dú)立存在:

import os

def load_model(path_to_model.pb):
    if not os.path.exists(path_to_model.pb):
        raise ValueError("'path_to_model.pb' is not exist.")

    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(path_to_model.pb, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return model_graph

模型導(dǎo)入之后,便可以獲取數(shù)據(jù)入口和出口,然后進(jìn)行推斷:

model_graph = load_model('./xxx/xxx.pb')

inputs = model_graph.get_tensor_by_name('inputs:0')
prediction = model_graph.get_tensor_by_name('prediction:0')

with model_graph.as_default():
    with tf.Session(graph=model_graph) as sess:
        ···
        pred = sess.run(prediction, feed_dict={inputs: xxx}

三、ckpt 格式轉(zhuǎn) pb 格式

????????一般情況下,為了便于從斷點(diǎn)之處繼續(xù)訓(xùn)練,模型通常保存為 .ckpt 格式,而一旦對(duì)訓(xùn)練結(jié)果很滿意之后則可能需要將 .ckpt 格式轉(zhuǎn)化為 .pb 格式。轉(zhuǎn)化方法很簡(jiǎn)單,只需要綜合前面的一、二兩步即可:

from tensorflow.python.framework import graph_util

with tf.Session() as sess:
    # Load .ckpt file
    saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
    saver.restore(sess, './xxx/xxx.ckpt')

    # Save as .pb file
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, 
        graph_def, 
        ['prediction']  <-- 輸出節(jié)點(diǎn)名,以實(shí)際情況為準(zhǔn)
    )
    with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
        serialized_graph = output_graph_def.SerializeToString()
        fid.write(serialized_graph)

????????預(yù)告:下一篇文章將簡(jiǎn)單介紹 tensorflow.contrib.slim 的應(yīng)用,敬請(qǐng)關(guān)注!

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容