Class Model
class Model:
def __init__(self, data, target):
data_size = int(data.get_shape()[1])
target_size = int(target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(data, weight) + bias
self._prediction = tf.nn.softmax(incoming)
cross_entropy = -tf.reduce_sum(target, tf.log(self._prediction))
self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
mistakes = tf.not_equal(
tf.argmax(target, 1), tf.argmax(self._prediction, 1))
self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
@property
def prediction(self):
return self._prediction
@property
def optimize(self):
return self._optimize
@property
def error(self):
return self._error
@property裝飾器可以將類函數(shù)與其屬性相關(guān)聯(lián)。但是以上的方法可讀性和復(fù)利用性太差。
Use property
class Model:
def __init__(self, data, target):
self.data = data
self.target = target
self._prediction = None
self._optimize = None
self._error = None
@property
def prediction(self):
if not self._prediction:
data_size = int(self.data.get_shape()[1])
target_size = int(self.target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(self.data, weight) + bias
self._prediction = tf.nn.softmax(incoming)
return self._prediction
@property
def optimize(self):
if not self._optimize:
cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
optimizer = tf.train.RMSPropOptimizer(0.03)
self._optimize = optimizer.minimize(cross_entropy)
return self._optimize
@property
def error(self):
if not self._error:
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
return self._error
Emmm 存在lazy_loading 的問題 , 這里還不太理解。
Code is still a bit bloated due to the lazy loading logic
Lazy Property Decorator
import functools
def lazy_property(function):
attribute = '_cache_' + function.__name__
@property
@functools.wrap(function)
def decorator(self, attribute):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self , attribute)
return decorator
通過這個(gè)裝飾器可以簡(jiǎn)化模型
class Model:
def __init__(self, data, target):
self.data = data
self.target = target
self.prediction
self.optimize
self.error
@lazy_property
def prediction(self):
data_size = int(self.data.get_shape()[1])
target_size = int(self.target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(self.data, weight) + bias
return tf.nn.softmax(incoming)
@lazy_property
def optimize(self):
cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
optimizer = tf.train.RMSPropOptimizer(0.03)
return optimizer.minimize(cross_entropy)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
Note that we mention the properties in the constructor. This way the full graph is ensured to be defined by the time we run tf.initialize_variables().
定義計(jì)算圖的范圍
同過函數(shù)定義
import functools
def define_scope(function):
attribute = '_cache_' + function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
with tf.variable_scope(function.__name__):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
插入tf.variable_scope(function.name) 或者 tf.name_scope(function.name) 來定義Scope。
自定義Scope
def doublewrap(function):
"""
A decorator decorator, allowing to use the decorator to be used without parentheses
if not arguments are provided. All arguments must be optional.
"""
@functools.wraps(function)
def decorator(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return function(args[0])
else:
return lambda wrapee: function(wrapee, *args, **kwargs)
return decorator
@doublewrap
def define_scope(function, scope=None, *args, **kwargs):
"""
A decorator for functions that define TensorFlow operations. The wrapped
function will only be executed once. Subsequent calls to it will directly
return the result so that operations are added to the graph only once.
The operations added by the function live within a tf.variable_scope(). If
this decorator is used with arguments, they will be forwarded to the
variable scope. The scope name defaults to the name of the wrapped
function
"""
attribute = '_cache_' + function.__name__
name = scope or function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
with tf.variable_scope(name, *args, **kwargs):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
雙層裝飾器以保證無參時(shí)也可照常調(diào)用。
本文參考翻譯自 Danijar Hafner