tensorflow的基本用法(十)——保存神經網絡參數和加載神經網絡參數

文章作者:Tyan
博客:noahsnail.com ?|? CSDN ?|? 簡書

本文主要是使用tensorfl保存神經網絡參數和加載神經網絡參數。

#!/usr/bin/env python
# _*_ coding: utf-8 _*_

import tensorflow as tf
import numpy as np


# 保存神經網絡參數
def save_para():
    # 定義權重參數
    W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype = tf.float32, name = 'weights')
    # 定義偏置參數
    b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = 'biases')
    # 參數初始化
    init = tf.global_variables_initializer()
    # 定義保存參數的saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init)
        # 保存session中的數據
        save_path = saver.save(sess, 'my_net/save_net.ckpt')
        # 輸出保存路徑
        print 'Save to path: ', save_path

# 恢復神經網絡參數
def restore_para():
    # 定義權重參數
    W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = tf.float32, name = 'weights')
    # 定義偏置參數
    b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = tf.float32, name = 'biases')
    # 定義提取參數的saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # 加載文件中的參數數據,會根據name加載數據并保存到變量W和b中
        save_path = saver.restore(sess, 'my_net/save_net.ckpt')
        # 輸出保存路徑
        print 'Weights: ', sess.run(W)
        print 'biases:  ', sess.run(b)


# save_para()
restore_para()

執(zhí)行結果如下:

# save
Save to path:  my_net/save_net.ckpt


# restore
Weights:  [[ 1.  2.  3.]
 [ 4.  5.  6.]]
biases:   [[ 1.  2.  3.]]
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

  • 風林一葉落 露草百蟲鳴 我從未想過,有那么一天,我會遇見你,陽光,剛好,我不樂觀,...
    風起loyi閱讀 801評論 0 3
  • 我是一個大俠,有著一身好武功,俠義肝膽的氣質,浪跡天涯的性格。會說詩,好喝酒,酒勁一來就作一首詩,路見不平就拔...
    TKD小勝閱讀 1,102評論 0 6
  • 死亡是人生最大的一個課題。因為死亡是專屬于人的。有人說其他的生命也會死亡,不僅僅是人。不錯,任何自然生命都有一個衰...
    妙所閱讀 1,033評論 2 6
  • ?1.扉頁 1.1 本話扉頁是應讀者要求,羅賓的畫展中一幅名為《KAWAII》的作品引來了藝術大師和記者的圍觀,羅...
    頑皮仕閱讀 736評論 0 0

友情鏈接更多精彩內容