BERT MLM LOSS2024-05-30

BERT(Bidirectional Encoder Representations from Transformers)的MLM(Masked Language Model)損失是這樣設(shè)計(jì)的:在訓(xùn)練過程中,BERT隨機(jī)地將輸入文本中的一些單詞替換為一個(gè)特殊的[MASK]標(biāo)記,然后模型的任務(wù)是預(yù)測(cè)這些被掩蓋的單詞。具體來(lái)說(shuō),它會(huì)預(yù)測(cè)整個(gè)詞匯表中每個(gè)單詞作為掩蓋位置的概率。

MLM損失的計(jì)算方式是使用交叉熵?fù)p失函數(shù)。對(duì)于每個(gè)被掩蓋的單詞,模型會(huì)輸出一個(gè)概率分布,表示每個(gè)可能的單詞是正確單詞的概率。交叉熵?fù)p失函數(shù)會(huì)計(jì)算模型輸出的概率分布與真實(shí)單詞的分布(實(shí)際上是一個(gè)one-hot編碼,其中正確單詞的位置是1,其余位置是0)之間的差異。

具體來(lái)說(shuō),如果你有一個(gè)詞匯表大小為V,對(duì)于一個(gè)被掩蓋的單詞,模型會(huì)輸出一個(gè)V維的向量,表示詞匯表中每個(gè)單詞的概率。如果y是一個(gè)one-hot編碼的真實(shí)分布,而p是模型預(yù)測(cè)的分布,則交叉熵?fù)p失可以表示為(用于衡量模型預(yù)測(cè)概率分布與真實(shí)標(biāo)簽概率分布之間的差異):

L = -\sum_{i=1}^{V} y_i \log(p_i)

其中:

  • L 表示損失函數(shù)的值
  • V 表示類別的數(shù)量
  • y_i 是第 i 個(gè)類別的真實(shí)標(biāo)簽,通常為0或1
  • p_i 是模型預(yù)測(cè)第 i 個(gè)類別的概率
  • \log 表示自然對(duì)數(shù)
  • \sum 表示對(duì)所有類別求和

在這個(gè)公式中,y_i是真實(shí)分布中的第i個(gè)元素,而p_i是模型預(yù)測(cè)的分布中的第i個(gè)元素。由于y是one-hot編碼的,所以除了正確單詞對(duì)應(yīng)的位置為1,其余位置都是0,這意味著上面的求和實(shí)際上只在正確單詞的位置計(jì)算。

在實(shí)際操作中,為了提高效率,通常不會(huì)對(duì)整個(gè)詞匯表進(jìn)行預(yù)測(cè),而是使用采樣技術(shù),如負(fù)采樣(negative sampling)或者層次softmax(hierarchical softmax),來(lái)減少每個(gè)訓(xùn)練步驟中需要計(jì)算的輸出數(shù)量。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容