緩解Exposure Bias的一種實(shí)現(xiàn)

介紹

seq2seq中的decoder是一個(gè)自回歸的生成模型,那么在訓(xùn)練階段,第t步輸入的前綴序列是來(lái)自真實(shí)數(shù)據(jù)分布的x_{1:(t-1)},這種學(xué)習(xí)方式稱(chēng)為教師強(qiáng)制(Teacher Forcing)。然而在預(yù)測(cè)階段,前綴序列則是來(lái)自模型分布的\hat{x}_{1:(t-1)}。由于模型分布和真實(shí)數(shù)據(jù)分布并不嚴(yán)格一致,因此一旦預(yù)測(cè)前綴\hat{x}_{1:(t-1)}的過(guò)程中存在錯(cuò)誤,會(huì)導(dǎo)致錯(cuò)誤傳播,使得后續(xù)生成的序列偏離真實(shí)分布,這個(gè)問(wèn)題稱(chēng)為曝光偏差(Exposure Bias)。

一個(gè)簡(jiǎn)單的想法就是在訓(xùn)練decoder的時(shí)候?qū)⒄鎸?shí)前綴序列x_{1:(t-1)}?中某些位置隨機(jī)替換成隨機(jī)詞,讓decoder不過(guò)分依賴(lài)前綴輸入。但是每一步中不管輸入如何選擇, 目標(biāo)輸出依然是來(lái)自于真實(shí)數(shù)據(jù). 這可能使得模型預(yù)測(cè)一些不正確的序列。比如一個(gè)真實(shí)的序列是 “吃飯”, 如果在第一步生成時(shí)使用模型預(yù)測(cè)的詞是 “喝”, 模型就會(huì)強(qiáng)制記住 “喝飯” 這個(gè)不正確的序列,這個(gè)問(wèn)題被稱(chēng)為過(guò)度糾正(Overcorrection)。

方案

ACL2019最佳長(zhǎng)文Bridging the Gap between Training and Inference for Neural Machine Translation提出了Oracle Word的概念,也就是說(shuō)不是隨機(jī)選取詞來(lái)替換,而是在word level或者sentence level考慮“合乎情理”的詞來(lái)替換。

  • Word-Level Oracle是指在第t-1步根據(jù)softmax輸出的概率分布做詞采樣,為了增加魯棒性,可以在概率分布上加入Gumbel noise。
  • Sentence-Level Oracle是指先利用beam search獲得一些候選翻譯結(jié)果,再和真實(shí)結(jié)果計(jì)算BLEU值,選擇對(duì)應(yīng)最優(yōu)BLEU值的候選翻譯結(jié)果作為decoder輸入。其中針對(duì)候選翻譯結(jié)果可能和真實(shí)結(jié)果長(zhǎng)度不一樣,又引入了Force Decoding技巧。
  • 文中Sampling with Decay技巧是考慮在模型未得到充分訓(xùn)練時(shí),decoder的解碼結(jié)果可能很不可靠,為了避免模型無(wú)法收斂,替換前綴序列概率伴隨著訓(xùn)練的step緩慢增加。

實(shí)現(xiàn)

初讀這篇文章的時(shí)候有這樣的疑問(wèn):文章中的實(shí)現(xiàn)都是基于RNN的,怎么在基于Transformer的機(jī)器翻譯模型中應(yīng)用以上的方法,難不成為了使用這些技巧,放棄模型的并行性?相比Sentence-Level Oracle,Word-Level Oracle更易于并行實(shí)現(xiàn),以下我的實(shí)現(xiàn)方案:

  1. 預(yù)先前向計(jì)算一次decoder部分,并映射到字典維度,得到logits
  2. 利用由top_K logits計(jì)算的概率分布并采樣詞,得到候選替換序列
  3. 根據(jù)訓(xùn)練步數(shù)global_steps計(jì)算替換概率p,利用預(yù)設(shè)的最大概率值截?cái)?/li>
  4. 依照p,替換decoder的輸入序列

注:預(yù)先前向計(jì)算后需要使用tf.stop_gradients,防止反向傳播時(shí)冗余的梯度回傳。


# 利用由top_K logits計(jì)算的概率分布并采樣詞,得到候選替換序列
def sample_with_topk(logits, k):
    reshaped_logits = (tf.reshape(logits, [-1, shape_list(logits)[-1]]))
    reshaped_logits_values, reshaped_logits_indices = tf.nn.top_k(input=reshaped_logits, k=k, sorted=True)
    choices = tf.multinomial(reshaped_logits_values, 1)
    choices = tf.concat(
        [tf.expand_dims(tf.cast(tf.range(tf.reduce_prod(shape_list(logits)[0:-1])), dtype=tf.int64), axis=-1),
         choices], axis=-1)

    choices = tf.gather_nd(params=reshaped_logits_indices, indices=choices)
    choices = tf.reshape(choices, shape_list(logits)[:logits.get_shape().ndims - 1])
    return tf.cast(choices, dtype=tf.int64)

# encoder_outpur: encoder所有隱層結(jié)果
# encoder_decoder_attention_bias: decoder中計(jì)算enc_dec_atten所涉及的mask偏置
# targets: 目標(biāo)id序列
# decoder_input: decoder輸入序列的嵌入向量
# decoder_self_attention_bias: decoder中self_atten的偏置
targets = common_layers.flatten4d3d(targets)
decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(targets, hparams, features=features)

# decoder_output_tmp: 預(yù)先計(jì)算decoder的最后一層隱層輸出
decoder_output_tmp = self.decode(
    decoder_input,
    encoder_output,
    encoder_decoder_attention_bias,
    decoder_self_attention_bias,
    hparams,
    nonpadding=features_to_nonpadding(features, "targets"),
    losses=losses)

# 將隱層向量映射到字典維度
with tf.variable_scope(self._variable_scopes['model_fn'],reuse=tf.AUTO_REUSE):
    logits_tmp = self.top(decoder_output_tmp, features)

# 防止梯度冗余傳播
logits_tmp = tf.stop_gradient(logits_tmp)

# 采樣得到候選序列
targets_proposal = sample_with_topk(logits_tmp, 10)

# 獲取全局訓(xùn)練步數(shù)
global_steps = tf.cast(tf.train.get_global_step(), dtype=tf.float32)

# 計(jì)算保留概率=1-替換概率,最小保留概率是0.5
p = tf.maximum(1.0 - tf.math.floordiv(global_steps, 10000.) * 0.5 / 75., 0.5)

# 判斷本次是保留, 還是替換;0表示保留, 1表示替換
pred = tf.cond(tf.less(tf.random.uniform(shape=(), minval=0, maxval=1), p),
               true_fn=lambda: 0.,
               false_fn=lambda: 1.)

# 隨機(jī)選擇15%序列中位置做替換
mask = tf.less(tf.random_uniform(tf.shape(features["targets_raw"])), 0.15 * pred)

# 利用mask融合原始目標(biāo)序列和候選目標(biāo)序列
targets_proposal = (cast_like(mask, targets_proposal) * targets_proposal +
                    cast_like(tf.logical_not(mask), targets_proposal) *
                    cast_like(features["targets_raw"], targets_proposal)) * \
                   cast_like(common_layers.weights_nonzero(features["targets_raw"]), targets_proposal)

# 利用融合的目標(biāo)序列作為decoder的輸入序列,計(jì)算decoder隱層向量
with tf.variable_scope(self._variable_scopes['symbol_modality_{}_{}'.format(hparams.problem_hparams.vocab_size["targets"],hparams.hidden_size)],reuse=tf.AUTO_REUSE):
    targets_proposal = self._problem_hparams.modality["targets"].bottom(targets_proposal)
targets_proposal = common_layers.flatten4d3d(targets_proposal)
decoder_input_random, _ = transformer_prepare_decoder(targets_proposal, hparams, features=features)
decoder_output_random = self.decode(
    decoder_input_random,
    encoder_output,
    encoder_decoder_attention_bias,
    decoder_self_attention_bias,
    hparams,
    nonpadding=features_to_nonpadding(features, "targets"),
    losses=losses)

總結(jié)

  • 為了緩解過(guò)度糾正問(wèn)題,選擇Oracle Word策略不能過(guò)于隨機(jī),一定程度上需要考慮當(dāng)前語(yǔ)義。
  • 緩解Exposure Bias問(wèn)題不能以犧牲并行性為代價(jià)。

拓展

  • 蘇神的實(shí)現(xiàn)方式Seq2Seq中Exposure Bias現(xiàn)象的淺析與對(duì)策,并引入對(duì)抗訓(xùn)練的概念來(lái)緩解Exposure Bias問(wèn)題,也發(fā)人深省。
  • 本質(zhì)上Exposure Bias問(wèn)題來(lái)源于自回歸生成中訓(xùn)練和測(cè)試的mismatch,目前利用前綴序列的方式都是離散的,是不是可以連續(xù)的利用前綴序列,從而在不失并行性的前提下統(tǒng)一訓(xùn)練與測(cè)試兩階段還有待探究。
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容