本文旨在學(xué)習(xí)Ryan Dahl的tensorflow-resnet源碼中的Config類的基本作用。因?yàn)?,?strong>真的真的很有趣!
tensorflow-resnet的repository中有個(gè)文件叫config.py,Config類就是在這個(gè)文件中被定義的。它能夠很方便的實(shí)現(xiàn)基于tensorflow編寫(xiě)程序的不同參數(shù)在不同scope中的隔離管理(很拗口,一會(huì)兒上例子),Config類有以下幾個(gè)特點(diǎn):
- 它可以被認(rèn)為是包含了多個(gè)dict的list
- 它的內(nèi)部參數(shù)在不同variable scope中是“隔離”的
說(shuō)了那么多,相信誰(shuí)都沒(méi)看明白,那么舉個(gè)栗子:
c = Config()
c['p1'] = 1
c['p2'] = 1
c['p3'] = 1
# c['p1'] = 1, c['p2'] = 1, c['p3'] = 1, c['p4']不存在
with tf.variable_scope('foo'):
c['p1'] = 2
c['p4'] = 2
# c['p1'] = 2, c['p2'] = 1, c['p3'] = 1, c['p4'] = 2
with tf.variable_scope('bar'):
c['p2'] = 2
# c['p1'] = 2, c['p2'] = 2, c['p3'] = 1, c['p4'] = 2
with tf.variable_scope('baz'):
c['p3'] = 2
# c['p1'] = 1, c['p2'] = 1, c['p3'] = 2, c['p4']不存在
# c['p1'] = 1, c['p2'] = 1, c['p3'] = 1, c['p4']不存在
程序內(nèi)各項(xiàng)參數(shù)在不同位置的取值我已經(jīng)注釋出來(lái)了,很明顯,不同variable scope中的參數(shù)是隔離的,你在’foo‘中設(shè)置的參數(shù)在’baz‘不起作用,在’foo‘中新定義的參數(shù)在其他scope中看不到(但在’foo‘中的’bar‘內(nèi)可以看到)。
關(guān)于Config類的特點(diǎn)還有待挖掘,以上只是說(shuō)明了它最基本的特點(diǎn),下面給出Ryan Dahl的config.py的源碼,你也可以去他的repository中看,這是鏈接。
# This is a variable scope aware configuation object for TensorFlow
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
class Config:
def __init__(self):
root = self.Scope('')
for k, v in FLAGS.__dict__['__flags'].iteritems():
root[k] = v
self.stack = [ root ]
def iteritems(self):
return self.to_dict().iteritems()
def to_dict(self):
self._pop_stale()
out = {}
# Work backwards from the flags to top fo the stack
# overwriting keys that were found earlier.
for i in range(len(self.stack)):
cs = self.stack[-i]
for name in cs:
out[name] = cs[name]
return out
def _pop_stale(self):
var_scope_name = tf.get_variable_scope().name
top = self.stack[0]
while not top.contains(var_scope_name):
# We aren't in this scope anymore
self.stack.pop(0)
top = self.stack[0]
def __getitem__(self, name):
self._pop_stale()
# Recursively extract value
for i in range(len(self.stack)):
cs = self.stack[i]
if name in cs:
return cs[name]
raise KeyError(name)
def set_default(self, name, value):
if not (name in self):
self[name] = value
def __contains__(self, name):
self._pop_stale()
for i in range(len(self.stack)):
cs = self.stack[i]
if name in cs:
return True
return False
def __setitem__(self, name, value):
self._pop_stale()
top = self.stack[0]
var_scope_name = tf.get_variable_scope().name
assert top.contains(var_scope_name)
if top.name != var_scope_name:
top = self.Scope(var_scope_name)
self.stack.insert(0, top)
top[name] = value
class Scope(dict):
def __init__(self, name):
self.name = name
def contains(self, var_scope_name):
return var_scope_name.startswith(self.name)
# Test
if __name__ == '__main__':
def assert_raises(exception, fn):
try:
fn()
except exception:
pass
else:
assert False, "Expected exception"
c = Config()
c['hello'] = 1
assert c['hello'] == 1
with tf.variable_scope('foo'):
c.set_default("bar", 10)
c['bar'] = 2
assert c['bar'] == 2
assert c['hello'] == 1
c.set_default("mario", True)
with tf.variable_scope('meow'):
c['dog'] = 3
assert c['dog'] == 3
assert c['bar'] == 2
assert c['hello'] == 1
assert c['mario'] == True
assert_raises(KeyError, lambda: c['dog'])
assert c['bar'] == 2
assert c['hello'] == 1