在NLP領(lǐng)域,在神經(jīng)網(wǎng)絡(luò)興起之前,條件隨機(jī)場(CRF)一直是作為主力模型的存在,就算是在RNN系(包括BERT系)的模型興起之后,也通常會在模型的最后添加一個CRF層,以提高準(zhǔn)確率。因此,CRF是所有NLPer必須要精通且掌握的一個模型,本文將優(yōu)先闡述清楚與CRF有關(guān)的全部基本概念,并詳細(xì)對比HMM,最后獻(xiàn)上BI-LSTM+CRF的實戰(zhàn)代碼及理解。相信讀完本文,將對CRF的認(rèn)識有一個新的高度。
在閱讀本文之前,務(wù)必對概率圖模型基礎(chǔ)有一個全盤的掌握,若對此沒有信心的,可以先參考我之前的一篇總結(jié)文:概率圖模型基礎(chǔ)
本文的基本目錄如下:
基礎(chǔ)知識
1.1 CRF到底是什么?
1.2 如何用CRF建模?
1.3 CRF與HMM的區(qū)別是什么?BILSTM+CRF實戰(zhàn)
2.1 為什么需要添加CRF層?
2.2 如何計算損失函數(shù)?
2.3 實戰(zhàn)環(huán)節(jié)
------------------第一菇 - 基礎(chǔ)知識------------------
1.1 CRF到底是什么?
本段主要用于講述與CRF有關(guān)的基礎(chǔ)概念。
大部分人理解CRF都會被帶到一個奇怪的誤區(qū)里面(包括我之前),因為總是理解完了HMM以后,就會立馬投入到CRF的學(xué)習(xí)里面,所以就會理所當(dāng)然的認(rèn)為CRF就是HMM的升級版(確實從模型效果上可以這么理解),然后一直把HMM的各自概念往CRF上套,之后兩廂一對比,就會有點犯迷糊了,好多東西也對不上啊??~然后,再一看書里的結(jié)論,啥?CRF竟然是判別式模型,HMM是生成式模型!這是什么鬼啦?CRF不就是HMM解除各自限制(有向圖變無向圖,箭頭指的方向更多了嗎??????),怎么突然就變成判別式模型啦???廢話不多說,如果有此疑問的同學(xué),就說明看對文章了,不要心急慢慢看,我將一一解釋清楚;而對此毫無疑問的同學(xué),那可以直接跳到實戰(zhàn)環(huán)節(jié)了哈。
CRF真的是判別式模型!準(zhǔn)確說是,判別式無向圖模型!
大家要牢記,區(qū)別判別式模型與生成式模型最基本的就是去判斷模型是對聯(lián)合分布進(jìn)行建模,還是對條件分布進(jìn)行建模。HMM中很顯然,模型是對的聯(lián)合分布進(jìn)行建模(不清楚的同學(xué),還請移步HMM專區(qū)),而CRF則不然,其試圖對多個變量在給定觀測值后的條件概率進(jìn)行建模,因此屬于判別式模型。(各位抱著學(xué)新東西的心態(tài)來學(xué)CRF,把HMM拋在腦后把)
具體展開來看,若令為觀測序列,
為與之對應(yīng)的標(biāo)記序列,則條件隨機(jī)場的目標(biāo)是構(gòu)建條件概率模型
。值得注意的是,標(biāo)記變量
可以是結(jié)構(gòu)型變量,即其分量之間具有某種相關(guān)性。就比如在NLP領(lǐng)域中的詞性標(biāo)注任務(wù),觀測數(shù)據(jù)為單詞序列(即為
),標(biāo)記為相應(yīng)的詞性序列(即為
),且其具有線性序列結(jié)構(gòu)。
1.2 如何用CRF建模?
令表示結(jié)點與標(biāo)記變量
中元素一一對應(yīng)的無向圖,
表示與結(jié)點
對應(yīng)的標(biāo)記變量,
表示結(jié)點
的鄰接結(jié)點,若圖
的每個變量
都滿足馬爾可夫性(即只與其相鄰的結(jié)點有關(guān)),即,
則構(gòu)成一個條件隨機(jī)場。而理論上來說,圖G可具有任意結(jié)構(gòu),只要能表示標(biāo)記變量之間的條件獨立性關(guān)系即可,但在現(xiàn)實應(yīng)用中,尤其是對標(biāo)記序列建模時候,最常用的仍然是鏈?zhǔn)浇Y(jié)構(gòu),即“鏈?zhǔn)綏l件隨機(jī)場(chain-structured CRF)”,也是我們接下來主要要討論的一種條件隨機(jī)場。

那我們該如何定義呢?
參考《機(jī)器學(xué)習(xí)》第14章的原文,其定義的方式類似馬爾可夫隨機(jī)場模型定義的聯(lián)合概率。條件隨機(jī)場使用勢函數(shù)和圖結(jié)構(gòu)上的團(tuán)來定義條件概率!如上圖所示,該鏈?zhǔn)綏l件隨機(jī)場主要包含兩種關(guān)于標(biāo)記變量的團(tuán),即單個標(biāo)記變量
以及相鄰的標(biāo)記變量
。選擇合適的勢函數(shù),即可得到形如馬爾可夫隨機(jī)場中聯(lián)合概率的定義。
在條件隨機(jī)場中,通過選用指數(shù)勢函數(shù)并引入特征函數(shù),條件概率被定義如下,
注意哈,這里的指的是整一個觀測序列!而且這里定義的條件概率計算方式,就只是將觀測序列
作為條件,并不對其作任何獨立性假設(shè)?。。。ㄟ@點很重要!也是其是判別式模型的重要依據(jù))
其中,是定義在觀測序列的倆個相鄰標(biāo)記位置上的轉(zhuǎn)移特征函數(shù),用于刻畫相鄰標(biāo)記變量之間的相關(guān)關(guān)系以及觀測序列對他們的影響。即給定觀測序列
,其標(biāo)注序列在
及
位置上標(biāo)記的轉(zhuǎn)移概率!而特征函數(shù)的定義往往不止一種,因此會有一個下標(biāo)
代表要遍歷計算每一種特征函數(shù)的取值。
另外,是定義在觀測序列的標(biāo)記位置
上的狀態(tài)特征函數(shù),用于刻畫觀測序列對標(biāo)記變量的影響。即表示對于觀察序列
其
位置的標(biāo)記概率。同理,也有多種特征函數(shù),所以會有下標(biāo)
。
剩下的就比較簡單理解,都是參數(shù),
為規(guī)范化因子,用于確保上式是被正確定義的概率(可以理解為類似softmax的操作)。
總結(jié)一下上式,可以理解為如下圖,

至此,整個概率的定義想必大家已經(jīng)爛熟于心了~顯然,要運用好條件隨機(jī)場,最重要的就是要去定義合適的特征函數(shù)了。特征函數(shù)通常是實值函數(shù),以刻畫數(shù)據(jù)的一些很可能成立或期望成立的經(jīng)驗特性。因此定義特征函數(shù)的時候,一般都可以定義一組關(guān)于觀察序列的二值特征
來表示訓(xùn)練樣本中某些分布特性,比如詞性標(biāo)注任務(wù),
等等類似的特征函數(shù),能定義好多出來的。因此,小小總結(jié)一下,CRF與馬爾可夫隨機(jī)場均使用團(tuán)上的勢函數(shù)定義概率,兩者在形式上并沒有顯著區(qū)別,只不過CRF處理的是條件概率,而馬爾可夫隨機(jī)場處理的是聯(lián)合概率。至此,整個CRF的建模已經(jīng)講明白了。
1.3 CRF與HMM的區(qū)別是什么?
CRF與HMM的一些基本定義的概念區(qū)別這邊在講概率圖模型和上面的基礎(chǔ)定義時已經(jīng)表述的很清楚了,本段就不繼續(xù)展開了~這里主要講一下HMM的標(biāo)注偏置問題,以及CRF為何能解決這個問題。
其實要想解釋清楚標(biāo)注偏置問題,大家只要看如下貼的一張圖即可,

大家可以發(fā)現(xiàn),狀態(tài)1傾向于轉(zhuǎn)移到狀態(tài)2,狀態(tài)2傾向于轉(zhuǎn)移到狀態(tài)2本身,但是實際計算得到的最大概率路徑是1>1>1>1,狀態(tài)1并沒有轉(zhuǎn)移到狀態(tài)2!這其實是與我們的直覺相悖的~究其本質(zhì)原因,從狀態(tài)2轉(zhuǎn)移出去可能的狀態(tài)包括1,2,3,4,5,概率在可能的狀態(tài)上分散了,而狀態(tài)1轉(zhuǎn)移出去的可能狀態(tài)僅僅為狀態(tài)1和2,概率更加集中?。ù蠹铱梢阅霉P算一下,是不是這么個理~加深理解)由于局部歸一化的影響,隱狀態(tài)會傾向于轉(zhuǎn)移到那些后續(xù)狀態(tài)可能更少的狀態(tài)上,以提高整體的后驗概率!這就是標(biāo)注偏置問題!
而CRF如上所述,因為有歸一化因子(Z)的存在,其在全局范圍內(nèi)進(jìn)行了歸一化,枚舉了整個隱狀態(tài)序列的全部可能,從而解決了局部歸一化帶來的標(biāo)注偏置問題。而這也是CRF在很多問題上,表現(xiàn)比HMM優(yōu)秀的原因~
------------------第二菇 - BILSTM+CRF實戰(zhàn)------------------
介紹了這么多CRF有關(guān)的東西,想必各位也是躍躍欲試,我這邊也獻(xiàn)上一份BILSTM+CRF的實戰(zhàn)解析,包括對此模型架構(gòu)的理解以及源碼的解讀~
這里貼一個鏈接,是一個外國小哥寫的博客,當(dāng)初就是看這個博客明白其原理的,所以也特地在這邊貼出來,英文好的同學(xué)也可以直接看這個鏈接,我就不單獨放在參考文獻(xiàn)里了。
為了便于理解,代碼都是用Pytorch寫的,且還是以命名實體識別任務(wù)為具體例子。
如果看到這篇文章的是初學(xué)者,也不用慌,就簡單理解BILSTM和CRF為一個命名實體識別模型中的兩個層。
為了便于理解下面的圖示,這邊假設(shè)我們的數(shù)據(jù)集有兩大類,人名和地名,與之相對應(yīng)在我們的訓(xùn)練數(shù)據(jù)集中,有五類標(biāo)簽:
* B-Person
* I-Person
* B-Organization
* I-Organization
* O
假設(shè)句子由5個字符組成,即
,其中
為人名實體,
為組織實體,其他字符的標(biāo)簽為"O"。
2.1 為什么需要添加CRF層?
這里先直接貼一張BILSTM-CRF的模型結(jié)構(gòu)圖,方便大家理解。

從下往上看,最下面就是輸入(字或詞向量),由于是序列模型,因此,在“時間”緯度上進(jìn)行展開,就可以得到如圖所示的模型表示,對應(yīng)于一個時刻,就是輸入一個字/詞向量(一般都是預(yù)先訓(xùn)練得出的)。
首先,是經(jīng)過BiLSTM的結(jié)構(gòu)單元。這個比較好理解,本質(zhì)上就是倆個LSTM層,只不過一次是正序輸入,一次是倒序輸入,然后把倆個結(jié)果進(jìn)行concact(拼接),并輸入到CRF層,最后由CRF層輸出每一個詞的標(biāo)簽~如果沒有CRF層的話,傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)都會加一層softmax層用于歸一化并輸出每個標(biāo)簽概率。
為了更容易理解CRF層的作用,我們還是先要理清Bi-LSTM的輸出。這里再貼一張圖,方便大家理解,

大家可以看到,其輸出十分簡單清晰,就是對于每一個單詞,其對應(yīng)每一個標(biāo)簽的分值(score)。因此,就算沒有CRF層,該模型依舊有效,我們只需要挑選每一個標(biāo)簽對應(yīng)最大的分值就可以,比如,就是
。因此,在原有模型本身就有效的情況下,我們再添加一層CRF的目的肯定只有一個,即提高模型的準(zhǔn)確率。
接下來,我們就要重點分析一下,CRF層的作用。先上結(jié)論,CRF層的主要作用是為最后預(yù)測的標(biāo)簽添加一些約束來保證預(yù)測標(biāo)簽的合理性!比如,在命名實體識別任務(wù)中,我們可以想到的約束可以是,
1)開頭的標(biāo)簽只能是B, O,而不可能是I-
2)B-Person開頭的標(biāo)簽,后面不可能接一個I-Organization
。。。
有了如上約束,我們就能保證,最終預(yù)測生成的標(biāo)簽序列的不合理性就會大大降低,而單憑BiLSTM的輸出來預(yù)測是無法保證標(biāo)簽序列的合理性的~
2.2 損失函數(shù)的定義及計算
弄清楚了CRF層的作用以后,我們就要來仔細(xì)研究研究CRF層的運行原理了,主要從其損失函數(shù)的角度來理解。
2.2.1 CRF中的兩種分?jǐn)?shù)
在CRF層的損失函數(shù)中,有兩種形式的score(分?jǐn)?shù)),第一個就是emission score(發(fā)射分?jǐn)?shù)),主要就是來自于BiLSTM層的輸出(如上圖所示),假設(shè)我們給每一個標(biāo)簽一個索引,那么第一個單詞的emission score就是,~
注:大家可千萬注意了,不要把這里的score和HMM里面的發(fā)射概率矩陣相混淆!兩者可是完全不一樣的,在HMM中,每一個單詞的發(fā)射概率,僅與當(dāng)前的隱狀態(tài)層有關(guān)!是由隱狀態(tài)決定了當(dāng)前單詞的發(fā)射概率!而這里的發(fā)射分?jǐn)?shù),是由當(dāng)前輸入的序列,決定的當(dāng)前狀態(tài)的概率!這里是整一個序列哦!若大家還能記得CRF層的定義和概率模型圖(往上翻一翻),想必對此并不會驚訝!而這,也是CRF層能接在神經(jīng)網(wǎng)絡(luò)最后一層的主要原因!大家對此一定要有深刻的理解和認(rèn)識。
第二個就是transition score(轉(zhuǎn)移分?jǐn)?shù)),這個倒是跟HMM中的狀態(tài)概率轉(zhuǎn)移矩陣相類似,也很好理解,也是模型中主要學(xué)習(xí)的參數(shù)!而且為了使模型更具有魯棒性,我們額外增加了倆個標(biāo)簽,和
,
代表句子的開始位置,而非第一個詞,同理
代表句子的結(jié)束位置,這里也貼一張transition score矩陣的圖,方便大家理解,

大家從圖中應(yīng)該很清楚可以看到,其能學(xué)到很多約束規(guī)則!因此,該矩陣也是模型主要訓(xùn)練的一個參數(shù),一般一開始都會初始化一個概率轉(zhuǎn)移矩陣,隨著訓(xùn)練的迭代,逐漸合理~因此,接下來,我們就要來看看,其損失函數(shù)是如何設(shè)計的,才能學(xué)到合理的參數(shù)~
2.2.2 損失函數(shù)的設(shè)計
先明確一點,損失函數(shù)就是我們要優(yōu)化的目標(biāo),那對于這樣一個序列標(biāo)注問題,我們肯定是希望,正確的序列,是所有的可能序列中,得分最高的!就如同我們作HMM解碼的時候,利用維特比算法解碼,我們返回的肯定是概率最大的那條路徑一般,那反過來,我們訓(xùn)練的時候,自然希望得到的參數(shù),能使得正確路徑的概率最大~
有了上述的核心思想,我們再來想一想,如何求解。假設(shè)一共有種可能的標(biāo)簽序列組合,記第i個標(biāo)簽序列的得分為
,那么所有可能標(biāo)簽序列組合的總得分為,
因此,我們可以設(shè)想出一個損失函數(shù),就是真實序列的分?jǐn)?shù)在所有可能的序列中占比最高,即,
由這個損失函數(shù),引出2個問題,
1)如何定義計算每一個序列的得分?
2)如何計算所有標(biāo)簽序列的總得分?
2.2.3 求解一個序列的得分
先來看第一個問題,上述曾提過倆個分?jǐn)?shù)概念,emission score和transition score,因此,一個序列的得分也有這倆個構(gòu)成。
看到這個公式,大家再與CRF定義相聯(lián)系,有木有看出點什么花頭?沒錯,這個跟CRF定義的條件概率幾乎是類似的哦,基本上可以理解為,BiLSTM的輸出(也就是EmissionScore)取代來狀態(tài)特征函數(shù)的位置,而我們要學(xué)習(xí)的參數(shù)也就是轉(zhuǎn)移特征函數(shù)及其權(quán)重。所有,CRF與神經(jīng)網(wǎng)絡(luò)的配套組合并不是強(qiáng)行加上或者巧合,而是有理論作強(qiáng)支撐的哈哈~
我們逐一來理解每一個分?jǐn)?shù)的計算過程,假設(shè)我們有一個正確的序列標(biāo)注為,
那么,
其中,就表示第index個詞被標(biāo)記為label的得分(直接是從神經(jīng)網(wǎng)絡(luò)的輸出能拿到的)。
而另一個轉(zhuǎn)移分?jǐn)?shù)即為,
其中,就表示label1 到 label2的轉(zhuǎn)移概率,也就是模型要學(xué)習(xí)的參數(shù)~
2.2.4 計算所有序列總分?jǐn)?shù)的方法
至此,每一條路徑的總得分,就可以根據(jù)上面的式子很輕松的計算得出~但顯然,如果真實計算也是如此遍歷操作的話,時間復(fù)雜度會吃不消的,因此我們需要一個高效的算法來計算~
我們先簡化一下?lián)p失函數(shù),

簡化之后,可以很輕松的看出,前半部分的計算是固定的,我們只需要高效的計算出后半部分即可,
很明顯,這里會運用到動態(tài)規(guī)劃的思想(不懂動態(tài)規(guī)劃的,直接去看一下維特比算法,加深理解)來進(jìn)行求解,即利用的總得分來推出
的總得分,最后以此類推,每一次計算都需要利用到上一步計算得到的結(jié)果。
這里舉一個簡單的示例,假設(shè)句子長度為3(),標(biāo)簽有2個(
),我們學(xué)到的Emission Score 矩陣如下(BiLSTM輸出),
學(xué)習(xí)到的Transition Score矩陣如下,
接下來,將演示如何計算總得分,因為是動態(tài)規(guī)劃的思想,只需演繹出其中一步即可~
針對,很輕松,因為沒有轉(zhuǎn)移分?jǐn)?shù),僅有發(fā)射分?jǐn)?shù),因此在第一個位置的總得分即為兩種路徑的分?jǐn)?shù)總和,而現(xiàn)在兩種路徑就是兩種可能性,要么就是標(biāo)簽1要么就是標(biāo)簽2,而這也是整個動態(tài)規(guī)劃開始的初始條件~)
接下來,我們要求在第二個位置的總得分,注意我們的推導(dǎo)式子就是由
推出的
,因此我們直接利用
計算得出的分?jǐn)?shù)即可~如下圖所演示的~

上述的動態(tài)核心式子即為,
最終,將在位置的所有狀態(tài)求和得分相加即是總路徑的得分~
有人可能會問,那你這不是只有一個位置的得分了嗎?我們不是要求得總路徑得分嗎?有這個疑惑的同學(xué)應(yīng)該就是還沒有領(lǐng)悟到動態(tài)規(guī)劃的精髓,建議自己手動推導(dǎo)一遍,便可迎刃而解~至此,整一個所有序列路徑的求和方法,我們已經(jīng)大致了解清楚了~(多提一句,預(yù)測階段的解碼思路,其實就是維特比算法,也是動態(tài)規(guī)劃的思路,十分簡單,這里就不多說了~)
2.3 實戰(zhàn)環(huán)節(jié)
上面兩節(jié)已經(jīng)把BiLSTM+CRF講的清清楚楚了~光看理論還不夠,我們要深入代碼實戰(zhàn)環(huán)節(jié)(注:此乃網(wǎng)上找的一個Pytorch的版本,個人覺得是寫的比較好的,只限用于理解理論,并非商業(yè)應(yīng)用)
我們首先導(dǎo)入相應(yīng)的包和定義一些后面要用到的輔助函數(shù),如下,
import torch
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(1)
# some helper functions
def argmax(vec):
# return the argmax as a python int
# 第1維度上最大值的下標(biāo)
# input: tensor([[2,3,4]])
# output: 2
_, idx = torch.max(vec,1)
return idx.item()
def prepare_sequence(seq,to_ix):
# 文本序列轉(zhuǎn)化為index的序列形式
idxs = [to_ix[w] for w in seq]
return torch.tensor(idxs, dtype=torch.long)
def log_sum_exp(vec):
#compute log sum exp in a numerically stable way for the forward algorithm
# 用數(shù)值穩(wěn)定的方法計算正演算法的對數(shù)和exp
# input: tensor([[2,3,4]])
# max_score_broadcast: tensor([[4,4,4]])
max_score = vec[0, argmax(vec)]
max_score_broadcast = max_score.view(1,-1).expand(1,vec.size()[1])
return max_score+torch.log(torch.sum(torch.exp(vec-max_score_broadcast)))
START_TAG = "<s>"
END_TAG = "<e>"
這里定義的幾個輔助函數(shù)都比較直觀,唯獨log_sum_exp可能會對大家造成一點困擾,但其實這是一種考慮數(shù)值穩(wěn)定性的求解辦法,具體大家參考這篇博文即可,深究一下也是好事情,不深究的就明白這個函數(shù)是為了求即可~
我們接著看模型的定義,
# create model
class BiLSTM_CRF(nn.Module):
def __init__(self,vocab_size, tag2ix, embedding_dim, hidden_dim):
super(BiLSTM_CRF,self).__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.tag2ix = tag2ix
self.tagset_size = len(tag2ix)
self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim//2, num_layers=1, bidirectional=True)
# maps output of lstm to tog space
self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
# matrix of transition parameters
# entry i, j is the score of transitioning to i from j
# tag間的轉(zhuǎn)移矩陣,是CRF層的參數(shù)
self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
# these two statements enforce the constraint that we never transfer to the start tag
# and we never transfer from the stop tag
self.transitions.data[tag2ix[START_TAG], :] = -10000
self.transitions.data[:, tag2ix[END_TAG]] = -10000
self.hidden = self.init_hidden()
def init_hidden(self):
return (torch.randn(2, 1,self.hidden_dim//2),
torch.randn(2, 1,self.hidden_dim//2))
def _forward_alg(self, feats):
# to compute partition function
# 求歸一化項的值,應(yīng)用動態(tài)歸化算法
init_alphas = torch.full((1,self.tagset_size), -10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
# START_TAG has all of the score
init_alphas[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])
forward_var = init_alphas
for feat in feats:
#feat指Bi-LSTM模型每一步的輸出,大小為tagset_size
alphas_t = []
for next_tag in range(self.tagset_size):
# 取其中的某個tag對應(yīng)的值進(jìn)行擴(kuò)張至(1,tagset_size)大小
# 如tensor([3]) -> tensor([[3,3,3,3,3]])
emit_score = feat[next_tag].view(1,-1).expand(1,self.tagset_size)
# 增維操作
trans_score = self.transitions[next_tag].view(1,-1)
# 上一步的路徑和+轉(zhuǎn)移分?jǐn)?shù)+發(fā)射分?jǐn)?shù)
next_tag_var = forward_var + trans_score + emit_score
# log_sum_exp求和
alphas_t.append(log_sum_exp(next_tag_var).view(1))
# 增維
forward_var = torch.cat(alphas_t).view(1,-1)
terminal_var = forward_var+self.transitions[self.tag2ix[END_TAG]]
alpha = log_sum_exp(terminal_var)
#歸一項的值
return alpha
def _get_lstm_features(self,sentence):
self.hidden = self.init_hidden()
embeds = self.word_embeds(sentence).view(len(sentence),1,-1)
lstm_out, self.hidden = self.lstm(embeds, self.hidden)
lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
lstm_feats = self.hidden2tag(lstm_out)
return lstm_feats
def _score_sentence(self,feats,tags):
# gives the score of a provides tag sequence
# 求某一路徑的值
score = torch.zeros(1)
tags = torch.cat([torch.tensor([self.tag2ix[START_TAG]], dtype=torch.long), tags])
for i , feat in enumerate(feats):
score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
score = score + self.transitions[self.tag2ix[END_TAG], tags[-1]]
return score
def _viterbi_decode(self, feats):
# 當(dāng)參數(shù)確定的時候,求解最佳路徑
backpointers = []
init_vars = torch.full((1,self.tagset_size),-10000.)# tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
init_vars[0][self.tag2ix[START_TAG]] = 0#tensor([[-10000.,-10000.,-10000.,0,-10000.]])
forward_var = init_vars
for feat in feats:
bptrs_t = [] # holds the back pointers for this step
viterbivars_t = [] # holds the viterbi variables for this step
for next_tag in range(self.tagset_size):
next_tag_var = forward_var + self.transitions[next_tag]
best_tag_id = argmax(next_tag_var)
bptrs_t.append(best_tag_id)
viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
backpointers.append(bptrs_t)
# Transition to STOP_TAG
terminal_var = forward_var + self.transitions[self.tag2ix[END_TAG]]
best_tag_id = argmax(terminal_var)
path_score = terminal_var[0][best_tag_id]
# Follow the back pointers to decode the best path.
best_path = [best_tag_id]
for bptrs_t in reversed(backpointers):
best_tag_id = bptrs_t[best_tag_id]
best_path.append(best_tag_id)
# Pop off the start tag (we dont want to return that to the caller)
start = best_path.pop()
assert start == self.tag2ix[START_TAG] # Sanity check
best_path.reverse()
return path_score, best_path
def neg_log_likelihood(self, sentence, tags):
# 由lstm層計算得的每一時刻屬于某一tag的值
feats = self._get_lstm_features(sentence)
# 歸一項的值
forward_score = self._forward_alg(feats)
# 正確路徑的值
gold_score = self._score_sentence(feats, tags)
return forward_score - gold_score# -(正確路徑的分值 - 歸一項的值)
def forward(self, sentence): # dont confuse this with _forward_alg above.
# Get the emission scores from the BiLSTM
lstm_feats = self._get_lstm_features(sentence)
# Find the best path, given the features.
score, tag_seq = self._viterbi_decode(lstm_feats)
return score, tag_seq
上面的注釋應(yīng)該說是很詳細(xì)了,一開始的初始化定義也都是Pytorch的常規(guī)寫法(簡書的代碼顯示的略詭異,大家將就看看吧)~LSTM層也是直接掉的nn里面的,只有CRF層是自己手?jǐn)]上來的~所以,大家重點關(guān)注一下_forward_alg這個函數(shù),就是我們上面講的求解路徑總得分的函數(shù)~其中feats就是序列步長,自然是要順序遍歷每一個feat,其中每一個feat又要遍歷每一種tag的情況,利用forward_var記錄每一個路徑的總得分(實時更新),最后在求和即可!應(yīng)該說看懂了上面的解釋的同學(xué),在看這個代碼,簡直是太簡單了哈哈~其他的函數(shù)也沒啥好特地強(qiáng)調(diào)的,大家掃一眼明白即可,對解碼不清楚的,直接看代碼也難,手動推演一遍,理解的更快~
最后,我們再來看一下主函數(shù),
if __name__ == "__main__":
EMBEDDING_DIM = 5
HIDDEN_DIM = 4
# Make up some training data
training_data = [(
"the wall street journal reported today that apple corporation made money".split(),
"B I I I O O O B I O O".split()
), (
"georgia tech is a university in georgia".split(),
"B I O O O O B".split()
)]
word2ix = {}
for sentence, tags in training_data:
for word in sentence:
if word not in word2ix:
word2ix[word] = len(word2ix)
tag2ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, END_TAG: 4}
model = BiLSTM_CRF(len(word2ix), tag2ix, EMBEDDING_DIM, HIDDEN_DIM)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)
# Check predictions before training
# 輸出訓(xùn)練前的預(yù)測序列
with torch.no_grad():
precheck_sent = prepare_sequence(training_data[0][0], word2ix)
precheck_tags = torch.tensor([tag2ix[t] for t in training_data[0][1]], dtype=torch.long)
print(model(precheck_sent))
# Make sure prepare_sequence from earlier in the LSTM section is loaded
for epoch in range(300): # again, normally you would NOT do 300 epochs, it is toy data
for sentence, tags in training_data:
# Step 1. Remember that Pytorch accumulates gradients.
# We need to clear them out before each instance
model.zero_grad()
# Step 2. Get our inputs ready for the network, that is,
# turn them into Tensors of word indices.
sentence_in = prepare_sequence(sentence, word2ix)
targets = torch.tensor([tag2ix[t] for t in tags], dtype=torch.long)
# Step 3. Run our forward pass.
loss = model.neg_log_likelihood(sentence_in, targets)
# Step 4. Compute the loss, gradients, and update the parameters by
# calling optimizer.step()
loss.backward()
optimizer.step()
# Check predictions after training
with torch.no_grad():
precheck_sent = prepare_sequence(training_data[0][0], word2ix)
print(model(precheck_sent))
# 輸出結(jié)果
# (tensor(-9996.9365), [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
# (tensor(-9973.2725), [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])
也是比較常規(guī)的寫法,還帶了示例~大家應(yīng)該很容易理解的!
至此,整一套跟CRF有關(guān)的知識點和代碼解釋已經(jīng)全部弄清楚了。簡單總結(jié)一下本文,先是詳細(xì)解釋了一下與CRF有關(guān)的一些誤區(qū)和知識點,接著展示了與CRF有關(guān)的用法和計算損失函數(shù)的方法,最后獻(xiàn)上了詳細(xì)的代碼解讀~希望大家讀完本文后對CRF的一些概念會有一個全新的認(rèn)識。有說的不對的地方也請大家指出,多多交流,大家一起進(jìn)步~??