淺談隨機(jī)梯度下降法在對(duì)數(shù)回歸中的應(yīng)用

寫在前面

這兩天研究了一下機(jī)器學(xué)習(xí)中較為簡(jiǎn)單的一類回歸問題,對(duì)數(shù)回歸。寫一篇博客總結(jié)一下自己的體會(huì)和經(jīng)驗(yàn),若是其中出現(xiàn)錯(cuò)誤,還希望讀者能夠留言幫我指出。

對(duì)數(shù)回歸問題是機(jī)器學(xué)習(xí)中線性模型的一種,也是非?;A(chǔ)的一種分類器模型。線性回歸中最廣為人知的莫過于最小二乘估計(jì),這是高中階段就講過的知識(shí)。本篇文章所要講的是如何將線性模型應(yīng)用到機(jī)器學(xué)習(xí)中,并且如何求解模型。

對(duì)數(shù)回歸

機(jī)器學(xué)習(xí)領(lǐng)域最常見的應(yīng)用就是分類,對(duì)數(shù)回歸進(jìn)行分類的主要思想就是根據(jù)現(xiàn)有標(biāo)記數(shù)據(jù)(訓(xùn)練集)對(duì)分類邊界建立回歸公式,以此進(jìn)行分類?,F(xiàn)在,可以試想我們已經(jīng)得到了一個(gè)有效的分類模型,這個(gè)模型的輸入是待分類的一個(gè)向量,輸出是分類結(jié)果(用0代表反例,1代表正例)。然而,回歸模型的輸出一般都是連續(xù)值。如果要使模型能夠應(yīng)用于分類,就必然需要確定一個(gè)閾值,若回歸模型的輸出大于這個(gè)閾值則返回正例,反之返回反例。很顯然分段函數(shù)能夠滿足上述的需求,然而分段函數(shù)不連續(xù)可微,在最優(yōu)化問題中這不是一個(gè)很好的數(shù)學(xué)性質(zhì),因此我們不能使用分段函數(shù)。我們需要的函數(shù)要既要擁有分段函數(shù)的性質(zhì),又要連續(xù)可微。幸好,存在這樣的函數(shù),這就是sigmod函數(shù),該函數(shù)的圖像如下圖所示:

圖1

sigmod函數(shù)的自變量是線性模型的輸出。寫到這里就帶出了對(duì)數(shù)回歸的概念,對(duì)數(shù)回歸其實(shí)就是在傳統(tǒng)線性回歸的基礎(chǔ)上增加了一層sigmod函數(shù)。

線性模型

這一節(jié)來論述對(duì)數(shù)回歸的核心——線性模型的建立。一個(gè)線性模型可以建立為以下形式:設(shè)權(quán)重列向量為w=(w0;w1;w2…;wn),輸入行向量為x=(x0,x1,x2…,xn),線性模型如下式所示:

公式1

其中x可以為一個(gè)行向量也可以為一個(gè)輸入矩陣。該模型的輸出就是模型的預(yù)測(cè)值,由于訓(xùn)練數(shù)據(jù)都是帶標(biāo)簽的因此我們可以計(jì)算出預(yù)測(cè)值與真實(shí)值的差:

公式2

需要說明的是若輸入x為向量則誤差e為一個(gè)實(shí)數(shù),若x為矩陣則e為一個(gè)列向量。我們的目標(biāo)是讓該模型在所有訓(xùn)練集上的誤差之和最小。然而誤差e可能出現(xiàn)正值也可能出現(xiàn)負(fù)值,如果簡(jiǎn)單相加可能出現(xiàn)正負(fù)抵消現(xiàn)象,因此好的方案是讓誤差平方后相加。這樣,我們就得到了線性模型的優(yōu)化目標(biāo)函數(shù):

公式3

看到這里你可能會(huì)發(fā)現(xiàn)為什么一直沒有用到sigmod函數(shù),這主要是因?yàn)閟igmod函數(shù)只是用于最后輸出,在模型求解過程中sigmod函數(shù)并沒有什么特別的影響,因此出于簡(jiǎn)化目的就將其省去。

梯度下降法

梯度下降法是求解最優(yōu)化的問題時(shí)最常用的一種方法,只是由于其計(jì)算量巨大已經(jīng)很少用到了。不過很多其他方法都是由梯度下降法改進(jìn)或衍生出的,因此還是有必要提一下梯度下降法。該方法適用于求解凸函數(shù)最小值,其思想概括起來就是讓函數(shù)的自變量沿著函數(shù)變化率最大的那個(gè)方向以某一個(gè)步長(zhǎng)移動(dòng),遞歸這一過程,直到函數(shù)值滿足某個(gè)預(yù)設(shè)條件,遞歸返回,迭代更新公式如下式:

公式4

函數(shù)變化率最大的那個(gè)方向就是函數(shù)微分結(jié)果的方向,步長(zhǎng)用戶可以自行指定,但是要注意防止陷入局部最小。

模型求解

目前我們已知模型的目標(biāo)函數(shù),首先我們要求目標(biāo)函數(shù)對(duì)w的微分,由于函數(shù)中的變量都是矩陣形式因此在微分操作時(shí)與普通函數(shù)有一點(diǎn)不同。詳細(xì)過程如下,將原始的目標(biāo)函數(shù)展開如下形式:

公式5

上式共有四項(xiàng),可以看出最后一項(xiàng)不含w因此其微分為0,我們主要對(duì)上式的前三項(xiàng)進(jìn)行微分,第一項(xiàng)的微分:

公式6

第二項(xiàng)的微分:


公式7

第三項(xiàng)的微分:

公式8

將三項(xiàng)微分結(jié)果組合在一起就得到了目標(biāo)函數(shù)的微分結(jié)果:

公式9

這個(gè)結(jié)果中的常數(shù)2可以忽略掉,對(duì)結(jié)果不會(huì)產(chǎn)生什么影響。權(quán)重向量w的迭代公式也就可以確定下來:

公式10

我們只需要在模型中迭代求解w值,直到某一個(gè)迭代次數(shù)或模型輸出達(dá)到某個(gè)標(biāo)準(zhǔn)便可以得到模型的解。

實(shí)現(xiàn)

這一部分筆者將詳細(xì)介紹對(duì)數(shù)回歸的編程實(shí)現(xiàn),在這里要感謝《機(jī)器學(xué)習(xí)實(shí)戰(zhàn)》這本書,這部分所使用到的兩個(gè)數(shù)據(jù)集都來源于這本書的附錄,前一個(gè)數(shù)據(jù)集用于測(cè)試和改進(jìn)算法,后一個(gè)數(shù)據(jù)集用來解決一個(gè)實(shí)際的問題。

首先實(shí)現(xiàn)傳統(tǒng)梯度下降算法以求解對(duì)數(shù)回歸模型,然后再引出本篇博客的另一個(gè)主題——隨機(jī)梯度下降算法。該部分使用第一個(gè)數(shù)據(jù)集,該數(shù)據(jù)集由若干個(gè)二維向量組成,每一個(gè)向量還帶有0或1的標(biāo)記值。傳統(tǒng)梯度下降算法的代碼如圖所示:

圖2

該算法首先初始化權(quán)重值weight全部為1,然后每次計(jì)算在sigmod函數(shù)下與標(biāo)準(zhǔn)分類結(jié)果的誤差error,使用倒數(shù)第二行的代碼對(duì)weight進(jìn)行更新,如此進(jìn)行10000次迭代,最后返回weight值。該函數(shù)最后還返回了一個(gè)res矩陣,該矩陣用來存儲(chǔ)每一次迭代weight的值,用于作圖。運(yùn)行結(jié)果如下圖所示:

圖3

圖中兩種顏色的點(diǎn)代表不同類別的數(shù)據(jù),中間的一條分割線即是分類器給出的最佳分類結(jié)果,可以看出在線性模式下該分類結(jié)果還是比較好的。然而我們分析上文給出的代碼可以看出,傳統(tǒng)的梯度下降算法每一次迭代都需要全部的數(shù)據(jù)集參與計(jì)算,這樣的模式在較小的數(shù)據(jù)集下尚能接受,若數(shù)據(jù)集較大這樣的方法就無法繼續(xù)使用了。

隨機(jī)梯度下降算法

隨機(jī)梯度下降算法是針對(duì)傳統(tǒng)梯度下降算法每一次迭代運(yùn)算量過大的改進(jìn)。在正式引入隨機(jī)梯度下降算法之前筆者先介紹一個(gè)中間算法,該算法也是對(duì)梯度下降算法的改進(jìn)。傳統(tǒng)的梯度下降算法每一次迭代都使用全部數(shù)據(jù)集對(duì)權(quán)重值進(jìn)行更新,能否換一種思路,每次迭代只用一個(gè)向量對(duì)權(quán)重值進(jìn)行更新,迭代的規(guī)模限定在數(shù)據(jù)集規(guī)模以內(nèi)。這樣不但降低了迭代次數(shù),而且這樣的學(xué)習(xí)算法是在線的,即能根據(jù)新的數(shù)據(jù)不斷優(yōu)化模型。實(shí)現(xiàn)代碼如下:

圖4

該算法與傳統(tǒng)梯度下降算法的區(qū)別主要在于單次迭代中對(duì)權(quán)重值的更新只用到了數(shù)據(jù)集中的一個(gè)子集。這樣的設(shè)定減少了算法的時(shí)間復(fù)雜度,提高了運(yùn)行速度,而且算法是在線的。運(yùn)行效果如下:

圖5

可以看出,該算法的分類效果不是很好,需要改進(jìn)。首先我們需要找出算法分類效果變差的原因,然后再制定改進(jìn)方案。算法分類效果的好壞很大程度上與權(quán)重值的計(jì)算有關(guān),我們首先需要確定在改進(jìn)算法的前提下權(quán)重值是否是收斂的。上文所列出的兩個(gè)算法都返回了res矩陣,該矩陣中存儲(chǔ)了權(quán)重值的變化情況,剛好可以用來觀察權(quán)重值的變化趨勢(shì),作圖如下:

圖6

圖中橫坐標(biāo)為迭代次數(shù),縱坐標(biāo)代表權(quán)重向量分量的大小。左圖為傳統(tǒng)梯度下降法下權(quán)重向量三個(gè)分量的變化情況,右圖是改進(jìn)算法下權(quán)重向量的變化情況。從右圖中我們可以看出除w2分量外,另外兩個(gè)分量都還看不出收斂的趨勢(shì),這可能是迭代次數(shù)過少造成的。除此之外,還可以看出右圖的曲線有較多波折,這可能是數(shù)據(jù)集中有一些不能被正確分類的點(diǎn)在每一次迭代中都對(duì)權(quán)重值造成很大干擾造成的。針對(duì)以上兩點(diǎn)筆者介紹一種改良算法——隨機(jī)梯度下降算法。該算法的實(shí)現(xiàn)代碼如下:

圖7

該算法主要有三個(gè)地方的改進(jìn):首先是增加了迭代次數(shù),且迭代次數(shù)可以由用戶傳入,這樣能使權(quán)重向量收斂;其次是對(duì)權(quán)重向量更新時(shí)選取的數(shù)據(jù)是隨機(jī)的,這樣能避免奇異點(diǎn)對(duì)權(quán)重向量的干擾;最后是每次迭代的步長(zhǎng)做了動(dòng)態(tài)調(diào)整,保證隨著迭代的增加步長(zhǎng)是減小且不是嚴(yán)格減小的,這樣的設(shè)定也是為了權(quán)重向量的加速收斂。采用隨機(jī)梯度下降算法的效果如下:

圖8

可以看出使用新的隨機(jī)梯度下降算法訓(xùn)練出的分類器也能取得不錯(cuò)的分類結(jié)果。接下來,檢查一下權(quán)重向量的收斂情況,效果如下圖:

圖9

左圖是最初的改進(jìn)算法的迭代次數(shù)與權(quán)重向量的關(guān)系,右圖是隨機(jī)梯度下降算法下迭代次數(shù)與權(quán)重向量的關(guān)系,可以看出在一共執(zhí)行了15000次迭代后權(quán)重向量已經(jīng)很好的收斂了,且不存在很大的噪聲,說明算法的效果還是不錯(cuò)的。

一個(gè)小應(yīng)用

這個(gè)應(yīng)用的例子也來源于《機(jī)器學(xué)習(xí)實(shí)戰(zhàn)》這本書,真的要感謝這本書給我啟蒙。這個(gè)小demo用對(duì)數(shù)回歸模型預(yù)測(cè)馬匹生病后是否會(huì)死亡,數(shù)據(jù)集分訓(xùn)練集和測(cè)試集,代碼如下:

圖10

算法在讀入測(cè)試數(shù)據(jù)后就分別使用傳統(tǒng)梯度下降算法,改進(jìn)算法和隨機(jī)梯度下降算法對(duì)分類器進(jìn)行訓(xùn)練,最后得到三個(gè)模型的解,然后用這三個(gè)解分別對(duì)測(cè)試數(shù)據(jù)集進(jìn)行測(cè)試,最終得出三個(gè)算法的錯(cuò)誤率,結(jié)果如下:

圖11

從結(jié)果上可以看出效果最好的仍然是傳統(tǒng)梯度下降算法,隨機(jī)梯度下降算法的表現(xiàn)也十分不錯(cuò),最差的是改進(jìn)算法,分錯(cuò)了60%的數(shù)據(jù)。

圖12

從左到右依次是傳統(tǒng)梯度下降算法,改進(jìn)算法和隨機(jī)梯度下降算法的權(quán)重向量的變化情況(只選取了權(quán)重向量的前三個(gè)分量作圖),可以看出梯度下降算法和隨機(jī)梯度下降算法的權(quán)重向量都收斂了,只是梯度下降算法的權(quán)重向量值抖動(dòng)很大,說明訓(xùn)練集中有很多噪聲點(diǎn),隨機(jī)梯度下降算法的權(quán)重向量的抖動(dòng)就小的多,說明隨機(jī)選取數(shù)據(jù)在減小噪聲點(diǎn)對(duì)權(quán)重向量收斂的影響上的效果是顯著的。做為對(duì)比,改進(jìn)算法的效果就差得多,權(quán)重向量沒有任何一個(gè)分量是收斂的。

開源

https://github.com/yhswjtuILMARE/LogisticRegression

2018年1月24日

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容