文章作者:Tyan
博客:noahsnail.com ?|? CSDN ?|? 簡書
本文主要是使用tensorfl保存神經網絡參數和加載神經網絡參數。
#!/usr/bin/env python
# _*_ coding: utf-8 _*_
import tensorflow as tf
import numpy as np
# 保存神經網絡參數
def save_para():
# 定義權重參數
W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype = tf.float32, name = 'weights')
# 定義偏置參數
b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = 'biases')
# 參數初始化
init = tf.global_variables_initializer()
# 定義保存參數的saver
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
# 保存session中的數據
save_path = saver.save(sess, 'my_net/save_net.ckpt')
# 輸出保存路徑
print 'Save to path: ', save_path
# 恢復神經網絡參數
def restore_para():
# 定義權重參數
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = tf.float32, name = 'weights')
# 定義偏置參數
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = tf.float32, name = 'biases')
# 定義提取參數的saver
saver = tf.train.Saver()
with tf.Session() as sess:
# 加載文件中的參數數據,會根據name加載數據并保存到變量W和b中
save_path = saver.restore(sess, 'my_net/save_net.ckpt')
# 輸出保存路徑
print 'Weights: ', sess.run(W)
print 'biases: ', sess.run(b)
# save_para()
restore_para()
執(zhí)行結果如下:
# save
Save to path: my_net/save_net.ckpt
# restore
Weights: [[ 1. 2. 3.]
[ 4. 5. 6.]]
biases: [[ 1. 2. 3.]]