pb文件的能夠保存tensorflow計(jì)算圖中的操作節(jié)點(diǎn)以及對應(yīng)的各張量,方便我們?nèi)蘸笾苯诱{(diào)用之前已經(jīng)訓(xùn)練好的計(jì)算圖。
本文代碼的運(yùn)行軟件為pycharm
保存pb文件
下面的代碼展示了最簡單的tensorflow四則運(yùn)算計(jì)算圖
import tensorflow as tf
x = tf.placeholder(tf.float32,name="input")
a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
c = tf.Variable(tf.constant(10.,shape=[1]),name="c")
d = tf.Variable(tf.constant(2.,shape=[1]),name="d")
tensor1 = tf.multiply(a,b,"mul")
tensor2 = tf.subtract(tensor1,c,"sub")
tensor3 = tf.div(tensor2,d,"div")
result = tf.add(tensor3,x,"add")
inial = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(inial)
print(sess.run(a))
print(result)
result = sess.run(result,feed_dict={x:1.0})
print(result)
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
with tf.gfile.FastGFile("wsj.pb", mode='wb') as f:
f.write(constant_graph.SerializeToString())
保存pb文件的功能主要是通過最后三行代碼實(shí)現(xiàn)的
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
with tf.gfile.FastGFile("wsj.pb", mode='wb') as f:
f.write(constant_graph.SerializeToString())
第一行代碼的作用是將計(jì)算圖中的變量轉(zhuǎn)化為常量,并指定輸出節(jié)點(diǎn)為“add”
第二行代碼用來生成一個(gè)名為wsj.pb的文件(未指定路徑的話,默認(rèn)在該python代碼的同路徑下生成)
第三行代碼的作用是將計(jì)算圖寫入該pb文件中
讀取pb文件
import tensorflow as tf
with tf.gfile.FastGFile("wsj.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
result, x = tf.import_graph_def(graph_def,return_elements=["add:0", "input:0"])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(a))
result = sess.run(result, feed_dict={x: 5.0})
print(result)
上面代碼主要分為兩部分:讀取pb文件并設(shè)置為默認(rèn)的計(jì)算圖;填充一個(gè)新的x值來計(jì)算結(jié)果。
讀取pb文件時(shí)候需要注意的是,若要獲取對應(yīng)的張量必須用“tensor_name:0”的形式,這是tensorflow默認(rèn)的。
若您覺得本文章對您有用,請您為我點(diǎn)上一顆小心心以表支持。感謝!