tensorflow實(shí)現(xiàn)
方法1:
tf.losses.softmax_cross_entropy(onehot_labels=y, logits=logit, label_smoothing=0.001)
方法2:
smoothing = 0.001
y -= smoothing * (y - 1. / tf.cast(y.shape[-1], y.dtype))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,
logits=logit))
label smoothing原理 (標(biāo)簽平滑)
對(duì)于分類問(wèn)題,常規(guī)做法時(shí)將類別做成one-hot vector,然后在網(wǎng)絡(luò)最后一層全鏈接的輸出后接一層softmax,softmax的輸出是歸一的,因此我們認(rèn)為softmax的輸出就是該樣本屬于某一類別的概率。由于標(biāo)簽是類別的one-hot vector, 因此表征我們已知該樣本屬于某一類別是概率為1的確定事件,而其他類別概率都為0。
softmax:
其中 一般叫做 logits ,即未被歸一化的對(duì)數(shù)概率 。我們用 p 代表 predicted probability,用 q 代表 groundtruth 。在分類問(wèn)題中l(wèi)oss函數(shù)一般用交叉熵,即:
cross entropy loss:
交叉熵對(duì)于logits可微,且偏導(dǎo)數(shù)形式簡(jiǎn)單: ,顯然梯度時(shí)有界的(-1到1)。
對(duì)于groundtruth為one-hot的情況,即每個(gè)樣本只有惟一的類別,則 ,
是真實(shí)類別。其中
是Dirac函數(shù)。要用predicted label 去擬合這樣的函數(shù)具有兩個(gè)問(wèn)題:首先,無(wú)法保證模型的泛化能力(generalizing),容易導(dǎo)致過(guò)擬合; 其次,全概率和零概率將鼓勵(lì)所屬類別和非所屬類別之間的差距盡可能拉大,而由于以上可知梯度有界,因此很難adapt。這種情況源于模型過(guò)于相信預(yù)測(cè)的類別。( Intuitively, this happens because the model becomes too confident about its predictions.)
因此提出一種機(jī)制,即要使得模型可以 less confident 。思路如下:考慮一個(gè)與樣本無(wú)關(guān)的分布 ,將我們的 label 即真實(shí)標(biāo)簽
變成
,其中:
可以理解為,對(duì)于 Dirac 函數(shù)分布的真實(shí)標(biāo)簽,我們將它變成以如下方式獲得:首先從標(biāo)注的真實(shí)標(biāo)簽的Dirac分布中取定,然后,以一定的概率 ,將其替換為在
分布中的隨機(jī)變量。因此可以避免上述的問(wèn)題。而
我們可以用先驗(yàn)概率來(lái)充當(dāng)。如果用 uniform distribution 的話就是 1/K 。該操作就叫做 label-smoothing regularization, or LSR 。
對(duì)于該操作的數(shù)學(xué)物理含義可以用交叉熵的概念說(shuō)明:

可以認(rèn)為 loss 函數(shù)分別以不同的權(quán)重對(duì) predicted label 與標(biāo)注的label 的差距 以及 predicted label 與 先驗(yàn)分布的差距 進(jìn)行懲罰,可以對(duì)分類性能有一定程度的提升。(In our ImageNet experiments with K = 1000 classes, we used u(k) = 1/1000 and = 0.1. For ILSVRC 2012, we have found a consistent improvement of about 0.2% absolute both for top-1 error and the top-5 error )
reference:
1. Szegedy C, Vanhoucke V, Ioffe S, et al. Rethinking the Inception Architecture for Computer Vision[C] Computer Vision and Pattern Recognition. IEEE, 2016:2818-2826.
2. https://github.com/tensorflow/cleverhans/blob/master/cleverhans_tutorials/mnist_tutorial_tf.py