關(guān)于Ryan Dahl的tensorflow-resnet中Config類的說(shuō)明

本文旨在學(xué)習(xí)Ryan Dahl的tensorflow-resnet源碼中的Config類的基本作用。因?yàn)?,?strong>真的真的很有趣!

tensorflow-resnet的repository中有個(gè)文件叫config.py,Config類就是在這個(gè)文件中被定義的。它能夠很方便的實(shí)現(xiàn)基于tensorflow編寫(xiě)程序的不同參數(shù)在不同scope中的隔離管理(很拗口,一會(huì)兒上例子),Config類有以下幾個(gè)特點(diǎn):

  • 它可以被認(rèn)為是包含了多個(gè)dict的list
  • 它的內(nèi)部參數(shù)在不同variable scope中是“隔離”的

說(shuō)了那么多,相信誰(shuí)都沒(méi)看明白,那么舉個(gè)栗子:

c = Config()
c['p1'] = 1
c['p2'] = 1
c['p3'] = 1
# c['p1'] = 1, c['p2'] = 1, c['p3'] = 1, c['p4']不存在

with tf.variable_scope('foo'):
    c['p1'] = 2
    c['p4'] = 2
    # c['p1'] = 2, c['p2'] = 1, c['p3'] = 1, c['p4'] = 2

    with tf.variable_scope('bar'):
        c['p2'] = 2
        # c['p1'] = 2, c['p2'] = 2, c['p3'] = 1, c['p4'] = 2

with tf.variable_scope('baz'):
    c['p3'] = 2
    # c['p1'] = 1, c['p2'] = 1, c['p3'] = 2, c['p4']不存在

# c['p1'] = 1, c['p2'] = 1, c['p3'] = 1, c['p4']不存在

程序內(nèi)各項(xiàng)參數(shù)在不同位置的取值我已經(jīng)注釋出來(lái)了,很明顯,不同variable scope中的參數(shù)是隔離的,你在’foo‘中設(shè)置的參數(shù)在’baz‘不起作用,在’foo‘中新定義的參數(shù)在其他scope中看不到(但在’foo‘中的’bar‘內(nèi)可以看到)。

關(guān)于Config類的特點(diǎn)還有待挖掘,以上只是說(shuō)明了它最基本的特點(diǎn),下面給出Ryan Dahl的config.py的源碼,你也可以去他的repository中看,這是鏈接。

# This is a variable scope aware configuation object for TensorFlow

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

class Config:
    def __init__(self):
        root = self.Scope('')
        for k, v in FLAGS.__dict__['__flags'].iteritems():
            root[k] = v
        self.stack = [ root ]

    def iteritems(self):
        return self.to_dict().iteritems()

    def to_dict(self):
        self._pop_stale()
        out = {}
        # Work backwards from the flags to top fo the stack
        # overwriting keys that were found earlier.
        for i in range(len(self.stack)):
            cs = self.stack[-i]
            for name in cs:
                out[name] = cs[name]
        return out

    def _pop_stale(self):
        var_scope_name = tf.get_variable_scope().name
        top = self.stack[0]
        while not top.contains(var_scope_name):
            # We aren't in this scope anymore
            self.stack.pop(0)
            top = self.stack[0]

    def __getitem__(self, name):
        self._pop_stale()
        # Recursively extract value
        for i in range(len(self.stack)):
            cs = self.stack[i]
            if name in cs:
                return cs[name]

        raise KeyError(name)

    def set_default(self, name, value):
        if not (name in self):
            self[name] = value

    def __contains__(self, name):
        self._pop_stale()
        for i in range(len(self.stack)):
            cs = self.stack[i]
            if name in cs:
                return True
        return False

    def __setitem__(self, name, value):
        self._pop_stale()
        top = self.stack[0]
        var_scope_name = tf.get_variable_scope().name
        assert top.contains(var_scope_name)

        if top.name != var_scope_name:
            top = self.Scope(var_scope_name)
            self.stack.insert(0, top)

        top[name] = value

    class Scope(dict):
        def __init__(self, name):
            self.name = name

        def contains(self, var_scope_name):
            return var_scope_name.startswith(self.name)



# Test
if __name__ == '__main__':

    def assert_raises(exception, fn):
        try:
            fn()
        except exception:
            pass
        else:
            assert False, "Expected exception"

    c = Config()

    c['hello'] = 1
    assert c['hello'] == 1

    with tf.variable_scope('foo'):
        c.set_default("bar", 10)
        c['bar'] = 2
        assert c['bar'] == 2
        assert c['hello'] == 1

        c.set_default("mario", True)

        with tf.variable_scope('meow'):
            c['dog'] = 3
            assert c['dog'] == 3
            assert c['bar'] == 2
            assert c['hello'] == 1

            assert c['mario'] == True

        assert_raises(KeyError, lambda: c['dog'])
        assert c['bar'] == 2
        assert c['hello'] == 1

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容