一、simbert介紹和下載
simbert模型,是由蘇劍林開發(fā)的模型,以Google開源的BERT模型為基礎(chǔ),基于微軟的UniLM思想設(shè)計(jì)了融檢索與生成于一體的任務(wù),來進(jìn)一步微調(diào)后得到的模型,所以它同時(shí)具備相似問生成和相似句檢索能力。
SimBERT屬于有監(jiān)督訓(xùn)練,訓(xùn)練語(yǔ)料是自行收集到的相似句對(duì),通過一句來預(yù)測(cè)另一句的相似句生成任務(wù)來構(gòu)建Seq2Seq部分,然后前面也提到過[CLS]的向量事實(shí)上就代表著輸入的句向量,所以可以同時(shí)用它來訓(xùn)練一個(gè)檢索任務(wù)。
開源地址:https://github.com/ZhuiyiTechnology/simbert
項(xiàng)目介紹:https://kexue.fm/archives/7427
已預(yù)訓(xùn)練的模型包含 Tiny(26M)、Small(49M)、Base(344M)三個(gè)模型。
下載:
https://github.com/ZhuiyiTechnology/pretrained-models
適合創(chuàng)建相似文本集
>>> gen_synonyms(u'微信和支付寶哪個(gè)好?')
[
u'微信和支付寶,哪個(gè)好?',
u'微信和支付寶哪個(gè)好',
u'支付寶和微信哪個(gè)好',
u'支付寶和微信哪個(gè)好啊',
u'微信和支付寶那個(gè)好用?',
u'微信和支付寶哪個(gè)好用',
u'支付寶和微信那個(gè)更好',
u'支付寶和微信哪個(gè)好用',
u'微信和支付寶用起來哪個(gè)好?',
...........
]
二、項(xiàng)目使用
2.1 創(chuàng)建conda 環(huán)境
1)查看虛擬環(huán)境
(base) C:\Users\user>conda info -e
# conda environments:
#
base * D:\Programs\Anaconda3
pytorch D:\Programs\Anaconda3\envs\pytorch
2)創(chuàng)建虛擬環(huán)境
(base) C:\Users\user>conda create -n simbert
3)安裝依賴包
因bert4keras的版本要求,推薦版本見:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.7.7
(base) C:\Users\user>conda activate simbert
(simbert) C:\Users\user>conda install keras
(simbert) C:\Users\user>pip install bert4keras
(simbert) C:\Users\user>pip install tensorflow
卸載安裝錯(cuò)誤的情況:
(simbert) C:\Users\user>pip uninstall tensorflow
(simbert) C:\Users\user>pip uninstall bert4keras
重新安裝
conda create -n simbert python=3.6
(simbert) C:\Users\user>conda install tensorflow==1.14
(simbert) C:\Users\user>conda install keras==2.3.1
(simbert) C:\Users\user>pip install bert4keras==0.7.7
4)驗(yàn)證安裝情況
from keras.layers import *
from bert4keras.backend import keras, K
2.2使用simbert進(jìn)行測(cè)試驗(yàn)證
1)下載測(cè)試代碼
https://github.com/ZhuiyiTechnology/pretrained-models/blob/master/examples/simbert_base.py
#! -*- coding: utf-8 -*-
# SimBERT base 基本例子
# 測(cè)試環(huán)境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.7.7
import numpy as np
from collections import Counter
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import sequence_padding, AutoRegressiveDecoder
from bert4keras.snippets import uniout
from keras.layers import *
maxlen = 32
# bert配置
config_path = './bert/chinese_simbert_L-12_H-768_A-12/bert_config.json'
checkpoint_path = './bert/chinese_simbert_L-12_H-768_A-12/bert_model.ckpt'
dict_path = './bert/chinese_simbert_L-12_H-768_A-12/vocab.txt'
# 建立分詞器
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分詞器
# 建立加載模型
bert = build_transformer_model(
config_path,
checkpoint_path,
with_pool='linear',
application='unilm',
return_keras_model=False,
)
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
seq2seq = keras.models.Model(bert.model.inputs, bert.model.outputs[1])
class SynonymsGenerator(AutoRegressiveDecoder):
"""seq2seq解碼器
"""
@AutoRegressiveDecoder.set_rtype('probas')
def predict(self, inputs, output_ids, step):
token_ids, segment_ids = inputs
token_ids = np.concatenate([token_ids, output_ids], 1)
segment_ids = np.concatenate(
[segment_ids, np.ones_like(output_ids)], 1)
return seq2seq.predict([token_ids, segment_ids])[:, -1]
def generate(self, text, n=1, topk=5):
token_ids, segment_ids = tokenizer.encode(text, max_length=maxlen)
output_ids = self.random_sample([token_ids, segment_ids], n, topk) # 基于隨機(jī)采樣
return [tokenizer.decode(ids) for ids in output_ids]
synonyms_generator = SynonymsGenerator(start_id=None,
end_id=tokenizer._token_end_id,
maxlen=maxlen)
def gen_synonyms(text, n=100, k=20):
""""含義: 產(chǎn)生sent的n個(gè)相似句,然后返回最相似的k個(gè)。
做法:用seq2seq生成,并用encoder算相似度并排序。
"""
r = synonyms_generator.generate(text, n)
r = [i for i in set(r) if i != text]
r = [text] + r
X, S = [], []
for t in r:
x, s = tokenizer.encode(t)
X.append(x)
S.append(s)
X = sequence_padding(X)
S = sequence_padding(S)
Z = encoder.predict([X, S])
Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
argsort = np.dot(Z[1:], -Z[0]).argsort()
return [r[i + 1] for i in argsort[:k]]
"""
gen_synonyms(u'微信和支付寶哪個(gè)好?')
[
u'微信和支付寶,哪個(gè)好?',
u'微信和支付寶哪個(gè)好',
u'支付寶和微信哪個(gè)好',
u'支付寶和微信哪個(gè)好啊',
u'微信和支付寶那個(gè)好用?',
u'微信和支付寶哪個(gè)好用',
u'支付寶和微信那個(gè)更好',
u'支付寶和微信哪個(gè)好用',
u'微信和支付寶用起來哪個(gè)好?',
u'微信和支付寶選哪個(gè)好',
u'微信好還是支付寶比較用',
u'微信與支付寶哪個(gè)',
u'支付寶和微信哪個(gè)好用一點(diǎn)?',
u'支付寶好還是微信',
u'微信支付寶究竟哪個(gè)好',
u'支付寶和微信哪個(gè)實(shí)用性更好',
u'好,支付寶和微信哪個(gè)更安全?',
u'微信支付寶哪個(gè)好用?有什么區(qū)別',
u'微信和支付寶有什么區(qū)別?誰(shuí)比較好用',
u'支付寶和微信哪個(gè)好玩'
]
"""
# print(gen_synonyms(u'微信和支付寶哪個(gè)好?'))
print("-----------------------------------")
print(gen_synonyms(u'apache-ywn-int部署在哪一臺(tái)主機(jī)呢?', n=100, k=10))
2)目錄結(jié)構(gòu)
將下載的simbert預(yù)訓(xùn)練模型,放在項(xiàng)目中
(simbert_test) D:\tmp\20220508\simbert_base>tree /F
卷 新加卷 的文件夾 PATH 列表
卷序列號(hào)為 549E-27B0
D:.
│ simbert_base.py
│
└─bert
│ chinese_simbert_L-12_H-768_A-12.zip
│
└─chinese_simbert_L-12_H-768_A-12
bert_config.json
bert_model.ckpt.data-00000-of-00001
bert_model.ckpt.index
checkpoint
vocab.txt
3)執(zhí)行腳本
python simbert_base.py
['apache-ywn-int的部署在哪一臺(tái)主機(jī)上呢', 'apache-ywn-int部署在哪一臺(tái)主機(jī)', 'apache-ywn-int部署在哪臺(tái)主機(jī)', 'apache ywn-int部署在哪一個(gè)主機(jī)上', 'apache-ywn-int部署到哪一臺(tái)主機(jī)', 'apache中ywn-int部署在哪個(gè)主機(jī)', 'apache的ywn-int在哪一臺(tái)主機(jī)', '如何查看apache-ywn-int是否部署在哪個(gè)主機(jī)里', 'apache-ywn-int部署到哪個(gè)服務(wù)器上', 'apache ywn-int是怎么部署在哪一個(gè)服務(wù)器上']
有些輸出并不一定準(zhǔn)確,有一定的參考意義。
由于tensorflow的版本較低,會(huì)出現(xiàn)以下過期提醒信息,不影響運(yùn)行
D:\Programs\Anaconda3\envs\simbert\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
說明:性能并不高,如果要用于生產(chǎn)環(huán)境,還需要基于實(shí)際硬件設(shè)備進(jìn)行驗(yàn)證。