tf.Variable()參數(shù)
tf.Variable(initial_value=None,
trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE,
shape=None
)
經(jīng)常使用的參數(shù)有initial_value、name、shape三個(gè),分別是初始化,命名和規(guī)定所需要的形狀大小。舉個(gè)例子:
import tensorflow as tf
v1=tf.Variable(tf.random_normal(shape=[4,3],mean=0,stddev=1),name='v1')
v2=tf.Variable(tf.constant(2),name='v2')
v3=tf.Variable(tf.ones([4,3]),name='v3')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print 'v1:\n',sess.run(v1)
print 'v2:\n',sess.run(v2)
print 'v3:\n',sess.run(v3)
運(yùn)行結(jié)果
v1:
[[ 0.4027793 0.72299665 -1.4619899 ]
[-1.7155927 -0.8806208 -0.39554796]
[-0.4185343 -1.562368 1.9035501 ]
[-0.7704326 -1.9970375 2.224315 ]]
v2:
2
v3:
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
tf.get_variable()參數(shù)
tf.get_variable(name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE
)
與tf.Variable()一樣,經(jīng)常使用的參數(shù)有initial_value、name、shape三個(gè),分別是初始化,命名和規(guī)定所需要的形狀大小。舉個(gè)例子:
import tensorflow as tf
v1 = tf.get_variable(name='v1', shape=[2,3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(1))
v3 = tf.get_variable(name='v3', shape=[2,3], initializer=tf.ones_initializer())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(v1)
print sess.run(v2)
print sess.run(v2)
運(yùn)行結(jié)果如下:
v1:
[[-0.06989016 0.44355923 -1.2255034 ]
[ 0.46685636 -0.8572208 -0.16504966]]
v2:
[1.]
v3:
[[1. 1. 1.]
[1. 1. 1.]]
tf.Variable()、tf.get_variable() 兩者區(qū)別
tf.get_variable創(chuàng)建變量時(shí),會(huì)進(jìn)行變量檢查,當(dāng)設(shè)置為共享變量時(shí)(通過with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE)設(shè)置),檢查到第二個(gè)擁有相同名字的變量,就返回已創(chuàng)建的相同的變量;如果沒有設(shè)置共享變量,則會(huì)報(bào)ValueError: Variable varx alreadly exists, disallowed的錯(cuò)誤。而tf.Variable()創(chuàng)建變量時(shí),name屬性值允許重復(fù),檢查到相同名字的變量時(shí),由自動(dòng)別名機(jī)制創(chuàng)建不同的變量。舉個(gè)例子:
with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE):
var1 = tf.get_variable(name='var1', shape=[1], initializer=None, dtype=tf.float32)
var11 = tf.get_variable(name='var1')
var2 = tf.Variable(name='var2', initial_value=[1], dtype=tf.float32)
var21 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
with tf.Session() as sess:
print var1.name
print var11.name
print var2.name
print var21.name
輸出name時(shí),如下:
var1:0
var1:0
name_scope_2/var2:0
name_scope_2/var2_1:0