基于tensorflow的label smoothing實(shí)現(xiàn)

tensorflow實(shí)現(xiàn)

方法1:

tf.losses.softmax_cross_entropy(onehot_labels=y, logits=logit, label_smoothing=0.001)

方法2:

smoothing = 0.001
y -= smoothing * (y - 1. / tf.cast(y.shape[-1], y.dtype))
loss =  tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,
                                                             logits=logit))

label smoothing原理 (標(biāo)簽平滑)

對(duì)于分類問(wèn)題,常規(guī)做法時(shí)將類別做成one-hot vector,然后在網(wǎng)絡(luò)最后一層全鏈接的輸出后接一層softmax,softmax的輸出是歸一的,因此我們認(rèn)為softmax的輸出就是該樣本屬于某一類別的概率。由于標(biāo)簽是類別的one-hot vector, 因此表征我們已知該樣本屬于某一類別是概率為1的確定事件,而其他類別概率都為0。

softmax:p(k|x) = frac{exp(z_k)}{sum_i^{i=K}{exp(z_i)}}

其中 z_i 一般叫做 logits ,即未被歸一化的對(duì)數(shù)概率 。我們用 p 代表 predicted probability,用 q 代表 groundtruth 。在分類問(wèn)題中l(wèi)oss函數(shù)一般用交叉熵,即:

cross entropy loss: loss = -sum_{k=1}^{K}{q(k|x) log(p(k|x))}

交叉熵對(duì)于logits可微,且偏導(dǎo)數(shù)形式簡(jiǎn)單:frac{partial{loss}}{partial{z_k}}=p(k) - q(k) ,顯然梯度時(shí)有界的(-1到1)。

對(duì)于groundtruth為one-hot的情況,即每個(gè)樣本只有惟一的類別,則 q(k) = delta_{k,y} ,y 是真實(shí)類別。其中 delta 是Dirac函數(shù)。要用predicted label 去擬合這樣的函數(shù)具有兩個(gè)問(wèn)題:首先,無(wú)法保證模型的泛化能力(generalizing),容易導(dǎo)致過(guò)擬合; 其次,全概率和零概率將鼓勵(lì)所屬類別和非所屬類別之間的差距盡可能拉大,而由于以上可知梯度有界,因此很難adapt。這種情況源于模型過(guò)于相信預(yù)測(cè)的類別。( Intuitively, this happens because the model becomes too confident about its predictions.)

因此提出一種機(jī)制,即要使得模型可以 less confident 。思路如下:考慮一個(gè)與樣本無(wú)關(guān)的分布 u(k) ,將我們的 label 即真實(shí)標(biāo)簽 q(k) 變成 q^{'}(k) ,其中:

可以理解為,對(duì)于 Dirac 函數(shù)分布的真實(shí)標(biāo)簽,我們將它變成以如下方式獲得:首先從標(biāo)注的真實(shí)標(biāo)簽的Dirac分布中取定,然后,以一定的概率 epsilon ,將其替換為在 u(k) 分布中的隨機(jī)變量。因此可以避免上述的問(wèn)題。而 u(k) 我們可以用先驗(yàn)概率來(lái)充當(dāng)。如果用 uniform distribution 的話就是 1/K 。該操作就叫做 label-smoothing regularization, or LSR 。

對(duì)于該操作的數(shù)學(xué)物理含義可以用交叉熵的概念說(shuō)明:

交叉熵

可以認(rèn)為 loss 函數(shù)分別以不同的權(quán)重對(duì) predicted label 與標(biāo)注的label 的差距 以及 predicted label 與 先驗(yàn)分布的差距 進(jìn)行懲罰,可以對(duì)分類性能有一定程度的提升。(In our ImageNet experiments with K = 1000 classes, we used u(k) = 1/1000 and = 0.1. For ILSVRC 2012, we have found a consistent improvement of about 0.2% absolute both for top-1 error and the top-5 error )

reference:

1. Szegedy C, Vanhoucke V, Ioffe S, et al. Rethinking the Inception Architecture for Computer Vision[C] Computer Vision and Pattern Recognition. IEEE, 2016:2818-2826.

2. https://github.com/tensorflow/cleverhans/blob/master/cleverhans_tutorials/mnist_tutorial_tf.py

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容