介紹
seq2seq中的decoder是一個(gè)自回歸的生成模型,那么在訓(xùn)練階段,第t步輸入的前綴序列是來(lái)自真實(shí)數(shù)據(jù)分布的,這種學(xué)習(xí)方式稱(chēng)為教師強(qiáng)制(Teacher Forcing)。然而在預(yù)測(cè)階段,前綴序列則是來(lái)自模型分布的
。由于模型分布和真實(shí)數(shù)據(jù)分布并不嚴(yán)格一致,因此一旦預(yù)測(cè)前綴
的過(guò)程中存在錯(cuò)誤,會(huì)導(dǎo)致錯(cuò)誤傳播,使得后續(xù)生成的序列偏離真實(shí)分布,這個(gè)問(wèn)題稱(chēng)為曝光偏差(Exposure Bias)。
一個(gè)簡(jiǎn)單的想法就是在訓(xùn)練decoder的時(shí)候?qū)⒄鎸?shí)前綴序列中某些位置隨機(jī)替換成隨機(jī)詞,讓decoder不過(guò)分依賴(lài)前綴輸入。但是每一步中不管輸入如何選擇, 目標(biāo)輸出依然是來(lái)自于真實(shí)數(shù)據(jù). 這可能使得模型預(yù)測(cè)一些不正確的序列。比如一個(gè)真實(shí)的序列是 “吃飯”, 如果在第一步生成時(shí)使用模型預(yù)測(cè)的詞是 “喝”, 模型就會(huì)強(qiáng)制記住 “喝飯” 這個(gè)不正確的序列,這個(gè)問(wèn)題被稱(chēng)為過(guò)度糾正(Overcorrection)。
方案
ACL2019最佳長(zhǎng)文Bridging the Gap between Training and Inference for Neural Machine Translation提出了Oracle Word的概念,也就是說(shuō)不是隨機(jī)選取詞來(lái)替換,而是在word level或者sentence level考慮“合乎情理”的詞來(lái)替換。
- Word-Level Oracle是指在第t-1步根據(jù)softmax輸出的概率分布做詞采樣,為了增加魯棒性,可以在概率分布上加入Gumbel noise。
- Sentence-Level Oracle是指先利用beam search獲得一些候選翻譯結(jié)果,再和真實(shí)結(jié)果計(jì)算BLEU值,選擇對(duì)應(yīng)最優(yōu)BLEU值的候選翻譯結(jié)果作為decoder輸入。其中針對(duì)候選翻譯結(jié)果可能和真實(shí)結(jié)果長(zhǎng)度不一樣,又引入了Force Decoding技巧。
- 文中Sampling with Decay技巧是考慮在模型未得到充分訓(xùn)練時(shí),decoder的解碼結(jié)果可能很不可靠,為了避免模型無(wú)法收斂,替換前綴序列概率伴隨著訓(xùn)練的step緩慢增加。
實(shí)現(xiàn)
初讀這篇文章的時(shí)候有這樣的疑問(wèn):文章中的實(shí)現(xiàn)都是基于RNN的,怎么在基于Transformer的機(jī)器翻譯模型中應(yīng)用以上的方法,難不成為了使用這些技巧,放棄模型的并行性?相比Sentence-Level Oracle,Word-Level Oracle更易于并行實(shí)現(xiàn),以下我的實(shí)現(xiàn)方案:
- 預(yù)先前向計(jì)算一次decoder部分,并映射到字典維度,得到logits
- 利用由top_K logits計(jì)算的概率分布并采樣詞,得到候選替換序列
- 根據(jù)訓(xùn)練步數(shù)global_steps計(jì)算替換概率p,利用預(yù)設(shè)的最大概率值截?cái)?/li>
- 依照p,替換decoder的輸入序列
注:預(yù)先前向計(jì)算后需要使用tf.stop_gradients,防止反向傳播時(shí)冗余的梯度回傳。
# 利用由top_K logits計(jì)算的概率分布并采樣詞,得到候選替換序列
def sample_with_topk(logits, k):
reshaped_logits = (tf.reshape(logits, [-1, shape_list(logits)[-1]]))
reshaped_logits_values, reshaped_logits_indices = tf.nn.top_k(input=reshaped_logits, k=k, sorted=True)
choices = tf.multinomial(reshaped_logits_values, 1)
choices = tf.concat(
[tf.expand_dims(tf.cast(tf.range(tf.reduce_prod(shape_list(logits)[0:-1])), dtype=tf.int64), axis=-1),
choices], axis=-1)
choices = tf.gather_nd(params=reshaped_logits_indices, indices=choices)
choices = tf.reshape(choices, shape_list(logits)[:logits.get_shape().ndims - 1])
return tf.cast(choices, dtype=tf.int64)
# encoder_outpur: encoder所有隱層結(jié)果
# encoder_decoder_attention_bias: decoder中計(jì)算enc_dec_atten所涉及的mask偏置
# targets: 目標(biāo)id序列
# decoder_input: decoder輸入序列的嵌入向量
# decoder_self_attention_bias: decoder中self_atten的偏置
targets = common_layers.flatten4d3d(targets)
decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(targets, hparams, features=features)
# decoder_output_tmp: 預(yù)先計(jì)算decoder的最后一層隱層輸出
decoder_output_tmp = self.decode(
decoder_input,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "targets"),
losses=losses)
# 將隱層向量映射到字典維度
with tf.variable_scope(self._variable_scopes['model_fn'],reuse=tf.AUTO_REUSE):
logits_tmp = self.top(decoder_output_tmp, features)
# 防止梯度冗余傳播
logits_tmp = tf.stop_gradient(logits_tmp)
# 采樣得到候選序列
targets_proposal = sample_with_topk(logits_tmp, 10)
# 獲取全局訓(xùn)練步數(shù)
global_steps = tf.cast(tf.train.get_global_step(), dtype=tf.float32)
# 計(jì)算保留概率=1-替換概率,最小保留概率是0.5
p = tf.maximum(1.0 - tf.math.floordiv(global_steps, 10000.) * 0.5 / 75., 0.5)
# 判斷本次是保留, 還是替換;0表示保留, 1表示替換
pred = tf.cond(tf.less(tf.random.uniform(shape=(), minval=0, maxval=1), p),
true_fn=lambda: 0.,
false_fn=lambda: 1.)
# 隨機(jī)選擇15%序列中位置做替換
mask = tf.less(tf.random_uniform(tf.shape(features["targets_raw"])), 0.15 * pred)
# 利用mask融合原始目標(biāo)序列和候選目標(biāo)序列
targets_proposal = (cast_like(mask, targets_proposal) * targets_proposal +
cast_like(tf.logical_not(mask), targets_proposal) *
cast_like(features["targets_raw"], targets_proposal)) * \
cast_like(common_layers.weights_nonzero(features["targets_raw"]), targets_proposal)
# 利用融合的目標(biāo)序列作為decoder的輸入序列,計(jì)算decoder隱層向量
with tf.variable_scope(self._variable_scopes['symbol_modality_{}_{}'.format(hparams.problem_hparams.vocab_size["targets"],hparams.hidden_size)],reuse=tf.AUTO_REUSE):
targets_proposal = self._problem_hparams.modality["targets"].bottom(targets_proposal)
targets_proposal = common_layers.flatten4d3d(targets_proposal)
decoder_input_random, _ = transformer_prepare_decoder(targets_proposal, hparams, features=features)
decoder_output_random = self.decode(
decoder_input_random,
encoder_output,
encoder_decoder_attention_bias,
decoder_self_attention_bias,
hparams,
nonpadding=features_to_nonpadding(features, "targets"),
losses=losses)
總結(jié)
- 為了緩解過(guò)度糾正問(wèn)題,選擇Oracle Word策略不能過(guò)于隨機(jī),一定程度上需要考慮當(dāng)前語(yǔ)義。
- 緩解Exposure Bias問(wèn)題不能以犧牲并行性為代價(jià)。
拓展
- 蘇神的實(shí)現(xiàn)方式Seq2Seq中Exposure Bias現(xiàn)象的淺析與對(duì)策,并引入對(duì)抗訓(xùn)練的概念來(lái)緩解Exposure Bias問(wèn)題,也發(fā)人深省。
- 本質(zhì)上Exposure Bias問(wèn)題來(lái)源于自回歸生成中訓(xùn)練和測(cè)試的mismatch,目前利用前綴序列的方式都是離散的,是不是可以連續(xù)的利用前綴序列,從而在不失并行性的前提下統(tǒng)一訓(xùn)練與測(cè)試兩階段還有待探究。