Tensorflow使用SavedModel格式模型

Saved Model是Tensorflow支持的一種保存模型的方式,在使用TF-Serving的時(shí)候需要使用這種格式的模型文件。
下面以mnist手寫數(shù)字識別為例,介紹一下這種格式的save和restore以及如何使用。

  • 保存Saved Model格式模型
    這個可以參照Tensorflow Serving(http://github.com/tensorflow/serving.git)自帶的mnist訓(xùn)練的例子,具體在./tensorflow_serving/example/mnist_saved_model.py文件中,大家可以執(zhí)行一下這個腳本就可以生成mnist模型,并且格式是Saved Model。
    腳本中最后的代碼就是如何保存Saved Model格式的模型文件,如下:
# Export model
  # WARNING(break-tutorial-inline-code): The following code snippet is
  # in-lined in tutorials, please update tutorial documents accordingly
  # whenever code changes.
  export_path_base = sys.argv[-1]
  export_path = os.path.join(
      tf.compat.as_bytes(export_path_base),
      tf.compat.as_bytes(str(FLAGS.model_version)))
  print('Exporting trained model to', export_path)
  builder = tf.saved_model.builder.SavedModelBuilder(export_path)

  # Build the signature_def_map.
  classification_inputs = tf.saved_model.utils.build_tensor_info(
      serialized_tf_example)
  classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
      prediction_classes)
  classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values)

  classification_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={
              tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                  classification_inputs
          },
          outputs={
              tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
                  classification_outputs_classes,
              tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
                  classification_outputs_scores
          },
          method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))

  tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
  tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

  prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'images': tensor_info_x},
          outputs={'scores': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

  legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
  builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          'predict_images':
              prediction_signature,
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              classification_signature,
      },
      legacy_init_op=legacy_init_op)

  builder.save()
  • 恢復(fù)Saved Model模型并推理使用
    這里給出一段示例代碼:
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
import random
from PIL import Image
import sys

signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'images'
output_key = 'scores'


graph = tf.Graph()

export_dir = "./models/mnist/1/"

image_content = []

#for i in range(0, 784):
#       image_content.append(random.random())

img = Image.open(sys.argv[1])
pix = img.load()
width = img.width
height = img.height

print "Image: width = %d, height = %d, mode = %s" %(width, height, img.mode)

for y in range(0, height):
        for x in range(0, width):
                if img.mode == 'P' or img.mode == 'L':
                        print "%3d" % (pix[x, y]),
                        image_content.append(pix[x, y]/255.0)
                elif img.mode == 'RGB':
                        r, g, b = pix[x, y]
                        gray = (0.3 * r) + (0.59 * g) + (0.11 * b)
                        print "%3d" % (int(gray)),
                        image_content.append(gray/255.0)
                else:
                        print "unsupported mode"
        print ""


with tf.Session(graph = graph) as sess:
        meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_dir)
        signature = meta_graph_def.signature_def
        x_tensor_name = signature[signature_key].inputs[input_key].name
        y_tensor_name = signature[signature_key].outputs[output_key].name

        x = sess.graph.get_tensor_by_name(x_tensor_name)
        y = sess.graph.get_tensor_by_name(y_tensor_name)



        y_out = sess.run(y, feed_dict = {x: [image_content]})

        print '---------- inference results ----------------'
        print y_out

選擇一張MNIST手寫圖片進(jìn)行測試,如下效果:


image.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容