手算梯度下降法,詳解神經(jīng)網(wǎng)絡(luò)迭代訓(xùn)練過(guò)程

神經(jīng)網(wǎng)絡(luò)本質(zhì)上是一個(gè)計(jì)算流程,在前端接收輸入信號(hào)后,經(jīng)過(guò)一層層復(fù)雜的運(yùn)算,在最末端輸出結(jié)果。然后將計(jì)算結(jié)果和正確結(jié)果相比較,得到誤差,再根據(jù)誤差通過(guò)相應(yīng)計(jì)算方法改進(jìn)網(wǎng)絡(luò)內(nèi)部的相關(guān)參數(shù),使得網(wǎng)絡(luò)下次再接收到同樣的數(shù)據(jù)時(shí),最終計(jì)算輸出得到的結(jié)果與正確結(jié)果之間的誤差能越來(lái)越小。

這里需要搞清楚一個(gè)重要概念,就是如何計(jì)算誤差,我們列一個(gè)表,展示一個(gè)在最外層有三個(gè)節(jié)點(diǎn)的網(wǎng)絡(luò)對(duì)誤差的三種計(jì)算情況:

這里寫(xiě)圖片描述

上表列出三種誤差處理情況,第一種計(jì)算誤差的方式是將簡(jiǎn)單的將網(wǎng)絡(luò)計(jì)算結(jié)果與正確結(jié)果相減,但采用這種做法,如果我們把所有誤差相加在一起,結(jié)果居然為零,因?yàn)榈谝粋€(gè)節(jié)點(diǎn)的結(jié)果與正確結(jié)果的差值和第二個(gè)節(jié)點(diǎn)結(jié)果與正確結(jié)果的差值剛好相反,于是誤差就相互抵消掉,由此直接將兩者相減不是一種理想的誤差計(jì)算方式。

第二種是相減后求絕對(duì)值。這樣一來(lái)每個(gè)節(jié)點(diǎn)間的誤差在加總時(shí)就不會(huì)相互抵消,但絕對(duì)值的存在使得函數(shù)圖像會(huì)變成一個(gè)"V"字型,在最低點(diǎn)處是一個(gè)箭頭,于是這個(gè)函數(shù)在最低點(diǎn)處不連續(xù),梯度下降法不能運(yùn)用于不連續(xù)的函數(shù)。

第三者是兩者相減后求平方,這種做法使得誤差函數(shù)變成一條光滑的曲線(xiàn),這是梯度下降法運(yùn)用的最佳場(chǎng)景。在上一節(jié)中我們講過(guò),我們要根據(jù)數(shù)據(jù)點(diǎn)所在的切線(xiàn)斜率來(lái)“適當(dāng)”的調(diào)整變量的值,后面我們會(huì)看到,這里的“適當(dāng)”就得依賴(lài)切線(xiàn)的斜率大小,一條光滑曲線(xiàn),也就是一條“連續(xù)”曲線(xiàn),它在最低點(diǎn)附件切線(xiàn)的斜率會(huì)越來(lái)越小,這樣的話(huà)變量改變的幅度也會(huì)越來(lái)越小,進(jìn)而使得我們能夠準(zhǔn)確的定位到最低點(diǎn)。這里的”連續(xù)“指的就是高等數(shù)學(xué)或微積分上的”連續(xù)“。

一個(gè)神經(jīng)網(wǎng)絡(luò)本質(zhì)上是一個(gè)含有多個(gè)變量的函數(shù),其中每條鏈路上的權(quán)重對(duì)應(yīng)著一個(gè)變量,任何一條鏈路權(quán)重的改變會(huì)對(duì)網(wǎng)絡(luò)末端的多個(gè)節(jié)點(diǎn)輸出產(chǎn)生影響,可謂是牽一發(fā)而動(dòng)全身。如果我們把第三中誤差計(jì)算方法,也就是error_sum = (節(jié)點(diǎn)輸出的結(jié)果-正確結(jié)果)^2加總,作為最終誤差,那么我們的目的就是不斷的修改網(wǎng)絡(luò)中每條鏈路權(quán)重值,使得erro_sum的值最小,這與我們上一節(jié)所講的求一個(gè)復(fù)雜函數(shù)最小值的目的是一致的。

前面我們說(shuō)過(guò),如果一個(gè)函數(shù)擁有兩個(gè)變量,那么函數(shù)的值就會(huì)在三維空間形成一個(gè)曲面。如果一個(gè)函數(shù)擁有多個(gè)變量,那么它的結(jié)果就會(huì)形成一個(gè)多維度的超平面,這已經(jīng)超出我們?nèi)四X的想象范圍,我們最多可以想象一個(gè)三維的物體。如果我們沿著某個(gè)變量的方向?qū)@個(gè)超平面切一刀,在切面的邊緣就會(huì)形成一條曲線(xiàn),例如你拿刀把一個(gè)蘋(píng)果切開(kāi),在切開(kāi)的平面邊緣對(duì)應(yīng)著一條曲線(xiàn),如下圖:

這里寫(xiě)圖片描述

大家注意看,切面曲線(xiàn)的最低點(diǎn)處,對(duì)應(yīng)著整個(gè)蘋(píng)果的最低點(diǎn)處。同理我們對(duì)一個(gè)包含多個(gè)變量構(gòu)成的函數(shù)所形成的超平面,我們沿著某個(gè)變量的方向?qū)ζ矫媲幸坏?,在切面的邊緣也?huì)有一條曲線(xiàn):

這里寫(xiě)圖片描述

我們前面所的error_sum,它是由(節(jié)點(diǎn)計(jì)算那結(jié)果-正確結(jié)果)^2加總構(gòu)成的,而“節(jié)點(diǎn)計(jì)算結(jié)果”卻是受到網(wǎng)絡(luò)中每一條鏈路權(quán)重的影響,因此我們可以認(rèn)為error_sum是一個(gè)含有多個(gè)變量的函數(shù),每個(gè)變量對(duì)應(yīng)著網(wǎng)絡(luò)中每條鏈路的權(quán)重。如果我們以某條鏈路的權(quán)重為準(zhǔn),往這個(gè)超平面切一刀,那么切面的邊緣就是一條一維曲線(xiàn),這個(gè)曲線(xiàn)的最低點(diǎn)就對(duì)應(yīng)著整個(gè)超平面的最低點(diǎn),假設(shè)這條曲線(xiàn)如上圖,那么我們通過(guò)上一節(jié)講解的梯度下降法調(diào)整這條鏈路的權(quán)重值,就會(huì)使得error_sum的值向最低點(diǎn)走去。

假設(shè)一個(gè)神經(jīng)網(wǎng)絡(luò)只含有兩條路徑,也就是說(shuō)error_sum對(duì)應(yīng)著兩個(gè)變量,這意味著erro_sum是的結(jié)果是三維空間上的一個(gè)曲面,那么我們對(duì)每一個(gè)變量做曲面的切面,根據(jù)切面的邊緣曲線(xiàn)做切線(xiàn),進(jìn)而得到每個(gè)變量該如何變化才能走到曲面的最低點(diǎn),這兩個(gè)變量各自的變化合在一起形成了曲面上一個(gè)點(diǎn)走到最低點(diǎn)的路徑,如下圖:

這里寫(xiě)圖片描述

接下來(lái)的問(wèn)題是,如何沿著某個(gè)變量的方向?qū)η媲幸坏逗螅业礁狞c(diǎn)在切面邊緣曲線(xiàn)上的斜率,在數(shù)學(xué)上對(duì)應(yīng)著對(duì)根據(jù)某個(gè)變量對(duì)函數(shù)求偏導(dǎo)數(shù),公式如下:

這里寫(xiě)圖片描述

偏導(dǎo)數(shù)的結(jié)果就是鏈路權(quán)重在error_sum函數(shù)這個(gè)超平面上做切面后,切面邊緣處的切線(xiàn),根據(jù)切線(xiàn)斜率,我們就可以調(diào)整鏈路W(j,k)的值,從而使得error_sumn變小。接下來(lái)我們通過(guò)一個(gè)具體實(shí)例,看看如何通過(guò)偏導(dǎo)數(shù)求得error_sum的最小值,假設(shè)我們有如下網(wǎng)絡(luò):

這里寫(xiě)圖片描述

網(wǎng)絡(luò)的輸出層有兩個(gè)節(jié)點(diǎn),k1和k2,他們輸出的值由O1和O2表示,相應(yīng)的誤差由e1和e2表示。根據(jù)前面描述,error_sum等于e12+e22,也就是(t1-o1)2+(t2-o2)2。由于O1與O2是由中間層與最外層節(jié)點(diǎn)間的鏈路權(quán)重決定的,于是調(diào)整這兩層節(jié)點(diǎn)間鏈路權(quán)重就能影響最外層的輸出結(jié)果,上圖已經(jīng)把影響最終輸出的四條鏈路標(biāo)注出來(lái)。于是我們分別根據(jù)這四個(gè)權(quán)重變量求偏導(dǎo)數(shù),這樣我們才能確定這些變量如何變化才會(huì)影響最終輸出結(jié)果:

這里寫(xiě)圖片描述

我們一定要注意,最外層節(jié)點(diǎn)O(k),只與內(nèi)層連接到它的鏈路權(quán)重w(jk)相關(guān),其他未跟它連接的鏈路權(quán)重?zé)o論如何變化,都不會(huì)影響最外層節(jié)點(diǎn)O(k)的輸出結(jié)果。同時(shí)求偏導(dǎo)數(shù)時(shí),除了參與求導(dǎo)的變量會(huì)留下來(lái),其他無(wú)關(guān)變量會(huì)在求導(dǎo)的過(guò)程中被消除掉,上面公式中,參與求導(dǎo)的是變量w(jk),與該變量對(duì)應(yīng)的就是O(k),所以上面的公式可以簡(jiǎn)化如下:

這里寫(xiě)圖片描述

接下來(lái)我們根據(jù)微積分原理,對(duì)上面的求導(dǎo)運(yùn)算進(jìn)行展開(kāi),由于t(k)對(duì)應(yīng)的是正確數(shù)值,因此它是個(gè)常量,于是變量w(jk)與它沒(méi)有關(guān)聯(lián),而節(jié)點(diǎn)輸出O(k)與權(quán)重w(jk)是緊密相關(guān)的,因?yàn)樾盘?hào)從中間層節(jié)點(diǎn)j輸出后,經(jīng)過(guò)鏈路w(jk)后進(jìn)入節(jié)點(diǎn)k才產(chǎn)生了輸出O(k)。也就是說(shuō)O(k)是將w(jk)經(jīng)由某種函數(shù)運(yùn)算后所得的結(jié)果,于是根據(jù)求導(dǎo)的鏈?zhǔn)椒▌t,我們有:

這里寫(xiě)圖片描述

結(jié)合上下兩個(gè)公式,我們可以把對(duì)變量O(k)的求導(dǎo)做進(jìn)一步展開(kāi)后如下:


這里寫(xiě)圖片描述

接下來(lái)我們得看上邊公式右邊,對(duì)W(jk)的求導(dǎo)如何展開(kāi),前面我們?cè)缫蚜私猓琌(k)的值是由進(jìn)入它的鏈路權(quán)重乘以經(jīng)過(guò)鏈路的信號(hào)量,加總后再經(jīng)過(guò)激活函數(shù)運(yùn)算后所得的結(jié)果,于是上邊公式右邊對(duì)變量w(jk)求導(dǎo)的部分就可以展開(kāi)如下:

這里寫(xiě)圖片描述

上面的變量O(j)就是中間層節(jié)點(diǎn)j輸出到鏈路jk上的信號(hào)量?,F(xiàn)在的問(wèn)題是,如何對(duì)激活函數(shù)求導(dǎo),我們完全可以根據(jù)求導(dǎo)數(shù)的方法,一步一步的算出來(lái),這里我們忽略這些繁瑣機(jī)械的流程,直接給出激活函數(shù)求導(dǎo)后的結(jié)果:

這里寫(xiě)圖片描述

于是我們把這幾步連續(xù)求導(dǎo)的結(jié)果結(jié)合起來(lái),得到如下公式:

這里寫(xiě)圖片描述

這里要注意到,我們是把經(jīng)過(guò)鏈路jk的信號(hào)量與鏈路權(quán)重做乘積之后再傳入激活函數(shù),而所謂的“jk的信號(hào)量與鏈路權(quán)重做乘積"實(shí)際上對(duì)應(yīng)的正是一個(gè)有關(guān)權(quán)重w(jk)的函數(shù)f=w(jk)*O(j),因此根據(jù)求導(dǎo)的鏈?zhǔn)椒▌t,我們對(duì)f也要做一次導(dǎo)數(shù),求導(dǎo)結(jié)果正好是O(j)。我們可以把上面式子里的2拿掉,因?yàn)槲覀冴P(guān)心的是切線(xiàn)斜率的方向,也就是上面求導(dǎo)結(jié)果是正是負(fù),這涉及到我們是應(yīng)該增加w(jk)還是應(yīng)該減少w(jk),正負(fù)確定了,至于具體值是多少,并不影響我們最后的運(yùn)算,所以經(jīng)過(guò)鏈?zhǔn)椒▌t一系列求導(dǎo)后,我們得到最終結(jié)果如下:


這里寫(xiě)圖片描述

上面所得結(jié)果可以分解成三部分,第一部分是正確結(jié)果與節(jié)點(diǎn)輸出結(jié)果的差值,也就是誤差,紅色部分對(duì)應(yīng)的是節(jié)點(diǎn)的激活函數(shù),所有輸入該節(jié)點(diǎn)的鏈路把經(jīng)過(guò)其上的信號(hào)與鏈路權(quán)重做乘積后加總,在把加總結(jié)果進(jìn)行激活函數(shù)運(yùn)算,最后一部分是鏈路w(jk)前端節(jié)點(diǎn)輸出的信號(hào)值。

我們這里談到的數(shù)學(xué)是涉及神經(jīng)網(wǎng)絡(luò)最核心的部分,除了這里有些數(shù)學(xué)知識(shí)需要掌握外,其他的就都是有關(guān)工程實(shí)踐的問(wèn)題了。我們這里運(yùn)算的是中間層和最外層節(jié)點(diǎn)間的鏈路權(quán)重求偏導(dǎo)數(shù)結(jié)果,那么輸入層和中間層之間鏈路權(quán)重的求偏導(dǎo)數(shù)過(guò)程其實(shí)是完全一模一樣的!我們只需要把上面等式中的k換成j,j換成i就可以了,所以輸入層和中間層間,鏈路權(quán)重的偏導(dǎo)公式如下:

這里寫(xiě)圖片描述

前面我們講梯度下降法時(shí)說(shuō),要根據(jù)變量對(duì)應(yīng)切線(xiàn)的斜率對(duì)變量做”適度“調(diào)整,調(diào)整的方向與斜率的方向相反,我們可以根據(jù)下面公式進(jìn)行權(quán)重調(diào)整:

這里寫(xiě)圖片描述

公式中的變量a,表示學(xué)習(xí)率,它決定了調(diào)整步伐的大小,前面的符號(hào)用于表示調(diào)整的方向與斜率的方向相反,如果斜率是賦值,那么我們就增加變量w(jk)的值,如果斜率是正的,我們就減少變量w(jk)的值。無(wú)論是中間層和輸出層,還是輸入層和中間層,我們都使用上面的公式修改鏈路權(quán)重。

我們把上邊公式中右邊減號(hào)后面的部分當(dāng)做鏈路調(diào)整的增量,記作△w(jk),那么就有:
△w(jk) = E(k) * (S(k)*(1 - S(k)))*O(j)
其中E(k) = (T(k) - O(k))也就是節(jié)點(diǎn)k對(duì)應(yīng)的誤差,S(k)對(duì)應(yīng)的就是節(jié)點(diǎn)k對(duì)輸入的信號(hào)量求和后做激活函數(shù)的結(jié)果,O(j)是節(jié)點(diǎn)j的輸出信號(hào)量,這幾部分分別對(duì)應(yīng)上面求偏導(dǎo)公式中的紫色,紅色,和綠色部分,如此一來(lái),每個(gè)節(jié)點(diǎn)的增量就可以對(duì)應(yīng)成矩陣運(yùn)算:

這里寫(xiě)圖片描述

我們 一眼上面的矩陣運(yùn)算會(huì)讓我們眼花,為了讓大家更清楚上面公式中各部分分量的組成,我們把前面的求導(dǎo)公式再次做個(gè)變換:

這里寫(xiě)圖片描述

紅色部分就是矩陣運(yùn)算右邊對(duì)應(yīng)的分量S(k)。接下來(lái)我們對(duì)一個(gè)實(shí)例的手算進(jìn)一步對(duì)加深對(duì)推導(dǎo)過(guò)程的理解。下面三層網(wǎng)絡(luò)是我們前幾節(jié)運(yùn)算過(guò)的例子,我們重新把它拿出來(lái),根據(jù)我們前面推導(dǎo)的權(quán)重變換流程,手動(dòng)做一次網(wǎng)絡(luò)的訓(xùn)練流程:

這里寫(xiě)圖片描述

我們要計(jì)算中間層節(jié)點(diǎn)1與輸出層節(jié)點(diǎn)1之間鏈路權(quán)重的增量,根據(jù)最外層節(jié)點(diǎn)得到的誤差1.5,中間層節(jié)點(diǎn)1對(duì)應(yīng)的信號(hào)量是0.4,中間層節(jié)點(diǎn)2對(duì)應(yīng)的信號(hào)量是0.5,當(dāng)前兩個(gè)節(jié)點(diǎn)間的鏈路權(quán)重w(11)是2.0,我們直接套入前面推導(dǎo)的偏導(dǎo)公式進(jìn)行計(jì)算:

這里寫(xiě)圖片描述

讓我們一步一步的套入公式進(jìn)行計(jì)算,如果我們要更改w(11)的值,計(jì)算步驟如下:
第一步t(k) - O(k) 對(duì)應(yīng)最外層輸出節(jié)點(diǎn)的誤差,例如e1 = 1.5。

第二步就是求和
這里寫(xiě)圖片描述
,對(duì)應(yīng)的計(jì)算就是(2.0*0.4)+(0.5*3.0) = 0.8+1.5=2.3
第三步計(jì)算sigmod:1/1+exp(-2.3)對(duì)應(yīng)的值為0.909,于是中間部分對(duì)應(yīng)為9.909*(1-0.909) = 0.082

第四步計(jì)算O(j),也就是0.4.
把所有結(jié)果合在一起算就是:-1.5*0.082*0.4 = -0.0492,如果我們把學(xué)習(xí)率設(shè)置為1,那么w(11)的修改量為:-1*(1)*(-0.0492) = 0.0492,修改后的W(11)值就是2.0+0.0492=2.0492。

我們看到一次變動(dòng)的步伐很小,但實(shí)際應(yīng)用時(shí),上面的修改步驟會(huì)進(jìn)行成千上萬(wàn)次,于是最后W(11)有可能會(huì)產(chǎn)生很明顯的變化。從下一節(jié)開(kāi)始,我們就進(jìn)入到使用python編碼實(shí)現(xiàn)我們這幾節(jié)所講的算法理論。

更多技術(shù)信息,包括操作系統(tǒng),編譯器,面試算法,機(jī)器學(xué)習(xí),人工智能,請(qǐng)關(guān)照我的公眾號(hào):


這里寫(xiě)圖片描述
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀(guān)點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容