最近在用 java 改寫一個用 python 編寫的 model,遇到了有關模型保存與恢復的問題,發(fā)現(xiàn)網上的資料有些混亂,在這里做一些記錄。
.ckpt
1. .ckpt 全稱為 checkpoint,代表著一個檢查點,即為 model 訓練過程中的一個快照,可能是在訓練開始,也可能是在訓練完成。
2. .ckpt 是由 Saver 調用 save 產生的:
saver.save(sess,"/tmp/model.ckpt")
3. 由 Saver 調用 restore 來復原 model 的數(shù)據:
saver.restore(sess,path)
注意這里,復原的只有數(shù)據,不含 graph 信息。
4. .ckpt 不是單獨的一個文件,而是一系列文件。

其內部包含了:
①checkpoint: .ckpt 的標記信息。
②.data: model 中 graph 的數(shù)據,包括各種變量,不含常量。
③.index: 索引信息。
④.meta: graph 信息。
在這里要搞明白一點,一個 model 是由 graph(④) + 數(shù)據(②) 組成的。
graph 代表著執(zhí)行邏輯,在 tensorflow 中,每個算子用一個 node 來表示,眾多 node 組合起來便是一張圖(graph),也就是我們的執(zhí)行邏輯,而這些執(zhí)行邏輯在 Saver 調用 save 時,會被存到 .meta 中(不含數(shù)據)。各個 node 中含有各種參數(shù)(變量,比如訓練的權重),這些參數(shù)則被存儲到 .data 中。graph 與數(shù)據是分別存儲的。
tf.train.import_meta_graph
該方法只能恢復 graph,不恢復數(shù)據。
注意與上面提及的 saver.restore 區(qū)分,saver.restore 只恢復數(shù)據,不恢復 graph。
recover model
現(xiàn)在我們來討論下,如何能恢復一個model。前面已經提過了,一個 model 由 graph 和 數(shù)據組成,所以只要能恢復這兩部分就可以了,依據恢復的方法不同,可以分為兩類。
①分別恢復 graph 和數(shù)據:
對于數(shù)據來說,可以用 saver.restore 來恢復。
對于graph來說,依據恢復方法不同可以分為兩種:
A.硬編碼恢復:在調用方法中,重新書寫 graph 信息。
B. .meta 恢復:通過調用 tf.train.import_meta_graph 方法獲得 graph,并配合 get_tensor_by_name 的方法來調用 model 中特定的算子(node)。
saver = tf.train.import_meta_graph('~/tmp/model.ckpt-1000.meta')
graph = tf.get_default_graph()
input = graph.get_tensor_by_name('input:0')
② freezing(固化):
該方法將變量(訓練的權重)固化在 graph 中,即用常量來替換 graph 中的變量,從而達到無需恢復數(shù)據,直接調用 graph 即可。權重一旦被固化就不能再修改,該方法一般用于生產環(huán)境。
注:筆者在測試 Java API 時,其只支持調用 freezing 后的圖。
References: