循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)淺析

RNN是兩種神經(jīng)網(wǎng)絡(luò)模型的縮寫,一種是遞歸神經(jīng)網(wǎng)絡(luò)(Recursive Neural Network),一種是循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network)。雖然這兩種神經(jīng)網(wǎng)絡(luò)有著千絲萬縷的聯(lián)系,但是本文主要討論的是第二種神經(jīng)網(wǎng)絡(luò)模型——循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network)。

循環(huán)神經(jīng)網(wǎng)絡(luò)是指一個(gè)隨著時(shí)間的推移,重復(fù)發(fā)生的結(jié)構(gòu)。在自然語言處理(NLP),語音圖像等多個(gè)領(lǐng)域均有非常廣泛的應(yīng)用。RNN網(wǎng)絡(luò)和其他網(wǎng)絡(luò)最大的不同就在于RNN能夠?qū)崿F(xiàn)某種“記憶功能”,是進(jìn)行時(shí)間序列分析時(shí)最好的選擇。如同人類能夠憑借自己過往的記憶更好地認(rèn)識這個(gè)世界一樣。RNN也實(shí)現(xiàn)了類似于人腦的這一機(jī)制,對所處理過的信息留存有一定的記憶,而不像其他類型的神經(jīng)網(wǎng)絡(luò)并不能對處理過的信息留存記憶。

RNN原理

循環(huán)神經(jīng)網(wǎng)絡(luò)的原理并不十分復(fù)雜,本節(jié)主要從原理上分析RNN的結(jié)構(gòu)和功能,不涉及RNN的數(shù)學(xué)推導(dǎo)和證明,整個(gè)網(wǎng)絡(luò)只有簡單的輸入輸出和網(wǎng)絡(luò)狀態(tài)參數(shù)。一個(gè)典型的RNN神經(jīng)網(wǎng)絡(luò)如圖所示:


RNN結(jié)構(gòu)圖

由上圖可以看出:一個(gè)典型的RNN網(wǎng)絡(luò)包含一個(gè)輸入x,一個(gè)輸出h和一個(gè)神經(jīng)網(wǎng)絡(luò)單元A。和普通的神經(jīng)網(wǎng)絡(luò)不同的是,RNN網(wǎng)絡(luò)的神經(jīng)網(wǎng)絡(luò)單元A不僅僅與輸入和輸出存在聯(lián)系,其與自身也存在一個(gè)回路。這種網(wǎng)絡(luò)結(jié)構(gòu)就揭示了RNN的實(shí)質(zhì):上一個(gè)時(shí)刻的網(wǎng)絡(luò)狀態(tài)信息將會作用于下一個(gè)時(shí)刻的網(wǎng)絡(luò)狀態(tài)。如果上圖的網(wǎng)絡(luò)結(jié)構(gòu)仍不夠清晰,RNN網(wǎng)絡(luò)還能夠以時(shí)間序列展開成如下形式:

RNN展開圖

等號右邊是RNN的展開形式。由于RNN一般用來處理序列信息,因此下文說明時(shí)都以時(shí)間序列來舉例,解釋。等號右邊的等價(jià)RNN網(wǎng)絡(luò)中最初始的輸入是x0,輸出是h0,這代表著0時(shí)刻RNN網(wǎng)絡(luò)的輸入為x0,輸出為h0,網(wǎng)絡(luò)神經(jīng)元在0時(shí)刻的狀態(tài)保存在A中。當(dāng)下一個(gè)時(shí)刻1到來時(shí),此時(shí)網(wǎng)絡(luò)神經(jīng)元的狀態(tài)不僅僅由1時(shí)刻的輸入x1決定,也由0時(shí)刻的神經(jīng)元狀態(tài)決定。以后的情況都以此類推,直到時(shí)間序列的末尾t時(shí)刻。

上面的過程可以用一個(gè)簡單的例子來論證:假設(shè)現(xiàn)在有一句話“I want to play basketball”,由于自然語言本身就是一個(gè)時(shí)間序列,較早的語言會與較后的語言存在某種聯(lián)系,例如剛才的句子中“play”這個(gè)動詞意味著后面一定會有一個(gè)名詞,而這個(gè)名詞具體是什么可能需要更遙遠(yuǎn)的語境來決定,因此一句話也可以作為RNN的輸入?;氐絼偛诺哪蔷湓挘@句話中的5個(gè)單詞是以時(shí)序出現(xiàn)的,我們現(xiàn)在將這五個(gè)單詞編碼后依次輸入到RNN中。首先是單詞“I”,它作為時(shí)序上第一個(gè)出現(xiàn)的單詞被用作x0輸入,擁有一個(gè)h0輸出,并且改變了初始神經(jīng)元A的狀態(tài)。單詞“want”作為時(shí)序上第二個(gè)出現(xiàn)的單詞作為x1輸入,這時(shí)RNN的輸出和神經(jīng)元狀態(tài)將不僅僅由x1決定,也將由上一時(shí)刻的神經(jīng)元狀態(tài)或者說上一時(shí)刻的輸入x0決定。之后的情況以此類推,直到上述句子輸入到最后一個(gè)單詞“basketball”。

接下來我們需要關(guān)注RNN的神經(jīng)元結(jié)構(gòu):

RNN內(nèi)部結(jié)構(gòu)圖

上圖依然是一個(gè)RNN神經(jīng)網(wǎng)絡(luò)的時(shí)序展開模型,中間t時(shí)刻的網(wǎng)絡(luò)模型揭示了RNN的結(jié)構(gòu)。可以看到,原始的RNN網(wǎng)絡(luò)的內(nèi)部結(jié)構(gòu)非常簡單。神經(jīng)元A在t時(shí)刻的狀態(tài)僅僅是t-1時(shí)刻神經(jīng)元狀態(tài)與t時(shí)刻網(wǎng)絡(luò)輸入的雙曲正切函數(shù)的值,這個(gè)值不僅僅作為該時(shí)刻網(wǎng)絡(luò)的輸出,也作為該時(shí)刻網(wǎng)絡(luò)的狀態(tài)被傳入到下一個(gè)時(shí)刻的網(wǎng)絡(luò)狀態(tài)中,這個(gè)過程叫做RNN的正向傳播(forward propagation)。注:雙曲正切函數(shù)的解析式如下:

雙曲正切函數(shù)的求導(dǎo)如下:

雙曲正切函數(shù)的圖像如下所示:

雙曲正切函數(shù)

這里就帶來一個(gè)問題:為什么RNN網(wǎng)絡(luò)的激活函數(shù)要選用雙曲正切而不是sigmod呢?(RNN的激活函數(shù)除了雙曲正切,RELU函數(shù)也用的非常多)原因在于RNN網(wǎng)絡(luò)在求解時(shí)涉及時(shí)間序列上的大量求導(dǎo)運(yùn)算,使用sigmod函數(shù)容易出現(xiàn)梯度消失,且sigmod的導(dǎo)數(shù)形式較為復(fù)雜。事實(shí)上,即使使用雙曲正切函數(shù),傳統(tǒng)的RNN網(wǎng)絡(luò)依然存在梯度消失問題,無法“記憶”長時(shí)間序列上的信息,這個(gè)bug直到LSTM上引入了單元狀態(tài)后才算較好地解決。

數(shù)學(xué)基礎(chǔ)

這一節(jié)主要介紹與RNN相關(guān)的數(shù)學(xué)推導(dǎo),由于RNN是一個(gè)時(shí)序模型,因此其求解過程可能和一般的神經(jīng)網(wǎng)絡(luò)不太相同。首先需要介紹一下RNN完整的結(jié)構(gòu)圖,上一節(jié)給出的RNN結(jié)構(gòu)圖省去了很多內(nèi)部參數(shù),僅僅作為一個(gè)概念模型給出。

RNN結(jié)構(gòu)圖

上圖表明了RNN網(wǎng)絡(luò)的完整拓?fù)浣Y(jié)構(gòu),從圖中我們可以看到RNN網(wǎng)絡(luò)中的參數(shù)情況。在這里我們只分析t時(shí)刻網(wǎng)絡(luò)的行為與數(shù)學(xué)推導(dǎo)。t時(shí)刻網(wǎng)絡(luò)迎來一個(gè)輸入xt,網(wǎng)絡(luò)此時(shí)刻的神經(jīng)元狀態(tài)st用如下式子表達(dá):

t時(shí)刻的網(wǎng)絡(luò)狀態(tài)st不僅僅要輸入到下一個(gè)時(shí)刻t+1的網(wǎng)絡(luò)狀態(tài)中去,還要作為該時(shí)刻的網(wǎng)絡(luò)輸出。當(dāng)然,st不能直接輸出,在輸出之前還要再乘上一個(gè)系數(shù)V,而且為了誤差逆?zhèn)鞑r(shí)的方便通常還要對輸出進(jìn)行歸一化處理,也就是對輸出進(jìn)行softmax化。因此,t時(shí)刻網(wǎng)絡(luò)的輸出ot表達(dá)為如下形式:

為了表達(dá)方便,筆者將上述兩個(gè)公式做如下變換:

以上,就是RNN網(wǎng)絡(luò)的數(shù)學(xué)表達(dá)了,接下來我們需要求解這個(gè)模型。在論述具體解法之前首先需要明確兩個(gè)問題:優(yōu)化目標(biāo)函數(shù)是什么?待優(yōu)化的量是什么?

只有在明確了這兩個(gè)問題之后才能對模型進(jìn)行具體的推導(dǎo)和求解。關(guān)于第一個(gè)問題,筆者選取模型的損失函數(shù)作為優(yōu)化目標(biāo);關(guān)于第二個(gè)問題,我們從RNN的結(jié)構(gòu)圖中不難發(fā)現(xiàn):只要我們得到了模型的U,V,W這三個(gè)參數(shù)就能完全確定模型的狀態(tài)。因此該優(yōu)化問題的優(yōu)化變量就是RNN的這三個(gè)參數(shù)。順便說一句,RNN模型的U,V,W三個(gè)參數(shù)是全局共享的,也就是說不同時(shí)刻的模型參數(shù)是完全一致的,這個(gè)特性使RNN得參數(shù)變得稍微少了一些。

損失函數(shù)

不做過多的討論,RNN的損失函數(shù)選用交叉熵(Cross Entropy),這是機(jī)器學(xué)習(xí)中使用最廣泛的損失函數(shù)之一了,其通常的表達(dá)式如下所示:

上面式子是交叉熵的標(biāo)量形式,y_i是真實(shí)的標(biāo)簽值,y_i*是模型給出的預(yù)測值,最外面之所以有一個(gè)累加符號是因?yàn)槟P洼敵龅囊话愣际且粋€(gè)多維的向量,只有把n維損失都加和才能得到真實(shí)的損失值。交叉熵在應(yīng)用于RNN時(shí)需要做一些改變:首先,RNN的輸出是向量形式,沒有必要將所有維度都加在一起,直接把損失值用向量表達(dá)就可以了;其次,由于RNN模型處理的是序列問題,因此其模型損失不能只是一個(gè)時(shí)刻的損失,應(yīng)該包含全部N個(gè)時(shí)刻的損失。

故RNN模型在t時(shí)刻的損失函數(shù)寫成如下形式:

全部N個(gè)時(shí)刻的損失函數(shù)(全局損失)表達(dá)為如下形式:

需要說明的是:yt是t時(shí)刻輸入的真實(shí)標(biāo)簽值,ot為模型的預(yù)測值,N代表全部N個(gè)時(shí)刻。下文中為了書寫方便,將Loss簡記為L。在結(jié)束本小節(jié)之前,最后補(bǔ)充一個(gè)softmax函數(shù)的求導(dǎo)公式:

BPTT算法

由于RNN模型與時(shí)間序列有關(guān),因此不能直接使用BP(back propagation)算法。針對RNN問題的特殊情況,提出了BPTT算法。BPTT的全稱是“隨時(shí)間變化的反向傳播算法”(back propagation through time)。這個(gè)方法的基礎(chǔ)仍然是常規(guī)的鏈?zhǔn)角髮?dǎo)法則,接下來開始具體推導(dǎo)。雖然RNN的全局損失是與全部N個(gè)時(shí)刻有關(guān)的,但為了簡單筆者在推導(dǎo)時(shí)只關(guān)注t時(shí)刻的損失函數(shù)。

首先求出t時(shí)刻下?lián)p失函數(shù)關(guān)于o_t*的微分:

求出損失函數(shù)關(guān)于參數(shù)V的微分:

因此,全局損失關(guān)于參數(shù)V的微分為:

求出t時(shí)刻的損失函數(shù)關(guān)于關(guān)于st*的微分:

求出t時(shí)刻的損失函數(shù)關(guān)于s_t-1*的微分:

求出t時(shí)刻損失函數(shù)關(guān)于參數(shù)U的偏微分。注意:由于是時(shí)間序列模型,因此t時(shí)刻關(guān)于U的微分與前t-1個(gè)時(shí)刻都有關(guān),在具體計(jì)算時(shí)可以限定最遠(yuǎn)回溯到前n個(gè)時(shí)刻,但在推導(dǎo)時(shí)需要將前t-1個(gè)時(shí)刻全部帶入:

因此,全局損失關(guān)于U的偏微分為:

求t時(shí)刻損失函數(shù)關(guān)于參數(shù)W的偏微分,和上面相同的道理,在這里仍然要計(jì)算全部前t-1時(shí)刻的情況:

因此,全局損失關(guān)于參數(shù)W的微分結(jié)果為:

至此,全局損失函數(shù)關(guān)于三個(gè)主要參數(shù)的微分都已經(jīng)得到了。整理如下:

接下來進(jìn)一步化簡上述微分表達(dá)式,化簡的主要方向?yàn)閠時(shí)刻的損失函數(shù)關(guān)于ot的微分以及關(guān)于st*的微分。已知t時(shí)刻損失函數(shù)的表達(dá)式,求關(guān)于ot的微分:

softmax函數(shù)求導(dǎo):

因此:

又因?yàn)椋?/p>

且:

有了上面的數(shù)學(xué)推導(dǎo),我們可以得到全局損失關(guān)于U,V,W三個(gè)參數(shù)的梯度公式:

由于參數(shù)U和W的微分公式不僅僅與t時(shí)刻有關(guān),還與前面的t-1個(gè)時(shí)刻都有關(guān),因此無法寫出直接的計(jì)算公式。不過上面已經(jīng)給出了t時(shí)刻的損失函數(shù)關(guān)于s_t-1的微分遞推公式,想來求解這個(gè)式子也是十分簡單的,在這里就不贅述了。

以上就是關(guān)于BPTT算法的全部數(shù)學(xué)推導(dǎo)。從最終結(jié)果可以看出三個(gè)公式的偏微分結(jié)果非常簡單,在具體的優(yōu)化過程中可以直接帶入進(jìn)行計(jì)算。對于這種優(yōu)化問題來說,最常用的方法就是梯度下降法。針對本文涉及的RNN問題,可以構(gòu)造出三個(gè)參數(shù)的梯度更新公式:

依靠上述梯度更新公式就能夠迭代求解三個(gè)參數(shù),直到三個(gè)參數(shù)的值發(fā)生收斂。

后記

這是筆者第一次嘗試推導(dǎo)RNN的數(shù)學(xué)模型,在推導(dǎo)過程中遇到了非常多的bug。非常感謝互聯(lián)網(wǎng)上的一些公開資料和博客,給了我非常大的幫助和指引。接下來筆者將嘗試實(shí)現(xiàn)一個(gè)單隱層的RNN模型用于實(shí)現(xiàn)一個(gè)語義預(yù)測模型。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 經(jīng)常聽到女性朋友感嘆,好男人數(shù)量少??;在自己工作生活圈里能碰到好男人的幾率就更小了;這種情況下,竟然還要碰到能兩...
    嚴(yán)小愛閱讀 747評論 0 0
  • 男人的一雙大手 把孩子高高舉過頭頂 拋在空中再接住 孩子在飛翔中 俯視爸爸 笑瞇瞇的一雙眼 樂成小月牙 爸爸也微笑...
    溪水音閱讀 522評論 0 5
  • 在抹濕膚液時(shí),看著她送的小瓶子,不由得又想起她。 經(jīng)歷了絕望,憤怒,失落,以及舊日事物墜入心湖產(chǎn)生的感動漣漪,我的...
    騎馬藍(lán)閱讀 376評論 0 1
  • 《繁華》 這地方火樹銀花 又有無盡繁華 虹燈閃爍 不似心中 文明的枝椏// 平凡生活 卻遭人歧諷 來吧,告訴我 你...
    紅塵紅塵閱讀 316評論 2 1
  • 《岡仁波齊》看了幾次都沒有看完,今天煮了茶,喝了2壺的時(shí)間看完了它??吹匠霭l(fā)的時(shí)候,第一個(gè)長頭,突然眼淚下來,心里...
    最好的十年里閱讀 235評論 0 0

友情鏈接更多精彩內(nèi)容