這周五最后的時(shí)間看完了蘇神剛更新的CoSENT,主要是為了解決目前Sentence-BERT等模型訓(xùn)練和預(yù)測(cè)目標(biāo)不一致的問(wèn)題,模型細(xì)節(jié)可以參考蘇神的博客。本著學(xué)習(xí)的態(tài)度,用ark-nlp復(fù)現(xiàn)了一下蘇神的實(shí)驗(yàn),并使用CHIP-STS測(cè)試該模型在醫(yī)療場(chǎng)景下的效果。
代碼地址:https://github.com/xiangking/PyTorch_CoSENT
實(shí)驗(yàn)數(shù)據(jù)
- ATEC、BQ、LCQMC和PAWSX:https://github.com/bojone/BERT-whitening/tree/main/chn
- CHIP-STS(平安醫(yī)療科技疾病問(wèn)答遷移學(xué)習(xí)):https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414
CoSENT在形式上和無(wú)監(jiān)督的simcse還是比較相似的,由于筆者在實(shí)現(xiàn)無(wú)監(jiān)督的simcse時(shí)是通過(guò)句子分別通過(guò)模型的形式,所以大部分的結(jié)構(gòu)都與CoSENT,更多的更改在損失這一塊:
# bert_embedding其實(shí)就是句子輸入BERT后生成的向量經(jīng)過(guò)池化或直接使用CLS對(duì)應(yīng)的向量后的表示,
bert_embedding_a = self.get_pooled_embedding(
input_ids_a,
token_type_ids_a,
position_ids_ids_a,
attention_mask_a
)
bert_embedding_b = self.get_pooled_embedding(
input_ids_b,
token_type_ids_b,
position_ids_b,
attention_mask_b
)
cosine_sim = torch.sum(bert_embedding_a * bert_embedding_b, dim=1) * 20
cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]
labels = label_ids[:, None] < label_ids[None, :]
labels = labels.long()
cosine_sim = cosine_sim - (1 - labels) * 1e12
# 上面的代碼對(duì)照公式便可理解,該部分添加0,主要是為了防止例如[-1e12, -1e12, -1e12]經(jīng)過(guò)logsumexp的問(wèn)題
cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)
loss = torch.logsumexp(cosine_sim.view(-1), dim=0)
參數(shù)設(shè)置
代碼參數(shù)設(shè)置如下:
句子截?cái)嚅L(zhǎng)度:64(PAWSX數(shù)據(jù)集截?cái)嚅L(zhǎng)度為128)
batch_size:32
epochs:5
效果
使用spearman系數(shù)作為測(cè)評(píng)指標(biāo),ATEC、BQ、LCQMC和PAWSX使用test集進(jìn)行測(cè)試實(shí)驗(yàn),CHIP-STS則使用驗(yàn)證集
| ATEC | BQ | LCQMC | PAWSX | CHIP-STS | |
|---|---|---|---|---|---|
| BERT+CoSENT(ark-nlp) | 49.80 | 72.46 | 79.00 | 59.17 | 76.22 |
| BERT+CoSENT(bert4keras) | 49.74 | 72.38 | 78.69 | 60.00 | |
| Sentence-BERT(bert4keras) | 46.36 | 70.36 | 78.72 | 46.86 |
PS:上表ark-nlp展示的是5輪里最好的結(jié)果,由于沒(méi)有深入了解bert4keras,所以設(shè)置參數(shù)可能還是存在差異,因此對(duì)比僅供參考
針對(duì)CHIP-STS測(cè)試數(shù)據(jù)集,選擇閾值為0.6生成結(jié)果進(jìn)行提交,結(jié)果如下:
| Precision | Recall | Macro-F1 | |
|---|---|---|---|
| BERT+CoSENT | 82.89 | 81.528 | 81.688 |
| BERT句子對(duì)分類 | 84.331 | 83.799 | 83.924 |
致謝
感謝蘇神無(wú)私的分享