用“模型想象出來的target”來訓(xùn)練可以提高分類的效果

LearnFromPapers系列——用“模型想象出來的target”來訓(xùn)練可以提高分類的效果

<center>作者:郭必?fù)P</center>
<center>時間:2020年最后一天</center>

前言:今天是2020年最后一天,這篇文章也是我的SimpleAI公眾號2020年的最后一篇推文,感謝大家一直以來的陪伴和支持,希望SimpleAI曾帶給各位可愛的讀者們一點(diǎn)點(diǎn)的收獲吧~這么特殊的一天,我也來介紹一篇特殊的論文,那就是今年我和組里幾位老師合作的一篇AAAI論文:“Label Confusion Learning to Enhance Text Classification Models”。這篇文章的主要思想是通過構(gòu)造一個“標(biāo)簽混淆模型”來實(shí)時地“想象”一個比one-hot更好的標(biāo)簽分布,從而使得各種深度學(xué)習(xí)模型(LSTM、CNN、BERT)在分類問題上都能得到更好的效果。個人感覺,還是有、意思的。

  • 論文標(biāo)題:Label Confusion Learning to Enhance Text Classification Models
  • 會議/期刊:AAAI-21
  • 團(tuán)隊:上海財經(jīng)大學(xué) 信息管理與工程學(xué)院 AI Lab

一、主要貢獻(xiàn)

本文的主要貢獻(xiàn)有這么幾點(diǎn):

  • 構(gòu)造了一個插件--"Label Confusion Model(LCM)",可以在模型訓(xùn)練的時候?qū)崟r計算樣本和標(biāo)簽間的關(guān)系,從而生成一個標(biāo)簽分布,作為訓(xùn)練的target,實(shí)驗(yàn)證明,這個新的target比one-hot標(biāo)簽更好;
  • 這個插件不需要任何外部的知識,也僅僅在訓(xùn)練的時候才需要,不會增加模型預(yù)測時的時間,不改變原模型的結(jié)構(gòu)。所以LCM的應(yīng)用范圍很廣;
  • 實(shí)驗(yàn)發(fā)現(xiàn)LCM還具有出色的抗噪性和抗干擾能力,對于有錯標(biāo)的數(shù)據(jù)集,或者標(biāo)簽間相似度很高的數(shù)據(jù)集,有更好的表現(xiàn)。

二、問題背景、相關(guān)工作

1. 用one-hot來訓(xùn)練不夠好

本文主要是從文本分類的角度出發(fā)的,但文本分類和圖像分類實(shí)際上在訓(xùn)練模式上是類似的,基本都遵循這樣的一個流程:

step 1. 一個深度網(wǎng)絡(luò)(DNN,諸如LSTM、CNN、BERT等)來得到向量表示
step 2. 一個softmax分類器來輸出預(yù)測的標(biāo)簽概率分布p
step 3. 使用Cross-entropy來計算真實(shí)標(biāo)簽(one-hot表示)與p之間的損失,從而優(yōu)化

這里使用cross-entropy loss(簡稱CE-loss)基本上成了大家訓(xùn)練模型的默認(rèn)方法,但它實(shí)際上存在一些問題。下面我舉個例子:

比如有一個六個類別的分類任務(wù),CE-loss是如何計算當(dāng)前某個預(yù)測概率p相對于y的損失呢:


可以看出,根據(jù)CE-loss的公式,只有y中為1的那一維度參與了loss的計算,其他的都忽略了。這樣就會造成一些后果

  • 真實(shí)標(biāo)簽跟其他標(biāo)簽之間的關(guān)系被忽略了,很多有用的知識無法學(xué)到;比如:“鳥”和“飛機(jī)”本來也比較像,因此如果模型預(yù)測覺得二者更接近,那么應(yīng)該給予更小的loss
  • 傾向于讓模型更加“武斷”,成為一個“非黑即白”的模型,導(dǎo)致泛化性能差;
  • 面對易混淆的分類任務(wù)、有噪音(誤打標(biāo))的數(shù)據(jù)集時,更容易受影響

總之,這都是由one-hot的不合理表示造成的,因?yàn)閛ne-hot只是對真實(shí)情況的一種簡化。

2. 一些可能的解決辦法

LDL
既然one-hot不合理,那我們就使用更合理的標(biāo)簽分布來訓(xùn)練嘛。比如下圖所示:


如果我們能獲取真實(shí)的標(biāo)簽分布來訓(xùn)練,那該多好啊。

這種使用標(biāo)簽的分布來學(xué)習(xí)模型的方法,稱為LDL(Label Distribution Learning),東南大學(xué)耿新團(tuán)隊專門研究這個方面,大家可以去了解一下。

但是,真實(shí)的標(biāo)簽分布,往往很難獲取,甚至不可獲取,只能模擬。比如找很多人來投票,或者通過觀察進(jìn)行統(tǒng)計。比如在耿新他們最初的LDL論文中,提出了很多生物數(shù)據(jù)集,是通過實(shí)驗(yàn)觀察來得到的標(biāo)簽分布。然而,大多數(shù)的現(xiàn)有的數(shù)據(jù)集,尤其是文本、圖像分類,幾乎都是one-hot的,所以LDL并無法直接使用。

Label Enhancement
Label Enhancement,機(jī)標(biāo)簽增強(qiáng)技術(shù),則是一類從通過樣本特征空間來生成標(biāo)簽分布的方法,我在前面的論文解讀中有介紹,這些方法都很有趣。

然而,使用這些方法來訓(xùn)練模型,都比較麻煩,因?yàn)槲覀冃枰ㄟ^“兩步走”來訓(xùn)練,第一步使用LE的方法來構(gòu)造標(biāo)簽分布,第二步再使用標(biāo)簽分布來訓(xùn)練。

Loss Correction
面對one-hot可能帶來的容易過擬合的問題,有研究提出了Label Smoothing方法:

label smoothing就是把原來的one-hot表示,在每一維上都添加了一個隨機(jī)噪音。這是一種簡單粗暴,但又十分有效的方法,目前已經(jīng)使用在很多的圖像分類模型中了。

這種方法,一定程度上,可以緩解模型過于武斷的問題,也有一定的抗噪能力。但是單純地添加隨機(jī)噪音,也無法反映標(biāo)簽之間的關(guān)系,因此對模型的提升有限,甚至有欠擬合的風(fēng)險。

當(dāng)然還有一些其他的Loss Correction方法,可以參考我前面的一個介紹。

三、我們的思想&模型設(shè)計

我們最終的目標(biāo),是能夠使用更加合理的標(biāo)簽分布來代替one-hot分布訓(xùn)練模型,最好這個過程能夠和模型的訓(xùn)練同步進(jìn)行。

首先我們思考,一個合理的標(biāo)簽分布,應(yīng)該有什么樣的性質(zhì)。

① 很自然地,標(biāo)簽分布應(yīng)該可以反映標(biāo)簽之間的相似性。
比方下面這個例子:


② 標(biāo)簽間的相似性是相對的,要根據(jù)具體的樣本內(nèi)容來看。
比方下面這個例子,同樣的標(biāo)簽,對于不同的句子,標(biāo)簽之間的相似度也是不一樣的:


③ 構(gòu)造得到的標(biāo)簽分布,在01化之后應(yīng)該跟原one-hot表示相同。
啥意思呢,就是我們不能構(gòu)造出了一個標(biāo)簽分布,最大值對應(yīng)的標(biāo)簽跟原本的one-hot標(biāo)簽還不一致,我們最終的標(biāo)簽分布,還是要以one-hot為標(biāo)桿來構(gòu)造。

根據(jù)上面的思考,我們這樣來設(shè)計模型:

使用一個Label Encoder來學(xué)習(xí)各個label的表示,與input sample的向量表示計算相似度,從而得到一個反映標(biāo)簽之間的混淆/相似程度的分布。最后,使用該混淆分布來調(diào)整原來的one-hot分布,從而得到一個更好的標(biāo)簽分布。

設(shè)計出來的模型結(jié)構(gòu)如圖:


這個結(jié)構(gòu)分兩部分,左邊是一個Basic Predictor,就是各種我們常用的分類模型。右邊的則是LCM的模型。注意LCM是一個插件,所以左側(cè)可以更換成任何深度學(xué)習(xí)模型。

Basic Predictor的過程可以用如下公式表達(dá):
\begin{aligned} v^{(i)} &=f^{I}(x)=f^{I}\left(\left[x_{1}, x_{2}, \ldots, x_{n}\right]\right) \\ &=\left[v_{1}^{(i)}, v_{2}^{(i)}, \ldots, v_{n}^{(i)}\right] \\ y^{(p)} &=\operatorname{softmax}\left(v^{(i)}\right) \end{aligned}
其中v^i就是輸入的文本的通過Input Decoder得到的表示。y^p則是predicted label distribution(PLD)。

LCM的過程可以表達(dá)為:
\begin{aligned} V^{(l)} &=f^{L}(l)=f^{L}\left(\left[l_{1}, l_{2}, \ldots, l_{C}\right]\right) \\ &=\left[v_{1}^{(l)}, v_{2}^{(l)}, \ldots, v_{C}^{(l)}\right] \\ y^{(c)} &=\operatorname{softmax}\left(v^{(i)^{\top}} V^{(l)} W+b\right) \\ y^{(s)} &=\operatorname{softmax}\left(\alpha y^{(t)}+y^{(c)}\right) \end{aligned}
其中V^l代表label通過Label Encoder得到的標(biāo)簽表示矩陣,y^c是標(biāo)簽和輸入文本的相似度得到的標(biāo)簽混淆分布,y^t是真實(shí)的one-hot表示,二者通過一個超參數(shù)結(jié)合再歸一化,得到最終的y^s,即模擬標(biāo)簽分布,simulated label distribution(SLD)。

最后,我們使用KL散度來計算loss:
\begin{aligned} \text {loss} &=K L \text {-divergence}\left(y^{(s)}, y^{(p)}\right) \\ &=\sum_{c}^{C} y_{c}^{(s)} \log \left(\frac{y_{c}^{(s)}}{y_{c}^{(p)}}\right) \end{aligned}

總體來說還是比較簡單的,很好復(fù)現(xiàn),其實(shí)也存在更優(yōu)的模型結(jié)構(gòu),我們還在探究。

四、實(shí)驗(yàn)&結(jié)果分析

1. Benchmark數(shù)據(jù)集上的測試

我們使用了2個中文數(shù)據(jù)集和3個英文數(shù)據(jù)集,在LSTM、CNN、BERT三種模型架構(gòu)上進(jìn)行測試,實(shí)驗(yàn)表明LCM可以在絕大多數(shù)情況下,提升主流模型的分類效果。


下面這個圖展示了不同水平的α超參數(shù)對模型的影響:



從圖中可以看出,不管α水平如何,LCM加成的模型,都可以顯著提高收斂速度,最終的準(zhǔn)確率也更高。針對不同的數(shù)據(jù)集特征,我們可以使用不同的α(比如數(shù)據(jù)集混淆程度大,可以使用較小的α),另外,論文中我們還介紹了在使用較小α的時候,可以使用early-stop策略來防止過擬合。

而下面這個圖則展示了LCM確實(shí)可以學(xué)習(xí)到label之間的一些相似性關(guān)系,而且是從完全隨機(jī)的初始狀態(tài)開始學(xué)到的:


2. 難以區(qū)分的數(shù)據(jù)集(標(biāo)簽易混淆)

我們構(gòu)造了幾個“簡單的”和“困難的”數(shù)據(jù)集,通過實(shí)驗(yàn)標(biāo)簽,LCM更適合那些容易混淆的數(shù)據(jù)集:


3. 有噪音的數(shù)據(jù)集

我們還測試了在不同噪音水平下的數(shù)據(jù)集上的效果,并跟Label Smoothing方法做了對比,發(fā)現(xiàn)是顯著好于LS方法的。


下面這個圖展示了另外一組更細(xì)致的實(shí)驗(yàn)結(jié)果:


4. 在圖像分類上也有效果

最后,我們在圖像任務(wù)上也簡單測試了一下,發(fā)現(xiàn)也有效果:


總結(jié):

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容