tensor2tensor中抽象出了一個(gè)Modality類,用來解耦模型主干和依賴任務(wù)的數(shù)據(jù)形式轉(zhuǎn)化。例如一個(gè)self-attention模塊既可以用于離散的字詞序列,也可以用于圖像的某個(gè)維度向量序列上,前提是需要轉(zhuǎn)換成特定格式。
Modality就是負(fù)責(zé)具體數(shù)據(jù)轉(zhuǎn)化,包括詞嵌入、交換維度、輸出映射、計(jì)算損失值等,所以Modality是在Problem類中的hparams方法中設(shè)置的,依賴于具體數(shù)據(jù)。
T2TModel類中bottom、top和loss方法中具體調(diào)用來做相應(yīng)數(shù)據(jù)格式轉(zhuǎn)化和損失計(jì)算。
Modality包含四個(gè)主要方法:bottom、targets_bottom、top和loss,下面結(jié)合機(jī)器翻譯中使用到的SymbolModality詳細(xì)闡述各個(gè)方法具體做了什么。
bottom & targets_bottom
在SymbolModality中關(guān)于輸入數(shù)據(jù)轉(zhuǎn)化的核心方法是bottom_simple和_get_weights,bottom和targets_bottom是前二者的封裝。
-
_get_weights創(chuàng)建字典維度的embedding,借助reuse機(jī)制即可用于編碼器和解碼器的embedding lookup,也可用于計(jì)算logits的維度映射。def _get_weights(self, hidden_dim=None): """Create or get concatenated embedding or softmax variable. Args: hidden_dim: dim of the variable. Defaults to self._body_input_depth Returns: a list of self._num_shards Tensors. """ if hidden_dim is None: hidden_dim = self._body_input_depth num_shards = self._model_hparams.symbol_modality_num_shards shards = [] for i in range(num_shards): shard_size = (self._vocab_size // num_shards) + ( 1 if i < self._vocab_size % num_shards else 0) var_name = "weights_%d" % i shards.append( tf.get_variable( var_name, [shard_size, hidden_dim], initializer=tf.random_normal_initializer(0.0, hidden_dim ** -0.5))) if num_shards == 1: ret = shards[0] else: ret = tf.concat(shards, 0) # Convert ret to tensor. if not tf.contrib.eager.in_eager_mode(): ret = common_layers.convert_gradient_to_tensor(ret) return ret -
bottom_simple將離散值輸入通過gather函數(shù)做詞嵌入。這里的gather函數(shù)是通過對(duì)離散值進(jìn)行one_hot編碼,然后與embedding做矩陣乘法得到。def bottom_simple(self, x, name, reuse): with tf.variable_scope(name, reuse=reuse): # Ensure the inputs are 3-D if len(x.get_shape()) == 4: x = tf.squeeze(x, axis=3) while len(x.get_shape()) < 3: x = tf.expand_dims(x, axis=-1) var = self._get_weights() x = common_layers.dropout_no_scaling( x, 1.0 - self._model_hparams.symbol_dropout) ret = common_layers.gather(var, x) if self._model_hparams.multiply_embedding_mode == "sqrt_depth": ret *= self._body_input_depth ** 0.5 ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1) return ret由于
tensor2tensor中默認(rèn)填充符<PAD>的index=0,ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)就是將index=0的embedding重置為全零。這樣序列真實(shí)長度和attention mask都可以從embedding中計(jì)算得到。 -
bottom和targets_bottom控制embedding共享機(jī)制,默認(rèn)情況下會(huì)共享編碼器和解碼器的embedding,減少參數(shù)的同時(shí)獲得更多更新次數(shù)。def bottom(self, x): if (self._model_hparams.shared_embedding_and_softmax_weights or self._model_hparams.get("shared_embedding")): return self.bottom_simple(x, "shared", reuse=None) return self.bottom_simple(x, "input_emb", reuse=None) def targets_bottom(self, x): if (self._model_hparams.shared_embedding_and_softmax_weights or self._model_hparams.get("shared_embedding")): try: return self.bottom_simple(x, "shared", reuse=True) except ValueError: # perhaps there were no inputs, and this is a new variable. return self.bottom_simple(x, "shared", reuse=None) else: return self.bottom_simple(x, "target_emb", reuse=None)
top
top負(fù)責(zé)映射隱層向量到字典維度,其中映射矩陣可以共享使用embedding矩陣,梯度反向傳播的路徑明顯縮短,可以更充分的訓(xùn)練embedding矩陣。
def top(self, body_output, _):
"""Generate logits.
Args:
body_output: A Tensor with shape [batch, p0, p1, body_input_depth]
Returns:
logits: A Tensor with shape [batch, p0, p1, ?, vocab_size].
"""
if self._model_hparams.symbol_modality_skip_top:
return tf.expand_dims(body_output, 3)
if self._model_hparams.shared_embedding_and_softmax_weights:
scope_name = "shared"
reuse = True
else:
scope_name = "softmax"
reuse = False
with tf.variable_scope(scope_name, reuse=reuse):
body_output_shape = common_layers.shape_list(body_output)
var = self._get_weights(body_output_shape[-1])
if (self._model_hparams.factored_logits and
self._model_hparams.mode == tf.estimator.ModeKeys.TRAIN):
# insert channels dimension
body_output = tf.expand_dims(body_output, 3)
return common_layers.FactoredTensor(body_output, var)
else:
body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
logits = tf.matmul(body_output, var, transpose_b=True)
if (common_layers.is_xla_compiled() and
self._model_hparams.mode == tf.estimator.ModeKeys.TRAIN):
# TPU does not react kindly to extra dimensions.
# TODO(noam): remove this once TPU is more forgiving of extra dims.
return logits
else:
return tf.reshape(logits,
body_output_shape[:-1] + [1, self._vocab_size])
loss
loss就是交叉熵?fù)p失,加上labe_smoothing技巧。weights_fn=weights_nonzero, 計(jì)算損失時(shí)忽略targets中等于零的位置
def loss(self, top_out, targets, weights_fn=None):
"""Compute loss numerator and denominator for one shard of output."""
logits = top_out
if weights_fn is None:
weights_fn = self.targets_weights_fn
return common_layers.padded_cross_entropy(
logits,
targets,
self._model_hparams.label_smoothing,
weights_fn=weights_fn)