使用Torch實(shí)現(xiàn)CoSENT實(shí)驗(yàn)

這周五最后的時(shí)間看完了蘇神剛更新的CoSENT,主要是為了解決目前Sentence-BERT等模型訓(xùn)練和預(yù)測(cè)目標(biāo)不一致的問(wèn)題,模型細(xì)節(jié)可以參考蘇神的博客。本著學(xué)習(xí)的態(tài)度,用ark-nlp復(fù)現(xiàn)了一下蘇神的實(shí)驗(yàn),并使用CHIP-STS測(cè)試該模型在醫(yī)療場(chǎng)景下的效果。

代碼地址https://github.com/xiangking/PyTorch_CoSENT

實(shí)驗(yàn)數(shù)據(jù)

CoSENT在形式上和無(wú)監(jiān)督的simcse還是比較相似的,由于筆者在實(shí)現(xiàn)無(wú)監(jiān)督的simcse時(shí)是通過(guò)句子分別通過(guò)模型的形式,所以大部分的結(jié)構(gòu)都與CoSENT,更多的更改在損失這一塊:

# bert_embedding其實(shí)就是句子輸入BERT后生成的向量經(jīng)過(guò)池化或直接使用CLS對(duì)應(yīng)的向量后的表示,
bert_embedding_a = self.get_pooled_embedding(
    input_ids_a,
    token_type_ids_a,
    position_ids_ids_a,
    attention_mask_a
)

bert_embedding_b = self.get_pooled_embedding(
    input_ids_b,
    token_type_ids_b,
    position_ids_b,
    attention_mask_b
)

cosine_sim = torch.sum(bert_embedding_a * bert_embedding_b, dim=1) * 20
cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]
    
labels = label_ids[:, None] < label_ids[None, :]
labels = labels.long()
    
cosine_sim = cosine_sim - (1 - labels) * 1e12

# 上面的代碼對(duì)照公式便可理解,該部分添加0,主要是為了防止例如[-1e12, -1e12, -1e12]經(jīng)過(guò)logsumexp的問(wèn)題
cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)
loss = torch.logsumexp(cosine_sim.view(-1), dim=0)

參數(shù)設(shè)置

代碼參數(shù)設(shè)置如下:

句子截?cái)嚅L(zhǎng)度:64(PAWSX數(shù)據(jù)集截?cái)嚅L(zhǎng)度為128)
batch_size:32
epochs:5

效果

使用spearman系數(shù)作為測(cè)評(píng)指標(biāo),ATEC、BQ、LCQMC和PAWSX使用test集進(jìn)行測(cè)試實(shí)驗(yàn),CHIP-STS則使用驗(yàn)證集

ATEC BQ LCQMC PAWSX CHIP-STS
BERT+CoSENT(ark-nlp) 49.80 72.46 79.00 59.17 76.22
BERT+CoSENT(bert4keras) 49.74 72.38 78.69 60.00
Sentence-BERT(bert4keras) 46.36 70.36 78.72 46.86

PS:上表ark-nlp展示的是5輪里最好的結(jié)果,由于沒(méi)有深入了解bert4keras,所以設(shè)置參數(shù)可能還是存在差異,因此對(duì)比僅供參考

針對(duì)CHIP-STS測(cè)試數(shù)據(jù)集,選擇閾值為0.6生成結(jié)果進(jìn)行提交,結(jié)果如下:

Precision Recall Macro-F1
BERT+CoSENT 82.89 81.528 81.688
BERT句子對(duì)分類 84.331 83.799 83.924

致謝

感謝蘇神無(wú)私的分享

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容