1. 到底什么是知識蒸餾?
一般地,大模型往往是單個復(fù)雜網(wǎng)絡(luò)或者是若干網(wǎng)絡(luò)的集合,擁有良好的性能和泛化能力,而小模型因為網(wǎng)絡(luò)規(guī)模較小,表達(dá)能力有限。因此,可以利用大模型學(xué)習(xí)到的知識去指導(dǎo)小模型訓(xùn)練,使得小模型具有與大模型相當(dāng)?shù)男阅埽菂?shù)數(shù)量大幅降低,從而實現(xiàn)模型壓縮與加速,這就是知識蒸餾與遷移學(xué)習(xí)在模型優(yōu)化中的應(yīng)用。
Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知識蒸餾(暗知識提取)的概念,通過引入與教師網(wǎng)絡(luò)(teacher network:復(fù)雜、但推理性能優(yōu)越)相關(guān)的軟目標(biāo)(soft-target)作為total loss的一部分,以誘導(dǎo)學(xué)生網(wǎng)絡(luò)(student network:精簡、低復(fù)雜度)的訓(xùn)練,實現(xiàn)知識遷移(knowledge transfer)。

2.Hard-target 和 Soft-target
傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)訓(xùn)練方法是定義一個損失函數(shù),目標(biāo)是使預(yù)測值盡可能接近于真實值(Hard- target),損失函數(shù)就是使神經(jīng)網(wǎng)絡(luò)的損失值和盡可能小。這種訓(xùn)練過程是對ground truth求極大似然。在知識蒸餾中,是使用大模型的類別概率作為Soft-target的訓(xùn)練過程。

如在MNIST數(shù)據(jù)集中做手寫體數(shù)字識別任務(wù),假設(shè)某個輸入的“2”更加形似"3",softmax的輸出值中"3"對應(yīng)的概率會比其他負(fù)標(biāo)簽類別高;而另一個"2"更加形似"7",則這個樣本分配給"7"對應(yīng)的概率會比其他負(fù)標(biāo)簽類別高。這兩個"2"對應(yīng)的Hard-target的值是相同的,但是它們的Soft-target卻是不同的,由此我們可見Soft-target蘊含著比Hard-target更多的信息。

使用軟標(biāo)簽就是修改了softmax函數(shù),增加溫度系數(shù)T;

其中 Pi 是每個類別輸出的概率,Zi 是每個類別輸出的 logits,T 就是溫度。當(dāng)溫度 T=1 時,這就是標(biāo)準(zhǔn)的 Softmax 公式。 T越高,softmax 的output probability distribution越趨于平滑,其分布的熵越大,負(fù)標(biāo)簽攜帶的信息會被相對地放大,模型訓(xùn)練將更加關(guān)注負(fù)標(biāo)簽。
關(guān)于溫度T的影響:
image.png
溫度的高低改變的是Student模型訓(xùn)練過程中對負(fù)標(biāo)簽的關(guān)注程度。當(dāng)溫度較低時,對負(fù)標(biāo)簽的關(guān)注,尤其是那些顯著低于平均值的負(fù)標(biāo)簽的關(guān)注較少;而溫度較高時,負(fù)標(biāo)簽相關(guān)的值會相對增大,Student模型會相對更多地關(guān)注到負(fù)標(biāo)簽。
實際上,負(fù)標(biāo)簽中包含一定的信息,尤其是那些負(fù)標(biāo)簽概率值顯著高于平均值的負(fù)標(biāo)簽。但由于Teacher模型的訓(xùn)練過程決定了負(fù)標(biāo)簽部分概率值都比較小,并且負(fù)標(biāo)簽的值越低,其信息就越不可靠。因此溫度的選取需要進(jìn)行實際實驗的比較,本質(zhì)上就是在下面兩種情況之中取舍:
- 當(dāng)想從負(fù)標(biāo)簽中學(xué)到一些信息量的時候,溫度T應(yīng)調(diào)高一些;
- 當(dāng)想減少負(fù)標(biāo)簽的干擾的時候,溫度T應(yīng)調(diào)低一些;
總的來說,T的選擇和Student模型的大小有關(guān),Student模型參數(shù)量比較小的時候,相對比較低的溫度就可以了。因為參數(shù)量小的模型不能學(xué)到所有Teacher模型的知識,所以可以適當(dāng)忽略掉一些負(fù)標(biāo)簽的信息。
如果還不懂硬目標(biāo)和軟目標(biāo)區(qū)別,可以點擊查看跳擊查看,作者舉了很好的一個實例。
3. 知識蒸餾訓(xùn)練的具體方法:

訓(xùn)練Teacher的過程很簡單,我們把第2步和第3步過程統(tǒng)一稱為:高溫蒸餾的過程。高溫蒸餾過程的目標(biāo)函數(shù)由distill loss(對應(yīng)Soft-target)和Student loss(對應(yīng)Hard-target)加權(quán)得到。如下所示:

采用軟標(biāo)簽的知識蒸餾方法,一方面壓縮了模型,另一方面,增強(qiáng)了模型的泛化能力(因為 SN 在訓(xùn)練集上的效果肯定沒 TN 好)
