Pytorch_LSTM與GRU

RNN循環(huán)網(wǎng)絡(luò)在序列問(wèn)題處理中得到了廣泛的應(yīng)用。但使用標(biāo)準(zhǔn)版本的RNN模型時(shí),常遇到梯度消失gradient vanishing和梯度爆炸gradient explosion問(wèn)題。

RNN的缺點(diǎn)

RNN的梯度消失和梯度爆炸不同于其它網(wǎng)絡(luò),全連接網(wǎng)絡(luò)和卷積網(wǎng)絡(luò)每一層有不同參數(shù),而RNN 的每個(gè)處理單元Cell(處理單個(gè)序列元素的操作稱(chēng)為處理單元Cell)共用同一組權(quán)重矩陣W。在上一篇介紹RNN網(wǎng)絡(luò)算法時(shí)可以看到,處理單元之間是全連接關(guān)系,序列向前傳播的過(guò)程中將不斷乘以權(quán)重矩陣W,從而構(gòu)成了連乘Wn,當(dāng)W<1時(shí),如果序列很長(zhǎng),則結(jié)果趨近0;當(dāng)w>1時(shí),經(jīng)過(guò)多次迭代,數(shù)值將迅速增長(zhǎng)。反向傳播也有同樣問(wèn)題。

梯度爆炸問(wèn)題一般通過(guò)“梯度裁剪”方法改善,而梯度消失則使得序列前面的數(shù)據(jù)無(wú)法起到應(yīng)有的作用,造成“長(zhǎng)距離依賴(lài)”(Long-Term Dependencies)問(wèn)題,也就是說(shuō)RNN只能處理短距離的依賴(lài)關(guān)系。

這類(lèi)似于卷積神經(jīng)網(wǎng)絡(luò)在處理圖像問(wèn)題時(shí)加深網(wǎng)絡(luò)層數(shù),無(wú)法改進(jìn)效果。盡管理論上可以通過(guò)調(diào)參改進(jìn),但難度很大,最后圖像處理通過(guò)修改網(wǎng)絡(luò)結(jié)構(gòu)使用殘差網(wǎng)絡(luò)解決了這一問(wèn)題。同樣,RNN也改進(jìn)了結(jié)構(gòu),使用LSTM和GRU網(wǎng)絡(luò)。作為RNN的變種,它們使用率更高。

LSTM長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)

LSTM是Long Short Term Memory Networks的縮寫(xiě),即長(zhǎng)短時(shí)記憶網(wǎng)絡(luò),該方法在1997年被提出,主要用于解決“長(zhǎng)距離依賴(lài)”問(wèn)題。不同于RNN用單一的隱藏層描述規(guī)律,LSTM新增加了細(xì)胞狀態(tài)Cell state,簡(jiǎn)稱(chēng)c,并用多個(gè)門(mén)控參數(shù)分別控制讀、寫(xiě)、遺忘操作。

門(mén)控gate

門(mén)控理論源于生物學(xué),指脊髓中的一些細(xì)胞像門(mén)一樣(門(mén)開(kāi)了才能通過(guò)),切斷和阻止一些痛覺(jué)信號(hào)進(jìn)入大腦。在神經(jīng)網(wǎng)絡(luò)中通常是使用激活函數(shù)控制數(shù)據(jù)的傳輸,如激活函數(shù)sigmoid常被用于控制信號(hào)是否通過(guò),它的取值范圍從0-1,0表示阻斷,1表示完全通過(guò),0-1之間數(shù)據(jù)部分通過(guò),從而實(shí)現(xiàn)有選擇的輸入、有選擇的輸出、有選擇的記憶。

算法

上圖描述了LSTM網(wǎng)絡(luò)對(duì)輸入Xt(序列中每個(gè)元素)處理生成輸出ht的前向傳播過(guò)程。筆者將其分為六步,在圖中用圓圈加數(shù)字表示。

第一步:計(jì)算遺忘門(mén),遺忘門(mén)forget gate簡(jiǎn)稱(chēng)f,用于控制是否遺忘上一層的狀態(tài)Cell state。該門(mén)的輸入是前一個(gè)隱藏層的狀態(tài)h(t-1)以及當(dāng)前的xt,通過(guò)一個(gè)sigmoid(用σ表示)激活函數(shù),得到當(dāng)前時(shí)間t的遺忘門(mén)的值ft ,W和b是該門(mén)的參數(shù)和偏置。比如:當(dāng)輸入詞為“但是”時(shí),認(rèn)為前面的記憶不再重要,ft值為0,清空之前的記憶(只是舉例,不要較真兒)。其公式為:

第二步:計(jì)算輸入門(mén),輸入門(mén)input gate簡(jiǎn)稱(chēng)i,它用于向Cell state中增加新的內(nèi)容,該門(mén)的輸入也是前一個(gè)隱藏層的狀態(tài)h(t-1)以及當(dāng)前的xt計(jì)算it。例如:當(dāng)輸入是“,”時(shí),認(rèn)為該輸入沒(méi)有攜帶有貢獻(xiàn)的信息,it值為0,忽略該輸入。

第三步:計(jì)算輸入值,這一步類(lèi)似于RNN中計(jì)算隱藏層參數(shù)的算法,輸入也是前一個(gè)隱藏層的狀態(tài)h(t-1)以及當(dāng)前的xt計(jì)算gt,它是這一步輸入產(chǎn)生的具體影響,此處的激活函數(shù)使用tanh。

第四步:計(jì)算輸出門(mén),輸出門(mén)output gate,簡(jiǎn)稱(chēng)o,在用Cell state值計(jì)算輸出值ht的過(guò)程中用ot控制輸出,該門(mén)的輸入也是前一個(gè)隱藏層的狀態(tài)h(t-1)以及當(dāng)前的xt。

第五步:計(jì)算當(dāng)前的細(xì)胞狀態(tài)Cell state,用遺忘門(mén)f控制上一步的狀態(tài)c(t-1),用輸入門(mén)i控

制當(dāng)前輸入g,從而計(jì)算當(dāng)前狀態(tài)ct(遺忘了部分以往信息,加入了部分新信息)。

第六步:通過(guò)當(dāng)前細(xì)胞狀態(tài)c和輸出門(mén)o計(jì)算隱藏層h,最后兩步將通過(guò)各個(gè)門(mén)的數(shù)據(jù)組織起來(lái)。

標(biāo)準(zhǔn)的RNN模型比較粗糙,只調(diào)節(jié)一組參數(shù),而LSTM把問(wèn)題細(xì)化成幾個(gè)子問(wèn)題,需要反復(fù)迭代計(jì)算多組W參數(shù),運(yùn)算量比普通的RNN大很多。LSTM的核心原理是保持信息的完整性,它假設(shè)每一個(gè)狀態(tài)都是由上一狀態(tài)疊加一個(gè)變化得來(lái)的(類(lèi)似于殘差網(wǎng)絡(luò)),即兩組信息做加法,它不同于RNN的逐層做乘法,由此改進(jìn)了梯度爆炸/梯度消失的問(wèn)題。它對(duì)于較長(zhǎng)的序列效果更好。

用法

Pytorch提供的LSTM調(diào)用方法與RNN類(lèi)似,只要把上篇例程中的“RNN”改成“LSTM”即可,不需要其它調(diào)整。

與RNN不同的是,在調(diào)用前向傳遞函數(shù)forward時(shí),傳入和傳出的參數(shù)都可包含h和c兩組值,其格式為:

其中input是輸入,output是輸出,第二個(gè)參數(shù)(h0,c0)為T(mén)uple類(lèi)型,h0和c0分別是兩個(gè)隱藏層的初始值;同樣LSTM也將計(jì)算后隱藏層的值(hn,cn)作為返回值。h和c的維度是(num_layers, batch_size, hidden_size)。

GRU門(mén)控循環(huán)單元

GRU是門(mén)控循環(huán)單元Gated Recurrent Unit的縮寫(xiě),該方法在2014年被提出,是LSTM網(wǎng)絡(luò)的變體,它比LSTM網(wǎng)絡(luò)結(jié)構(gòu)更簡(jiǎn)單,邏輯更加清晰,速度更快,且效果也很好。GRU模型只有兩個(gè)門(mén):更新門(mén)和重置門(mén)。它的網(wǎng)絡(luò)結(jié)構(gòu)與RNN更為相似,在每一步接收序列中的數(shù)據(jù)輸入,上一個(gè)隱藏層的輸出,并輸出隱藏層。

上圖描述了對(duì)GRU處理輸入Xt生成輸出ht的前向傳播過(guò)程。筆者將其分為四步,在圖中用圓圈加數(shù)字表示。

第一步:計(jì)算更新門(mén),更新門(mén)update gate簡(jiǎn)稱(chēng)為z,它的功能類(lèi)似于LSTM中的遺忘門(mén),用于控制以往信息和新輸入數(shù)據(jù)的在當(dāng)前狀態(tài)中的比例,該門(mén)的輸入也是前一個(gè)隱藏層的狀態(tài)ht-1以及當(dāng)前的輸入xt,省略了偏置參數(shù)b。

第二步:計(jì)算重置門(mén),重置門(mén)reset gate常簡(jiǎn)稱(chēng)為r,它的功能類(lèi)似于LSTM中的輸入門(mén),該門(mén)的輸入也是前一個(gè)隱藏層的狀態(tài)ht-1以及當(dāng)前的輸入xt。

第三步:計(jì)算輸入值,輸入值由前一個(gè)隱藏層的狀態(tài)ht-1,當(dāng)前的xt以及重置門(mén)rt計(jì)算得來(lái)。可視為當(dāng)前輸入對(duì)狀態(tài)的影響。

第四步:計(jì)算當(dāng)前狀態(tài),當(dāng)前狀態(tài)由兩部分組成,前一部分是以往信息的影響,后一部分是當(dāng)前輸入的影響,參數(shù)zt是更新門(mén)的值,它經(jīng)過(guò)激活函數(shù)sigmoid,取值在0-1之間,也就是說(shuō),前后兩部分的權(quán)重之和為1,通過(guò)更新門(mén)均衡二者的占比。

與LSTM相比,狀態(tài)層State cell被省略,由隱藏層h實(shí)現(xiàn)它的功能,并省略了輸出門(mén)o,去掉了各層的偏置參數(shù)b,在多個(gè)步驟進(jìn)行了簡(jiǎn)化,占用資源更少。

Pytorch的具體調(diào)用方法和RNN類(lèi)似,此處不再贅述。

優(yōu)化RNN網(wǎng)絡(luò)

深度學(xué)習(xí)工具一般都提供API直接調(diào)用RNN模型,像Keras工具只使用一條語(yǔ)句即建立一個(gè)LSTM模型,在建模過(guò)程中除了調(diào)用API程序員需要做哪些工作呢?

循環(huán)神經(jīng)的網(wǎng)絡(luò)每一個(gè)處理單元都通過(guò)一個(gè)或者多個(gè)全連接網(wǎng)絡(luò)與下一個(gè)單元相連,類(lèi)似于CNN的多層網(wǎng)絡(luò),因此序列越長(zhǎng),計(jì)算越復(fù)雜,設(shè)計(jì)網(wǎng)絡(luò)時(shí)需要考慮模型復(fù)雜度,估計(jì)訓(xùn)練時(shí)間,涉及:迭代次數(shù)、序列長(zhǎng)度,如何切分序列,隱藏層數(shù),隱藏層元素個(gè)數(shù),學(xué)習(xí)率,是否將隱藏層狀態(tài)傳入下一次迭代,超參數(shù)、以及參數(shù)初值等因素。

比如:RNN的誤差往往不是平滑收斂的,尤其是序列較長(zhǎng)時(shí),學(xué)習(xí)率很難固定下來(lái),建議使用Adam優(yōu)化器自動(dòng)調(diào)節(jié)學(xué)習(xí)參數(shù)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀(guān)點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容