知識蒸餾源自Hinton et al.于2014年發(fā)表在NIPS的一篇文章:Distilling the Knowledge in a Neural Network。
1. 背景
一般情況下,我們在訓練模型的時候使用了大量訓練數(shù)據(jù)和計算資源來提取知識,但這不方便在工業(yè)中部署,原因有二:
(1)大模型推理速度慢
(2)對設(shè)備的資源要求高(大內(nèi)存)
因此我們希望對訓練好的模型進行壓縮,在保證推理效果的前提下減小模型的體量,知識蒸餾(Knownledge Distillation)屬于模型壓縮的一種方法 [1]。
2. 知識蒸餾
名詞解釋:
cumbersome model:原始模型或者說大模型,但在后續(xù)的論文中一般稱它為teacher model;
distilled model:蒸餾后的小模型,在后續(xù)的論文中一般稱它為stududent model;
hard targets:像[1, 0, 0]這樣的標簽,也叫做ground-truth label;
soft targets:像[0.7, 0.2, 0.1]這樣的標簽;
transfer set:訓練student model的數(shù)據(jù)
好模型的目標不是擬合訓練數(shù)據(jù),而是學習如何泛化到新的數(shù)據(jù)。所以蒸餾的目標是讓student學習到teacher的泛化能力,理論上得到的結(jié)果會比單純擬合訓練數(shù)據(jù)的student要好 [3]。顯然,soft target可以提供更大的信息熵,所以studetn model可以學習到更多的信息。
通俗的來講,粗暴的使用one-hot編碼把原本有幫助的類內(nèi)variance和類間distance都忽略了,比如貓和狗的相似性要比貓與摩托車的相似性要多,狗的某些特征可能對識別貓也會有幫助(比如毛發(fā)),因此使用soft target可以恢復(fù)被one-hot編碼丟棄的信息 [2]。
在Hinton et al. 發(fā)表的這篇論文中,作者提出了"softmax temperature"的概念,其公式為:
Python代碼:
import numpy as np
def softmax_t(x,t):
x_exp = np.exp(x / t)
return x_exp / np.sum(x_exp)
代表第
類的輸出概率,
和
為softmax的輸入,即上一層神經(jīng)元的輸出(logits),T表示temperature參數(shù)。通常情況下,我們使用的softmax函數(shù)T為1,但
可以控制輸出soft的程度。比如對于
,我們分別取
,然后畫出softmax函數(shù)的輸出可以看到,
越小,輸出的預(yù)測結(jié)果越“硬”(曲線更加曲折),T越大輸出的結(jié)果越“軟”(曲線更加平和)。

插一句題外話,為什么這里的參數(shù)是叫溫度(temperature)呢?這和蒸餾(distillation)這一熱力學工藝有關(guān)。在蒸餾工藝中,溫度越高提取到的物質(zhì)越純越濃縮。而在知識蒸餾中,參數(shù)T越大(溫度越高),teacher model產(chǎn)生的label越"soft",信息熵就越高,提煉的知識更具有一般性(generalization)。所以說作者將這一參數(shù)取名temperature十分有趣。

知識蒸餾的實現(xiàn)過程可以概括為:
- 訓練teacher model;
- 使用高溫T將teacher model中的知識蒸餾到student model(在測試時溫度T設(shè)為1)。
student modeld的目標函數(shù)由一下兩項的加權(quán)平均組成:
- distillation loss:soft targets(由teacher model產(chǎn)生) 和student model的soft predictions的交叉熵,這里的T使用的是和訓練teacher model相同的值。(保證student model和teacher model的結(jié)果盡可能一致)
- student loss:hard targets 和student model的輸出數(shù)據(jù)的交叉熵,但T設(shè)置為1。(保證student model的結(jié)果和實際類別標簽盡可能一致)
總體的損失函數(shù)可以寫作:
其中,表示輸入,
表示student model的參數(shù),
是ground-truth label,
是交叉熵損失函數(shù),
是剛剛提到的softmax temperature激活函數(shù),
和
分別表示student和teacher model神經(jīng)元的輸出(logits),
和
表示兩個權(quán)重參數(shù) [4].
原論文指出,要比
相對小一些可以取得更好的結(jié)果,因為在求梯度時soft targets被縮放了
,所以第2項要乘以一個更小的權(quán)值來平衡二者在優(yōu)化時的比重 [1].
換一個角度來想,這里的知識蒸餾其實是相對于對于原始交叉熵添加了一個正則項:
利用teacher model的先驗知識對student model進行正則化 [5]。
References:
[1] Distilling the Knowledge in a Neural Network.
[2] # Distilling the Knowledge in a Neural Network 論文筆記
[3] 深度神經(jīng)網(wǎng)絡(luò)模型蒸餾Distillation
[4] Knowledge Distillation
[5] 神經(jīng)網(wǎng)絡(luò)知識蒸餾 Knowledge Distillation