tf.train.saver保存模型,加載時提示變量未初始化

問題描述:

使用TensorFlow的saver()方法保存模型,在加載模型時提示變量未初始化。

import tensorflow as tf

#定義變量
b1 = tf.Variable(2.0, name="bias1")
#...

# 創(chuàng)建一個Saver對象,用于保存所有變量
saver = tf.train.Saver()

graph = tf.get_default_graph()
with tf.Session(graph=graph) as sess:
  sess.run(tf.global_variables_initializer())

  # ...執(zhí)行操作...

  # 保存模型
  saver.save(sess, './checkpoint_dir/MyModel', global_step=1000)
error : Attempting to use uninitialized value

問題解決:

把saver = tf.train.Saver()這條語句移到with tf.Session() as sess里面就可以了。
(出現(xiàn)這種情況的一般是代碼中有多個graph)

import tensorflow as tf

#定義變量
b1 = tf.Variable(2.0, name="bias1")
#...
graph = tf.get_default_graph()
with tf.Session(graph=graph) as sess:
  sess.run(tf.global_variables_initializer())

  # 創(chuàng)建一個Saver對象,用于保存所有變量
  saver = tf.train.Saver()

  # ...執(zhí)行操作...

  # 保存模型
  saver.save(sess, './checkpoint_dir/MyModel', global_step=1000)

注意:
saver = tf.train.Saver() 這條語句會影響訓(xùn)練模型時GPU的利用率,切不可放在循環(huán)里面!


對于保存某次迭代的模型,直接在if語句下使用saver = tf.train.Saver(),加載模型時報變量未初始化錯誤的話,可以在外部在加一條saver = tf.train.Saver()語句。

    with tf.Session(graph=graph) as sess:
        tf.initialize_all_variables().run()
        # 外部加一條saver語句
        saver = tf.train.Saver()
        for i in range(num_steps):
            # training
            results = sess.run([xxx])
            if i == 1999:
                saver = tf.train.Saver()
                print("-----saving-----")
                save_path = saver.save(sess, model_path, global_step=i)
                print("-----saved-----")
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容