將tensorflow的ckpt模型轉(zhuǎn)化為pb模型

標(biāo)簽:tensorflow
作者:煉己者


本博客所有內(nèi)容以學(xué)習(xí)、研究和分享為主,如需轉(zhuǎn)載,請聯(lián)系本人,標(biāo)明作者和出處,并且是非商業(yè)用途,謝謝!


1.摘要

為什么好好的ckpt模型我要把它轉(zhuǎn)為pb模型呢?因?yàn)槲掖蛩惆裵ython版本的代碼轉(zhuǎn)化為c++版本,最好的方法就是把python訓(xùn)練好的模型直接移植過去。但是ckpt模型不能用,所以我要想辦法把訓(xùn)練好的ckpt模型轉(zhuǎn)化為pb模型,然后再用c++調(diào)用這個模型。

2.ckpt轉(zhuǎn)化為pb

我們用tensorflow訓(xùn)練模型,一般是用tf.train.Saver()保存模型,然后得到多個文件,一般長這個樣子

這四個文件主要是記錄了神經(jīng)網(wǎng)絡(luò)的網(wǎng)絡(luò)結(jié)構(gòu)以及這個結(jié)構(gòu)中涉及到的權(quán)重參數(shù)等內(nèi)容。

具體怎么轉(zhuǎn)呢?詳細(xì)代碼看ckpt轉(zhuǎn)化為pb
點(diǎn)進(jìn)去之后,直接clone即可。然后可以得到以下文件

checkpoints里面裝著模型文件,點(diǎn)進(jìn)去會發(fā)現(xiàn)里面有一個textcnn文件夾,這里面裝著ckpt模型的四個文件。通過convert_ckpt_to_pb.py可以得到pb模型文件,這個文件也保存在checkpoints文件夾中,是frozen_model.pb。

def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路徑
    :return:
    '''
    # 指定輸出的節(jié)點(diǎn)名稱,該節(jié)點(diǎn)名稱必須是原模型中存在的節(jié)點(diǎn)
    # 直接用最后輸出的節(jié)點(diǎn),可以在tensorboard中查找到,tensorboard只能在linux中使用
    output_node_names = "score/output"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() # 獲得默認(rèn)的圖
    input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當(dāng)前的圖
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢復(fù)圖并得到數(shù)據(jù)
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多個輸出節(jié)點(diǎn),以逗號隔開
 
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化輸出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到當(dāng)前圖有幾個操作節(jié)點(diǎn)

代碼很簡單,里面你唯一要改動的就是output_node_names,指定的最后一層輸出節(jié)點(diǎn)名稱,這個是你自己設(shè)定的。怎么找到它呢?首先你要去查看你的代碼,
下面的代碼是我自己定義的,你可以點(diǎn)擊 訓(xùn)練ckpt模型,clone到本地,然后你會找到cnn_model.py這個文件,這里面就是定義著我的cnn模型

        with tf.name_scope("score"):
            # 全連接層,后面接dropout以及relu激活
            fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)

            # 分類器
            self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
            self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1,name='output')  # 預(yù)測類別

我一開始也沒明白,然后翻閱大量資料才發(fā)現(xiàn)的,這個要自己改動,記得是最后輸出的那層,你要自己把這個節(jié)點(diǎn)命名出來。另外怎么才能確定你得圖結(jié)構(gòu)里有這個節(jié)點(diǎn)呢?去查看tensorboard。不知道為啥在windows下得到的tensorboard查看不了,得到的是亂碼文件。在linux下就不會,所以我把代碼放到linux環(huán)境下跑,就可以得到tensoboard了。現(xiàn)在要跑的這個代碼是指你自己訓(xùn)練ckpt模型的代碼,如果你的tensorboard可以查看,那么就不用重新到linux環(huán)境下跑了。

查看tensorboard

tensorboard --logdir = “保存tensorboard的絕對路徑”

敲入上面的命令,然后就可以得到一個網(wǎng)址,把這個網(wǎng)址復(fù)制到瀏覽器上打開,就可以得到圖的結(jié)構(gòu),然后你點(diǎn)開看看,有沒有output這個節(jié)點(diǎn),也可以順便看一下你自己的網(wǎng)絡(luò)圖
查看tensorboard的方法可以看這篇博客,TensorBoard:計算圖的查看

有的話,把節(jié)點(diǎn)寫入output_node_names,改這一個就行了。
改好之后,運(yùn)行這個文件,便可以得到pb模型了。

3.pb模型文件測試數(shù)據(jù)

然后怎么用pb模型像ckpt模型那樣來測試呢?c++版本的我還沒開始做,但是Python版本的可以跑通測試了。這意味著我離c++跑通更進(jìn)一步了。
這個具體去看pb_test.py文件,里面有很完整的注釋。網(wǎng)上很多都是針對圖像的資料,我這里是針對的文本的,相信對大家會有所幫助!
另外注意tensorflow是一個batch一個batch的逐批次輸入,一次性輸入的話會報內(nèi)存錯誤!

4.總結(jié)

這兩份代碼大家可以對著看,很相似,理解怎么訓(xùn)練出來的ckpt模型,又是怎么把ckpt模型轉(zhuǎn)化為pb模型的,怎么用pb模型去測試數(shù)據(jù),代碼很好理解的。有不懂的歡迎在博客下面評論提問,我們一起交流解決
等后期我用c++調(diào)用pb模型成功后,再來和大家分享?。。?/p>

以下是我所有文章的目錄,大家如果感興趣,也可以前往查看
??戳右邊:打開它,也許會看到很多對你有幫助的文章

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容