【可視化】網(wǎng)絡(luò)Attention層

1. 前言

準(zhǔn)備中期答辯,補(bǔ)充了一個(gè)實(shí)驗(yàn),需要對(duì)網(wǎng)絡(luò)結(jié)構(gòu)中的attention層進(jìn)行可視化,觀察序列輸入的哪些詞或者詞組合是網(wǎng)絡(luò)比較care的。在小論文中主要研究了關(guān)于詞性POS對(duì)輸入序列的注意力機(jī)制。同時(shí)對(duì)比實(shí)驗(yàn)采取的是words的self-attention機(jī)制。

基于POS-Attention的層次化模型

2. 效果對(duì)比

下圖主要包含兩列:word_attention是self-attention機(jī)制的模型訓(xùn)練結(jié)果,POS_attention是詞性模型的訓(xùn)練結(jié)果。
可以看出,相對(duì)于word_attention,POS的注意力機(jī)制不僅能夠捕捉到評(píng)價(jià)的aspect,也能根據(jù)aspect關(guān)聯(lián)的詞借助情感語(yǔ)義表達(dá)的詞性分布,care到相關(guān)詞性的情感詞。

Attention可視化對(duì)比結(jié)果

3. 核心代碼

3.1 可視化樣例

# coding: utf-8
def highlight(word, attn):
    html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(seq, attns):
    html = ""
    for ix, attn in zip(seq, attns):
        html += ' ' + highlight(
            ix,
            attn
        )
    return html + "<br>"

from IPython.display import HTML, display
batch_size = 1
seqs = [["這", "是", "一個(gè)", "測(cè)試", "樣例", "而已"]]
attns = [[0.01, 0.19, 0.12, 0.7, 0.2, 0.1]]

for i in range(batch_size):
    text = mk_html(seqs[i], attns[i])
    display(HTML(text))

3.2 接入model

需要在model的返回列表中,添加attention_weight的輸出,理論上維度應(yīng)該和輸入序列的長(zhǎng)度是一致的。

# load model
import torch
# if you train on gpu, you need to move onto cpu
model = torch.load("../docs/model_chk/2018-11-07-02:45:37", map_location=lambda storage, location: storage)

from torch.autograd import Variable
for batch_idx, samples in enumerate(test_loader, 0):
    v_word = Variable(samples['word_vec'])
    v_final_label = samples['top_label']

    model.eval()
    final_probs, att_weight = model(v_word, v_pos)

    batch_words = toWords(samples["word_vec"].numpy(), idx_word)  # id轉(zhuǎn)化為word
    batch_att = getAtten(batch_words, att_weight.data.numpy())    # 去除padding詞,根據(jù)words的長(zhǎng)度截取attention
    labels = toLabel(samples['top_label'].numpy())  # 真實(shí)標(biāo)簽
    pre_labels = toLabel(final_probs.data.numpy() >= 0.5)   # 預(yù)測(cè)標(biāo)簽

    for i in range(len(batch_words)):
        text = mk_html(batch_words[i], batch_att[i])
        print(labels[i], pre_labels[i])
        display(HTML(text))

4. 總結(jié)

  • 建議把可視化獨(dú)立出來,用jupyter-notebook編輯,方便分段調(diào)試和copy;同時(shí)因?yàn)槭墙柚鷋tml渲染的,所以需要notebook
  • 項(xiàng)目代碼我后期后同步到github上
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容