tensorflow 中使用 tf.py_func 靈活地自定義張量計算

背景是這樣的,在看完論文《A Deep Reinforced Model for Abstractive Summarization》之后想用 tensorflow 實現(xiàn)一下。論文中的一個關(guān)鍵點是根據(jù)摘要生成的一個評測指標 ROUGE 作為強化學習的 reward,參與到損失函數(shù)的計算中,需要優(yōu)化的目標函數(shù)如下(詳細一些的介紹可參考我的另一篇文章http://www.itdecent.cn/p/38a2d3e04272):

強化學習目標函數(shù).png

去 github 找了一下這個方法的開源實現(xiàn),發(fā)現(xiàn)了兩種實現(xiàn)方式。

方式一(有問題)

https://github.com/weili-ict/SelfCriticalSequenceTraining-tensorflow
此方式的實現(xiàn)思路是先通過一次 sess.run 根據(jù)最大概率生成序列作為 baseline 以及進行采樣得到序列 ys。生成序列后計算 baseline 和 ys 的 reward。然后將 reward 放到 placeholder 中再跑一次 sess.run 計算 Lrl 并優(yōu)化。此方法的問題是兩次 sess.run 產(chǎn)生的 baseline 序列是不變的,但 ys 序列是根據(jù)概率分布采樣生成的兩次生成的序列完全不同,不能將第一次計算出的 reward 用于第二次的損失函數(shù)計算。
主要邏輯如下:

captions_batch = np.array(captions[i * self.batch_size:(i + 1) * self.batch_size])
image_idxs_batch = np.array(image_idxs[i * self.batch_size:(i + 1) * self.batch_size])
features_batch = np.array(features[image_idxs_batch])

ground_truths = [captions[image_idxs == image_idxs_batch[j]] for j in
                 range(len(image_idxs_batch))]
ref_decoded = [decode_captions(ground_truths[j], self.model.idx_to_word) for j in range(len(ground_truths))]

feed_dict = {self.model.features: features_batch, self.model.captions: captions_batch}
# first run to get 2 different serials
samples, greedy_words = sess.run([sampled_captions, greedy_caption],
                                         feed_dict)
masks, all_decoded = decode_captions_for_blue(samples, self.model.idx_to_word)
_, greedy_decoded = decode_captions_for_blue(greedy_words, self.model.idx_to_word)
# calculate the rewards of 2 serials
r = [evaluate_captions([k], [v])  for k, v in zip(ref_decoded, all_decoded)]
b = [evaluate_captions([k], [v]) for k, v in zip(ref_decoded, greedy_decoded)]

b_for_eval.extend(b)
feed_dict = {grad_mask: masks, rewards: r, base_line: b,
             self.model.features: features_batch, self.model.captions: captions_batch
             } 
# calculate loss and train
_ = sess.run([train_op], feed_dict)

方式二

https://github.com/ne7ermore/deeping-flow/tree/master/deep-reinforced-sum-model
由方式一可見,進行一次序列生成 -> 計算序列的 reward -> 進行第二次序列生成,利用已計算的 reward 計算 loss 并優(yōu)化,這一模式是行不通的,由于隨機采樣的的存在,必須保證 reward 計算與 loss 計算優(yōu)化在同一次 sess.run 中進行。
但是,困難在于 reward 的計算是一個較為復(fù)雜的過程, tensorflow 中必然沒有提供這樣的計算 API,那么如何在 sess.run 中進行這樣復(fù)雜的自定義的計算呢?
可以利用 tf.py_func 函數(shù)。py_func 函數(shù)接收一個用戶自定義的函數(shù) f 和一個輸入張量 input 作為參數(shù),返回一個利用 f 對 input 轉(zhuǎn)換后的輸出張量。
主要邏輯如下:

b_words = model.sample()
s_words, props = model.sample(False)

s_props = tf.reshape(_gather_index(tf.reshape(
    props, [-1, args.tgt_vs]), s_words, model.prev), [args.batch_size, args.l_max_len])
# calculate rewards in session
baseline = tf.py_func(rouge_l, [b_words, model.tgt], tf.float32)
reward = tf.py_func(rouge_l, [s_words, model.tgt], tf.float32)
advantage = reward - baseline

mask = pad_mask(model.tgt, EOS, [args.batch_size, args.l_max_len])

loss = -tf.reduce_sum(s_props *
                      mask * advantage[:, None]) / tf.reduce_sum(mask)

可以看到此方式將序列生成、reward 計算、loss 計算放在同一個 sess.run 中進行,解決了方式一中兩次 sess.run 生成的采樣候選序列不同的問題。

參考:
1、《A Deep Reinforced Model for Abstractive Summarization》
2、《Self-critical Sequence Training for Image Captioning》
3、https://github.com/weili-ict/SelfCriticalSequenceTraining-tensorflow
4、https://github.com/ne7ermore/deeping-flow/tree/master/deep-reinforced-sum-model

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 簡單線性回歸 import tensorflow as tf import numpy # 創(chuàng)造數(shù)據(jù) x_dat...
    CAICAI0閱讀 3,664評論 0 49
  • 強化學習(Reinforcement Learing),機器學習重要分支,解決連續(xù)決策問題。強化學習問題三概念,環(huán)...
    利炳根閱讀 3,844評論 0 4
  • DayOne 帕勞是太平洋上的一個熱帶島國。 一大早,拉開厚重的窗簾,張望小路對面頗有風情的民居和酒店,高大的棕櫚...
    瑤_8765閱讀 1,534評論 0 0
  • 1.1字符串字面量 1.2轉(zhuǎn)義字符 ' 單引號 " 雙引號 \t 制表符 \n 換行 \ 倒斜杠 1.3 原始字符...
    illaclv閱讀 364評論 0 0
  • 12年的5月,這是您第一次生病嚴重到自殺的地步,那也是我第一次感覺手足無措,第一次覺得天塌了的心慌難受。我跟哥哥第...
    我是胖小香閱讀 579評論 1 0

友情鏈接更多精彩內(nèi)容