BERT(Bidirectional Encoder Representations from Transformers)的MLM(Masked Language Model)損失是這樣設(shè)計(jì)的:在訓(xùn)練過程中,BERT隨機(jī)地將輸入文本中的一些單詞替換為一個(gè)特殊的[MASK]標(biāo)記,然后模型的任務(wù)是預(yù)測(cè)這些被掩蓋的單詞。具體來(lái)說(shuō),它會(huì)預(yù)測(cè)整個(gè)詞匯表中每個(gè)單詞作為掩蓋位置的概率。
MLM損失的計(jì)算方式是使用交叉熵?fù)p失函數(shù)。對(duì)于每個(gè)被掩蓋的單詞,模型會(huì)輸出一個(gè)概率分布,表示每個(gè)可能的單詞是正確單詞的概率。交叉熵?fù)p失函數(shù)會(huì)計(jì)算模型輸出的概率分布與真實(shí)單詞的分布(實(shí)際上是一個(gè)one-hot編碼,其中正確單詞的位置是1,其余位置是0)之間的差異。
具體來(lái)說(shuō),如果你有一個(gè)詞匯表大小為V,對(duì)于一個(gè)被掩蓋的單詞,模型會(huì)輸出一個(gè)V維的向量,表示詞匯表中每個(gè)單詞的概率。如果y是一個(gè)one-hot編碼的真實(shí)分布,而p是模型預(yù)測(cè)的分布,則交叉熵?fù)p失可以表示為(用于衡量模型預(yù)測(cè)概率分布與真實(shí)標(biāo)簽概率分布之間的差異):
其中:
-
表示損失函數(shù)的值
-
表示類別的數(shù)量
-
是第
個(gè)類別的真實(shí)標(biāo)簽,通常為0或1
-
是模型預(yù)測(cè)第
個(gè)類別的概率
-
表示自然對(duì)數(shù)
-
表示對(duì)所有類別求和
在這個(gè)公式中,是真實(shí)分布中的第i個(gè)元素,而
是模型預(yù)測(cè)的分布中的第i個(gè)元素。由于y是one-hot編碼的,所以除了正確單詞對(duì)應(yīng)的位置為1,其余位置都是0,這意味著上面的求和實(shí)際上只在正確單詞的位置計(jì)算。
在實(shí)際操作中,為了提高效率,通常不會(huì)對(duì)整個(gè)詞匯表進(jìn)行預(yù)測(cè),而是使用采樣技術(shù),如負(fù)采樣(negative sampling)或者層次softmax(hierarchical softmax),來(lái)減少每個(gè)訓(xùn)練步驟中需要計(jì)算的輸出數(shù)量。