TensorFlow固化模型

前言

TensorFlow目前在移動(dòng)端是無(wú)法training的,只能跑已經(jīng)訓(xùn)練好的模型,但一般的保存方式只有單一保存參數(shù)或者graph的,如何將參數(shù)、graph同時(shí)保存呢?

生成模型

主要有兩種方法生成模型,一種是通過(guò)freeze_graph把tf.train.write_graph()生成的pb文件與tf.train.saver()生成的chkp文件固化之后重新生成一個(gè)pb文件,這一種現(xiàn)在不太建議使用。另一種是把變量轉(zhuǎn)成常量之后寫(xiě)入PB文件中。我們簡(jiǎn)單的介紹下freeze_graph方法。

freeze_graph

這種方法我們需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代碼如下:

with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.save(session, "model.ckpt")
    tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源碼中的freeze_graph工具進(jìn)行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用這個(gè)工具進(jìn)行固化(/path/to/表示文件路徑):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb

convert_variables_to_constants

其實(shí)在TensorFlow中傳統(tǒng)的保存模型方式是保存常量以及graph的,而我們的權(quán)重主要是變量,如果我們把訓(xùn)練好的權(quán)重變成常量之后再保存成PB文件,這樣確實(shí)可以保存權(quán)重,就是方法有點(diǎn)繁瑣,需要一個(gè)一個(gè)調(diào)用eval方法獲取值之后賦值,再構(gòu)建一個(gè)graph,把W和b賦值給新的graph。

牛逼的Google為了方便大家使用,編寫(xiě)了一個(gè)方法供我們快速的轉(zhuǎn)換并保存。

  • 首先我們需要引入這個(gè)方法
from tensorflow.python.framework.graph_util import convert_variables_to_constants
  • 在想要保存的地方加入如下代碼,把變量轉(zhuǎn)換成常量
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

這里參數(shù)第一個(gè)是當(dāng)前的session,第二個(gè)為graph,第三個(gè)是輸出節(jié)點(diǎn)名(如我的輸出層代碼是這樣的:)

    with tf.name_scope('output'):
        w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
        tf.summary.histogram('output/weight', w_out)
        b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
        tf.summary.histogram('output/biases', b_out)
        out = tf.add(tf.matmul(dense2, w_out), b_out)
        out = tf.nn.softmax(out)
        predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我們采用了name_scope所以我們?cè)?code>predict之前需要加上output/

  • 生成文件
    with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
        f.write(output_graph_def.SerializeToString())

第一個(gè)參數(shù)是文件路徑,第二個(gè)是指文件操作的模式,這里指的是以二進(jìn)制的方式寫(xiě)入文件。

運(yùn)行代碼,系統(tǒng)會(huì)生成一個(gè)PB文件,接下來(lái)我們要測(cè)試下這個(gè)模型是否能夠正常的讀取、運(yùn)行。

測(cè)試模型

在Python環(huán)境下,我們首先需要加載這個(gè)模型,代碼如下:

with open('./model/rounded_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    output = tf.import_graph_def(graph_def,
                                 input_map={'inputs/X:0': newInput_X},
                                 return_elements=['output/predict:0'])

由于我們?cè)镜木W(wǎng)絡(luò)輸入值是一個(gè)placeholder,這里為了方便輸入我們也先定義一個(gè)新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的參數(shù)填入新的placeholder。

在調(diào)用我們的網(wǎng)絡(luò)的時(shí)候直接用這個(gè)新的placeholder接收數(shù)據(jù),如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是運(yùn)行我們的網(wǎng)絡(luò),看是否可以運(yùn)行吧。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 簡(jiǎn)單線(xiàn)性回歸 import tensorflow as tf import numpy # 創(chuàng)造數(shù)據(jù) x_dat...
    CAICAI0閱讀 3,668評(píng)論 0 49
  • 昨天晚上我隨手看看自己的手紋,紋身遍布,手心亂左一團(tuán),想一想自己這么多年真不知道自己干些啥了,感覺(jué)時(shí)間就是這么浪費(fèi)...
    娥眉山閱讀 234評(píng)論 0 0
  • 再回深圳,溫度還沒(méi)有完全升回來(lái)。在十幾度的溫度中,還是能明顯的感受到潮濕,慶幸的是,在出國(guó)之前,沒(méi)有巧遇令人厭惡的...
    壹言肆韻閱讀 857評(píng)論 0 2
  • 《超級(jí)個(gè)體-伽藍(lán)214》201/300,5.29打卡,陽(yáng)光繼續(xù)普照 【三件事】 1. [ ] pm課程學(xué)習(xí)13/9...
    伽藍(lán)214閱讀 131評(píng)論 0 0
  • 周末追完爸爸去哪兒,諾一和妹妹霓娜對(duì)甜食的如癡如醉,讓人久久不能忘懷。 看把咱們霓娜給饞的!再高冷的女神也抵擋不住...
    足記閱讀 767評(píng)論 0 3

友情鏈接更多精彩內(nèi)容