機(jī)器學(xué)習(xí)基礎(chǔ)(11)條件隨機(jī)場的理解及BI-LSTM+CRF實戰(zhàn)

在NLP領(lǐng)域,在神經(jīng)網(wǎng)絡(luò)興起之前,條件隨機(jī)場(CRF)一直是作為主力模型的存在,就算是在RNN系(包括BERT系)的模型興起之后,也通常會在模型的最后添加一個CRF層,以提高準(zhǔn)確率。因此,CRF是所有NLPer必須要精通且掌握的一個模型,本文將優(yōu)先闡述清楚與CRF有關(guān)的全部基本概念,并詳細(xì)對比HMM,最后獻(xiàn)上BI-LSTM+CRF的實戰(zhàn)代碼及理解。相信讀完本文,將對CRF的認(rèn)識有一個新的高度。

在閱讀本文之前,務(wù)必對概率圖模型基礎(chǔ)有一個全盤的掌握,若對此沒有信心的,可以先參考我之前的一篇總結(jié)文:概率圖模型基礎(chǔ)

本文的基本目錄如下:

  1. 基礎(chǔ)知識
    1.1 CRF到底是什么?
    1.2 如何用CRF建模?
    1.3 CRF與HMM的區(qū)別是什么?

  2. BILSTM+CRF實戰(zhàn)
    2.1 為什么需要添加CRF層?
    2.2 如何計算損失函數(shù)?
    2.3 實戰(zhàn)環(huán)節(jié)

------------------第一菇 - 基礎(chǔ)知識------------------

1.1 CRF到底是什么?

本段主要用于講述與CRF有關(guān)的基礎(chǔ)概念。

大部分人理解CRF都會被帶到一個奇怪的誤區(qū)里面(包括我之前),因為總是理解完了HMM以后,就會立馬投入到CRF的學(xué)習(xí)里面,所以就會理所當(dāng)然的認(rèn)為CRF就是HMM的升級版(確實從模型效果上可以這么理解),然后一直把HMM的各自概念往CRF上套,之后兩廂一對比,就會有點犯迷糊了,好多東西也對不上啊??~然后,再一看書里的結(jié)論,啥?CRF竟然是判別式模型,HMM是生成式模型!這是什么鬼啦?CRF不就是HMM解除各自限制(有向圖變無向圖,箭頭指的方向更多了嗎??????),怎么突然就變成判別式模型啦???廢話不多說,如果有此疑問的同學(xué),就說明看對文章了,不要心急慢慢看,我將一一解釋清楚;而對此毫無疑問的同學(xué),那可以直接跳到實戰(zhàn)環(huán)節(jié)了哈。

CRF真的是判別式模型!準(zhǔn)確說是,判別式無向圖模型!

大家要牢記,區(qū)別判別式模型與生成式模型最基本的就是去判斷模型是對聯(lián)合分布進(jìn)行建模,還是對條件分布進(jìn)行建模。HMM中很顯然,模型是對x,y的聯(lián)合分布進(jìn)行建模(不清楚的同學(xué),還請移步HMM專區(qū)),而CRF則不然,其試圖對多個變量在給定觀測值后的條件概率進(jìn)行建模,因此屬于判別式模型。(各位抱著學(xué)新東西的心態(tài)來學(xué)CRF,把HMM拋在腦后把)

具體展開來看,若令x = \{x_1, x_2, x_3, ..., x_n\}為觀測序列,y = \{y_1, y_2, ..., y_n\}為與之對應(yīng)的標(biāo)記序列,則條件隨機(jī)場的目標(biāo)是構(gòu)建條件概率模型P(y | x)。值得注意的是,標(biāo)記變量y可以是結(jié)構(gòu)型變量,即其分量之間具有某種相關(guān)性。就比如在NLP領(lǐng)域中的詞性標(biāo)注任務(wù),觀測數(shù)據(jù)為單詞序列(即為x),標(biāo)記為相應(yīng)的詞性序列(即為y),且其具有線性序列結(jié)構(gòu)。

1.2 如何用CRF建模?

G = <V,E>表示結(jié)點與標(biāo)記變量y中元素一一對應(yīng)的無向圖,y_v表示與結(jié)點v對應(yīng)的標(biāo)記變量,n(v)表示結(jié)點v的鄰接結(jié)點,若圖G的每個變量y_v都滿足馬爾可夫性(即只與其相鄰的結(jié)點有關(guān)),即,

P(y_v | x, y_{V\setminus \{v\}}) = P(y_v | x, y_{n(v)})

(y,x)構(gòu)成一個條件隨機(jī)場。而理論上來說,圖G可具有任意結(jié)構(gòu),只要能表示標(biāo)記變量之間的條件獨立性關(guān)系即可,但在現(xiàn)實應(yīng)用中,尤其是對標(biāo)記序列建模時候,最常用的仍然是鏈?zhǔn)浇Y(jié)構(gòu),即“鏈?zhǔn)綏l件隨機(jī)場(chain-structured CRF)”,也是我們接下來主要要討論的一種條件隨機(jī)場。

鏈?zhǔn)綏l件隨機(jī)場

那我們該如何定義P(y|x)呢?

參考《機(jī)器學(xué)習(xí)》第14章的原文,其定義的方式類似馬爾可夫隨機(jī)場模型定義的聯(lián)合概率。條件隨機(jī)場使用勢函數(shù)和圖結(jié)構(gòu)上的團(tuán)來定義條件概率P(y|x)!如上圖所示,該鏈?zhǔn)綏l件隨機(jī)場主要包含兩種關(guān)于標(biāo)記變量的團(tuán),即單個標(biāo)記變量\{y_i\}以及相鄰的標(biāo)記變量\{y_{i-1}, y_i\}。選擇合適的勢函數(shù),即可得到形如馬爾可夫隨機(jī)場中聯(lián)合概率的定義。

在條件隨機(jī)場中,通過選用指數(shù)勢函數(shù)并引入特征函數(shù),條件概率被定義如下,

P(Y|X) = \frac{1}{Z}exp\left (\sum_{j}\sum_{i=1}^{n-1}\lambda_jt_j(y_{i+1}, y_i,X,i) + \sum_k\sum_{i=1}^{n}\mu_ks_k(y_i,X,i)\right )

注意哈,這里的X指的是整一個觀測序列!而且這里定義的條件概率計算方式,就只是將觀測序列X作為條件,并不對其作任何獨立性假設(shè)?。。。ㄟ@點很重要!也是其是判別式模型的重要依據(jù))

其中,t_j(y_{i+1}, y_i,X,i)是定義在觀測序列的倆個相鄰標(biāo)記位置上的轉(zhuǎn)移特征函數(shù),用于刻畫相鄰標(biāo)記變量之間的相關(guān)關(guān)系以及觀測序列對他們的影響。即給定觀測序列X,其標(biāo)注序列在ii-1位置上標(biāo)記的轉(zhuǎn)移概率!而特征函數(shù)的定義往往不止一種,因此會有一個下標(biāo)j代表要遍歷計算每一種特征函數(shù)的取值。

另外,s_k(y_i,X,i)是定義在觀測序列的標(biāo)記位置i上的狀態(tài)特征函數(shù),用于刻畫觀測序列對標(biāo)記變量的影響。即表示對于觀察序列Xi位置的標(biāo)記概率。同理,也有多種特征函數(shù),所以會有下標(biāo)k。

剩下的就比較簡單理解,\lambda_j和\mu_k都是參數(shù),Z為規(guī)范化因子,用于確保上式是被正確定義的概率(可以理解為類似softmax的操作)。

總結(jié)一下上式,可以理解為如下圖,

概率定義理解圖

至此,整個概率的定義想必大家已經(jīng)爛熟于心了~顯然,要運用好條件隨機(jī)場,最重要的就是要去定義合適的特征函數(shù)了。特征函數(shù)通常是實值函數(shù),以刻畫數(shù)據(jù)的一些很可能成立或期望成立的經(jīng)驗特性。因此定義特征函數(shù)的時候,一般都可以定義一組關(guān)于觀察序列的\{0,1\}二值特征b(X,i)來表示訓(xùn)練樣本中某些分布特性,比如詞性標(biāo)注任務(wù),

b(X,i) = \left\{\begin{matrix} 1, & X的\ i \ 位置為某個特定的詞\\ 0, & 否則 \end{matrix}\right.

等等類似的特征函數(shù),能定義好多出來的。因此,小小總結(jié)一下,CRF與馬爾可夫隨機(jī)場均使用團(tuán)上的勢函數(shù)定義概率,兩者在形式上并沒有顯著區(qū)別,只不過CRF處理的是條件概率,而馬爾可夫隨機(jī)場處理的是聯(lián)合概率。至此,整個CRF的建模已經(jīng)講明白了。

1.3 CRF與HMM的區(qū)別是什么?

CRF與HMM的一些基本定義的概念區(qū)別這邊在講概率圖模型和上面的基礎(chǔ)定義時已經(jīng)表述的很清楚了,本段就不繼續(xù)展開了~這里主要講一下HMM的標(biāo)注偏置問題,以及CRF為何能解決這個問題。

其實要想解釋清楚標(biāo)注偏置問題,大家只要看如下貼的一張圖即可,

標(biāo)注偏置問題

大家可以發(fā)現(xiàn),狀態(tài)1傾向于轉(zhuǎn)移到狀態(tài)2,狀態(tài)2傾向于轉(zhuǎn)移到狀態(tài)2本身,但是實際計算得到的最大概率路徑是1>1>1>1,狀態(tài)1并沒有轉(zhuǎn)移到狀態(tài)2!這其實是與我們的直覺相悖的~究其本質(zhì)原因,從狀態(tài)2轉(zhuǎn)移出去可能的狀態(tài)包括1,2,3,4,5,概率在可能的狀態(tài)上分散了,而狀態(tài)1轉(zhuǎn)移出去的可能狀態(tài)僅僅為狀態(tài)1和2,概率更加集中?。ù蠹铱梢阅霉P算一下,是不是這么個理~加深理解)由于局部歸一化的影響,隱狀態(tài)會傾向于轉(zhuǎn)移到那些后續(xù)狀態(tài)可能更少的狀態(tài)上,以提高整體的后驗概率!這就是標(biāo)注偏置問題!

而CRF如上所述,因為有歸一化因子(Z)的存在,其在全局范圍內(nèi)進(jìn)行了歸一化,枚舉了整個隱狀態(tài)序列的全部可能,從而解決了局部歸一化帶來的標(biāo)注偏置問題。而這也是CRF在很多問題上,表現(xiàn)比HMM優(yōu)秀的原因~

------------------第二菇 - BILSTM+CRF實戰(zhàn)------------------

介紹了這么多CRF有關(guān)的東西,想必各位也是躍躍欲試,我這邊也獻(xiàn)上一份BILSTM+CRF的實戰(zhàn)解析,包括對此模型架構(gòu)的理解以及源碼的解讀~

這里貼一個鏈接,是一個外國小哥寫的博客,當(dāng)初就是看這個博客明白其原理的,所以也特地在這邊貼出來,英文好的同學(xué)也可以直接看這個鏈接,我就不單獨放在參考文獻(xiàn)里了。

為了便于理解,代碼都是用Pytorch寫的,且還是以命名實體識別任務(wù)為具體例子。

如果看到這篇文章的是初學(xué)者,也不用慌,就簡單理解BILSTM和CRF為一個命名實體識別模型中的兩個層。

為了便于理解下面的圖示,這邊假設(shè)我們的數(shù)據(jù)集有兩大類,人名地名,與之相對應(yīng)在我們的訓(xùn)練數(shù)據(jù)集中,有五類標(biāo)簽:

* B-Person
* I-Person
* B-Organization
* I-Organization
* O

假設(shè)句子x由5個字符組成,即x = (w_0, w_1, w_2, w_3, w_4, w_5),其中[w_0, w_1]人名實體,[w_3]組織實體,其他字符的標(biāo)簽為"O"。

2.1 為什么需要添加CRF層?

這里先直接貼一張BILSTM-CRF的模型結(jié)構(gòu)圖,方便大家理解。

BiLSTM+CRF模型結(jié)構(gòu)圖

從下往上看,最下面就是輸入(字或詞向量),由于是序列模型,因此,在“時間”緯度上進(jìn)行展開,就可以得到如圖所示的模型表示,對應(yīng)于一個時刻,就是輸入一個字/詞向量(一般都是預(yù)先訓(xùn)練得出的)。

首先,是經(jīng)過BiLSTM的結(jié)構(gòu)單元。這個比較好理解,本質(zhì)上就是倆個LSTM層,只不過一次是正序輸入,一次是倒序輸入,然后把倆個結(jié)果進(jìn)行concact(拼接),并輸入到CRF層,最后由CRF層輸出每一個詞的標(biāo)簽~如果沒有CRF層的話,傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)都會加一層softmax層用于歸一化并輸出每個標(biāo)簽概率。

為了更容易理解CRF層的作用,我們還是先要理清Bi-LSTM的輸出。這里再貼一張圖,方便大家理解,

BiLSTM層的輸出.png

大家可以看到,其輸出十分簡單清晰,就是對于每一個單詞,其對應(yīng)每一個標(biāo)簽的分值(score)。因此,就算沒有CRF層,該模型依舊有效,我們只需要挑選每一個標(biāo)簽對應(yīng)最大的分值就可以,比如,w_0就是B-Person。因此,在原有模型本身就有效的情況下,我們再添加一層CRF的目的肯定只有一個,即提高模型的準(zhǔn)確率。

接下來,我們就要重點分析一下,CRF層的作用。先上結(jié)論,CRF層的主要作用是為最后預(yù)測的標(biāo)簽添加一些約束來保證預(yù)測標(biāo)簽的合理性!比如,在命名實體識別任務(wù)中,我們可以想到的約束可以是,

1)開頭的標(biāo)簽只能是B, O,而不可能是I-
2)B-Person開頭的標(biāo)簽,后面不可能接一個I-Organization
。。。

有了如上約束,我們就能保證,最終預(yù)測生成的標(biāo)簽序列的不合理性就會大大降低,而單憑BiLSTM的輸出來預(yù)測是無法保證標(biāo)簽序列的合理性的~

2.2 損失函數(shù)的定義及計算

弄清楚了CRF層的作用以后,我們就要來仔細(xì)研究研究CRF層的運行原理了,主要從其損失函數(shù)的角度來理解。

2.2.1 CRF中的兩種分?jǐn)?shù)

在CRF層的損失函數(shù)中,有兩種形式的score(分?jǐn)?shù)),第一個就是emission score(發(fā)射分?jǐn)?shù)),主要就是來自于BiLSTM層的輸出(如上圖所示),假設(shè)我們給每一個標(biāo)簽一個索引,那么第一個單詞的emission score就是,[1.5, 0.9, 0.1, 0.08, 0.05]

:大家可千萬注意了,不要把這里的score和HMM里面的發(fā)射概率矩陣相混淆!兩者可是完全不一樣的,在HMM中,每一個單詞的發(fā)射概率,僅與當(dāng)前的隱狀態(tài)層有關(guān)!是由隱狀態(tài)決定了當(dāng)前單詞的發(fā)射概率!而這里的發(fā)射分?jǐn)?shù),是由當(dāng)前輸入的序列,決定的當(dāng)前狀態(tài)的概率!這里是整一個序列哦!若大家還能記得CRF層的定義和概率模型圖(往上翻一翻),想必對此并不會驚訝!而這,也是CRF層能接在神經(jīng)網(wǎng)絡(luò)最后一層的主要原因!大家對此一定要有深刻的理解和認(rèn)識。

第二個就是transition score(轉(zhuǎn)移分?jǐn)?shù)),這個倒是跟HMM中的狀態(tài)概率轉(zhuǎn)移矩陣相類似,也很好理解,也是模型中主要學(xué)習(xí)的參數(shù)!而且為了使模型更具有魯棒性,我們額外增加了倆個標(biāo)簽,STARTENDSTART代表句子的開始位置,而非第一個詞,同理END代表句子的結(jié)束位置,這里也貼一張transition score矩陣的圖,方便大家理解,

transition score 圖.png

大家從圖中應(yīng)該很清楚可以看到,其能學(xué)到很多約束規(guī)則!因此,該矩陣也是模型主要訓(xùn)練的一個參數(shù),一般一開始都會初始化一個概率轉(zhuǎn)移矩陣,隨著訓(xùn)練的迭代,逐漸合理~因此,接下來,我們就要來看看,其損失函數(shù)是如何設(shè)計的,才能學(xué)到合理的參數(shù)~

2.2.2 損失函數(shù)的設(shè)計

先明確一點,損失函數(shù)就是我們要優(yōu)化的目標(biāo),那對于這樣一個序列標(biāo)注問題,我們肯定是希望,正確的序列,是所有的可能序列中,得分最高的!就如同我們作HMM解碼的時候,利用維特比算法解碼,我們返回的肯定是概率最大的那條路徑一般,那反過來,我們訓(xùn)練的時候,自然希望得到的參數(shù),能使得正確路徑的概率最大~

有了上述的核心思想,我們再來想一想,如何求解。假設(shè)一共有N種可能的標(biāo)簽序列組合,記第i個標(biāo)簽序列的得分為P_i,那么所有可能標(biāo)簽序列組合的總得分為,

P_{total} = P_1 + P_2 + ... + P_N = e^{S_1} +e^{S_2} +... + e^{S_N}

因此,我們可以設(shè)想出一個損失函數(shù),就是真實序列的分?jǐn)?shù)在所有可能的序列中占比最高,即,

L = \frac{P_{real}}{P_{total}}

由這個損失函數(shù),引出2個問題,

1)如何定義計算每一個序列的得分?
2)如何計算所有標(biāo)簽序列的總得分?

2.2.3 求解一個序列的得分

先來看第一個問題,上述曾提過倆個分?jǐn)?shù)概念,emission score和transition score,因此,一個序列的得分也有這倆個構(gòu)成。

S_i = EmissionScore + TransitionScore

看到這個公式,大家再與CRF定義相聯(lián)系,有木有看出點什么花頭?沒錯,這個跟CRF定義的條件概率幾乎是類似的哦,基本上可以理解為,BiLSTM的輸出(也就是EmissionScore)取代來狀態(tài)特征函數(shù)的位置,而我們要學(xué)習(xí)的參數(shù)也就是轉(zhuǎn)移特征函數(shù)及其權(quán)重。所有,CRF與神經(jīng)網(wǎng)絡(luò)的配套組合并不是強(qiáng)行加上或者巧合,而是有理論作強(qiáng)支撐的哈哈~

我們逐一來理解每一個分?jǐn)?shù)的計算過程,假設(shè)我們有一個正確的序列標(biāo)注為,[START, B-Person, I-Person, O, B-Organization, O, END]

那么,
EmissionScore = x_{0, START} + x_{1, B-Person} + ... x_{6, END}

其中,x_{index,label}就表示第index個詞被標(biāo)記為label的得分(直接是從神經(jīng)網(wǎng)絡(luò)的輸出能拿到的)。

而另一個轉(zhuǎn)移分?jǐn)?shù)即為,
TransitionScore = t_{START, B-Person} + t_{B-Person, I-Person} + ... t_{O, END}

其中,t_{label1, label2}就表示label1 到 label2的轉(zhuǎn)移概率,也就是模型要學(xué)習(xí)的參數(shù)~

2.2.4 計算所有序列總分?jǐn)?shù)的方法

至此,每一條路徑的總得分,就可以根據(jù)上面的式子很輕松的計算得出~但顯然,如果真實計算也是如此遍歷操作的話,時間復(fù)雜度會吃不消的,因此我們需要一個高效的算法來計算~

我們先簡化一下?lián)p失函數(shù),

損失函數(shù)的簡化.png

簡化之后,可以很輕松的看出,前半部分的計算是固定的,我們只需要高效的計算出后半部分即可,

e^{S_1} + e^{S_2} + ... + e^{S_N}

很明顯,這里會運用到動態(tài)規(guī)劃的思想(不懂動態(tài)規(guī)劃的,直接去看一下維特比算法,加深理解)來進(jìn)行求解,即利用w_0的總得分來推出w_1的總得分,最后以此類推,每一次計算都需要利用到上一步計算得到的結(jié)果。

這里舉一個簡單的示例,假設(shè)句子長度為3([w_0, w_1, w_2]),標(biāo)簽有2個([l_1, l_2]),我們學(xué)到的Emission Score 矩陣如下(BiLSTM輸出),

l_1 l_2
w_0 x_{01} x_{02}
w_1 x_{11} x_{12}
w_2 x_{21} x_{22}

學(xué)習(xí)到的Transition Score矩陣如下,

l_1 l_2
l_1 t_{11} t_{12}
l_2 t_{21} t_{22}

接下來,將演示如何計算總得分,因為是動態(tài)規(guī)劃的思想,只需演繹出其中一步即可~

針對w_0,很輕松,因為沒有轉(zhuǎn)移分?jǐn)?shù),僅有發(fā)射分?jǐn)?shù),因此在第一個位置的總得分即為兩種路徑的分?jǐn)?shù)總和,而現(xiàn)在兩種路徑就是兩種可能性,要么就是標(biāo)簽1要么就是標(biāo)簽2,而這也是整個動態(tài)規(guī)劃開始的初始條件~)

S_{w0} = log(e^{x_{01}} + e^{x_{02}})

接下來,我們要求在第二個位置w_1的總得分,注意我們的推導(dǎo)式子就是由w_0推出的w_1,因此我們直接利用w_0計算得出的分?jǐn)?shù)即可~如下圖所演示的~

路徑求和示意圖.png

上述的動態(tài)核心式子即為,

S_{ij} = S_{i-1j} + t_{ij} + x_{ij}

最終,將在w1位置的所有狀態(tài)求和得分相加即是總路徑的得分~

有人可能會問,那你這不是只有w_1一個位置的得分了嗎?我們不是要求得總路徑得分嗎?有這個疑惑的同學(xué)應(yīng)該就是還沒有領(lǐng)悟到動態(tài)規(guī)劃的精髓,建議自己手動推導(dǎo)一遍,便可迎刃而解~至此,整一個所有序列路徑的求和方法,我們已經(jīng)大致了解清楚了~(多提一句,預(yù)測階段的解碼思路,其實就是維特比算法,也是動態(tài)規(guī)劃的思路,十分簡單,這里就不多說了~)

2.3 實戰(zhàn)環(huán)節(jié)

上面兩節(jié)已經(jīng)把BiLSTM+CRF講的清清楚楚了~光看理論還不夠,我們要深入代碼實戰(zhàn)環(huán)節(jié)(注:此乃網(wǎng)上找的一個Pytorch的版本,個人覺得是寫的比較好的,只限用于理解理論,并非商業(yè)應(yīng)用)

我們首先導(dǎo)入相應(yīng)的包和定義一些后面要用到的輔助函數(shù),如下,

import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)

# some helper functions
def argmax(vec):
    # return the argmax as a python int
    # 第1維度上最大值的下標(biāo)
    # input: tensor([[2,3,4]])
    # output: 2
    _, idx = torch.max(vec,1)
    return idx.item()

def prepare_sequence(seq,to_ix):
    # 文本序列轉(zhuǎn)化為index的序列形式
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

def log_sum_exp(vec):
    #compute log sum exp in a numerically stable way for the forward algorithm
    # 用數(shù)值穩(wěn)定的方法計算正演算法的對數(shù)和exp
    # input: tensor([[2,3,4]])
    # max_score_broadcast: tensor([[4,4,4]])
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1,-1).expand(1,vec.size()[1])
    return max_score+torch.log(torch.sum(torch.exp(vec-max_score_broadcast)))

START_TAG = "<s>"
END_TAG = "<e>"

這里定義的幾個輔助函數(shù)都比較直觀,唯獨log_sum_exp可能會對大家造成一點困擾,但其實這是一種考慮數(shù)值穩(wěn)定性的求解辦法,具體大家參考這篇博文即可,深究一下也是好事情,不深究的就明白這個函數(shù)是為了求即可~

log(e^{S_1} + e^{S_2} ... e^{S_n})

我們接著看模型的定義,

# create model
class BiLSTM_CRF(nn.Module):
    def __init__(self,vocab_size, tag2ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF,self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.tag2ix = tag2ix
        self.tagset_size = len(tag2ix)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim//2, num_layers=1, bidirectional=True)

        # maps output of lstm to tog space
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # matrix of transition parameters
        # entry i, j is the score of transitioning to i from j
        # tag間的轉(zhuǎn)移矩陣,是CRF層的參數(shù)
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))

        # these two statements enforce the constraint that we never transfer to the start tag
        # and we never transfer from the stop tag
        self.transitions.data[tag2ix[START_TAG], :] = -10000
        self.transitions.data[:, tag2ix[END_TAG]] = -10000

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1,self.hidden_dim//2),
                torch.randn(2, 1,self.hidden_dim//2))

    def _forward_alg(self, feats):
        # to compute partition function
        # 求歸一化項的值,應(yīng)用動態(tài)歸化算法
        init_alphas = torch.full((1,self.tagset_size), -10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
        # START_TAG has all of the score
        init_alphas[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])

        forward_var = init_alphas

        for feat in feats:
            #feat指Bi-LSTM模型每一步的輸出,大小為tagset_size
            alphas_t = []
            for next_tag in range(self.tagset_size):
                # 取其中的某個tag對應(yīng)的值進(jìn)行擴(kuò)張至(1,tagset_size)大小
                # 如tensor([3]) -> tensor([[3,3,3,3,3]])
                emit_score = feat[next_tag].view(1,-1).expand(1,self.tagset_size)
                # 增維操作
                trans_score = self.transitions[next_tag].view(1,-1)
                # 上一步的路徑和+轉(zhuǎn)移分?jǐn)?shù)+發(fā)射分?jǐn)?shù)
                next_tag_var = forward_var + trans_score + emit_score
                # log_sum_exp求和
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            # 增維
            forward_var = torch.cat(alphas_t).view(1,-1)
        terminal_var = forward_var+self.transitions[self.tag2ix[END_TAG]]
        alpha = log_sum_exp(terminal_var)
        #歸一項的值
        return alpha

    def _get_lstm_features(self,sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence),1,-1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self,feats,tags):
        # gives the score of a provides tag sequence
        # 求某一路徑的值
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag2ix[START_TAG]], dtype=torch.long), tags])
        for i , feat in enumerate(feats):
            score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag2ix[END_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        # 當(dāng)參數(shù)確定的時候,求解最佳路徑
        backpointers = []

        init_vars = torch.full((1,self.tagset_size),-10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
        init_vars[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])

        forward_var = init_vars
        for feat in feats:
            bptrs_t = [] # holds the back pointers for this step
            viterbivars_t = [] # holds the viterbi variables for this step

            for next_tag in range(self.tagset_size):
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag2ix[END_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag2ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        # 由lstm層計算得的每一時刻屬于某一tag的值
        feats = self._get_lstm_features(sentence)
        # 歸一項的值
        forward_score = self._forward_alg(feats)
        # 正確路徑的值
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score# -(正確路徑的分值  -  歸一項的值)

    def forward(self, sentence):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)

        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

上面的注釋應(yīng)該說是很詳細(xì)了,一開始的初始化定義也都是Pytorch的常規(guī)寫法(簡書的代碼顯示的略詭異,大家將就看看吧)~LSTM層也是直接掉的nn里面的,只有CRF層是自己手?jǐn)]上來的~所以,大家重點關(guān)注一下_forward_alg這個函數(shù),就是我們上面講的求解路徑總得分的函數(shù)~其中feats就是序列步長,自然是要順序遍歷每一個feat,其中每一個feat又要遍歷每一種tag的情況,利用forward_var記錄每一個路徑的總得分(實時更新),最后在求和即可!應(yīng)該說看懂了上面的解釋的同學(xué),在看這個代碼,簡直是太簡單了哈哈~其他的函數(shù)也沒啥好特地強(qiáng)調(diào)的,大家掃一眼明白即可,對解碼不清楚的,直接看代碼也難,手動推演一遍,理解的更快~

最后,我們再來看一下主函數(shù),

if __name__ == "__main__":
    EMBEDDING_DIM = 5
    HIDDEN_DIM = 4

    # Make up some training data
    training_data = [(
        "the wall street journal reported today that apple corporation made money".split(),
        "B I I I O O O B I O O".split()
    ), (
        "georgia tech is a university in georgia".split(),
        "B I O O O O B".split()
    )]

    word2ix = {}
    for sentence, tags in training_data:
        for word in sentence:
            if word not in word2ix:
                word2ix[word] = len(word2ix)

    tag2ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, END_TAG: 4}

    model = BiLSTM_CRF(len(word2ix), tag2ix, EMBEDDING_DIM, HIDDEN_DIM)
    optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

    # Check predictions before training
    # 輸出訓(xùn)練前的預(yù)測序列
    with torch.no_grad():
        precheck_sent = prepare_sequence(training_data[0][0], word2ix)
        precheck_tags = torch.tensor([tag2ix[t] for t in training_data[0][1]], dtype=torch.long)
        print(model(precheck_sent))

    # Make sure prepare_sequence from earlier in the LSTM section is loaded
    for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data
        for sentence, tags in training_data:
            # Step 1. Remember that Pytorch accumulates gradients.
            # We need to clear them out before each instance
            model.zero_grad()

            # Step 2. Get our inputs ready for the network, that is,
            # turn them into Tensors of word indices.
            sentence_in = prepare_sequence(sentence, word2ix)
            targets = torch.tensor([tag2ix[t] for t in tags], dtype=torch.long)

            # Step 3. Run our forward pass.
            loss = model.neg_log_likelihood(sentence_in, targets)

            # Step 4. Compute the loss, gradients, and update the parameters by
            # calling optimizer.step()
            loss.backward()
            optimizer.step()

    # Check predictions after training
    with torch.no_grad():
        precheck_sent = prepare_sequence(training_data[0][0], word2ix)
        print(model(precheck_sent))

    # 輸出結(jié)果
    # (tensor(-9996.9365), [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
    # (tensor(-9973.2725), [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])

也是比較常規(guī)的寫法,還帶了示例~大家應(yīng)該很容易理解的!

至此,整一套跟CRF有關(guān)的知識點和代碼解釋已經(jīng)全部弄清楚了。簡單總結(jié)一下本文,先是詳細(xì)解釋了一下與CRF有關(guān)的一些誤區(qū)和知識點,接著展示了與CRF有關(guān)的用法和計算損失函數(shù)的方法,最后獻(xiàn)上了詳細(xì)的代碼解讀~希望大家讀完本文后對CRF的一些概念會有一個全新的認(rèn)識。有說的不對的地方也請大家指出,多多交流,大家一起進(jìn)步~??

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容