BasicDecoder類和dynamic_decode
decoder文件中定義了Decoder抽象類和dynamic_decode函數(shù),dynamic_decode可以視為整個解碼過程的入口,需要傳入的參數(shù)就是Decoder的一個實例,他會動態(tài)的調(diào)用Decoder的step函數(shù)按步執(zhí)行decode,可以理解為Decoder類定義了單步解碼(根據(jù)輸入求出輸出,并將該輸出當(dāng)做下一時刻輸入)
basic_decoder文件定義了一個基本的Decoder類實例BasicDecoder,其初始化函數(shù):
def __init__(self, cell, helper, initial_state, output_layer=None):
需要傳入的參數(shù)就是cell類型、helper類型、初始化狀態(tài)(encoder的最后一個隱層狀態(tài))、輸出層(輸出映射層,將rnn_size轉(zhuǎn)化為vocab_size維)
AttentionWrapper
AttentionWrapper在原本RNNCell的基礎(chǔ)上在封裝一層attention
# 分為三步,第一步是定義attention機制,第二步是定義要是用的基礎(chǔ)的RNNCell,第三步是使用AttentionWrapper進行封裝
#定義要使用的attention機制。
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)
#attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)
# 定義decoder階段要是用的LSTMCell,然后為其封裝attention wrapper
decoder_cell = self._create_rnn_cell()
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.rnn_size, name='Attention_Wrapper')
Helper類
helper其實就是decode階段如何根據(jù)預(yù)測結(jié)果得到下一時刻的輸入,比如訓(xùn)練過程中應(yīng)該直接使用上一時刻的真實值作為下一時刻輸入,預(yù)測過程中可以使用貪婪搜索選擇概率最大的那個值作為下一時刻等等。所以Helper也就可以大致分為訓(xùn)練時helper和預(yù)測時helper兩種
“TrainingHelper”:訓(xùn)練過程中最常使用的Helper,下一時刻輸入就是上一時刻target的真實值
“GreedyEmbeddingHelper”:預(yù)測階段最常使用的Helper,下一時刻輸入是上一時刻概率最大的單詞通過embedding之后的向量
#分為四步,第一步是定義cell類型,第二步是定義helper類型,第三步是定義BasicDecoder類實例,第四步是調(diào)用dynamic_decode函數(shù)進行解碼
decoder_cell = ***(上面的代碼)
training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_inputs_embedded,
sequence_length=self.decoder_targets_length,
time_major=False, name='training_helper')
training_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper,
initial_state=decoder_initial_state, output_layer=output_layer)
#調(diào)用dynamic_decode進行解碼,decoder_outputs是一個namedtuple,里面包含兩項(rnn_outputs, sample_id)
# rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每個時刻每個單詞的概率,可以用來計算loss
# sample_id: [batch_size], tf.int32,保存最終的編碼結(jié)果??梢员硎咀詈蟮拇鸢? decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=training_decoder, impute_finished=True,
maximum_iterations=self.max_target_sequence_length)
Beam search decoder類
BeamSearchDecoder類,其實是一個Decoder的實例,跟BasicDecoder在一個等級上,但是二者又存在著不同,因為BasicDecoder需要指定helper參數(shù),也就是指定decode階段如何根據(jù)上一時刻輸出獲得下一時刻輸入。但是BeamSearchDecoder不需要,因為其在內(nèi)部實現(xiàn)了beam_search的功能,也就包含了helper的效果。
所以解碼器有兩種方式,直接調(diào)用BeamSearchDecoder,或者使用調(diào)用GreedyEmbeddingHelper+BasicDecoder的組合進行貪婪式解碼
#分為三步,第一步是定義cell,第二步是定義BeamSearchDecoder,第三步是調(diào)用dynamic_decode函數(shù)進行解碼
docoder_cell = ***(上面代碼)
if self.beam_search:
inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=embedding,
start_tokens=start_tokens, end_token=end_token,
initial_state=decoder_initial_state,
beam_width=self.beam_size,
output_layer=output_layer)
else:
decoding_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embedding,
start_tokens=start_tokens, end_token=end_token)
inference_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=decoding_helper,
initial_state=decoder_initial_state,
output_layer=output_layer)
decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=inference_decoder,
maximum_iterations=self.max_target_sequence_length)