谷歌發(fā)布bert已經(jīng)有一段時(shí)間了,但是僅在最近一個(gè)文本分類任務(wù)中實(shí)戰(zhàn)使用過,順便記錄下使用過程。記錄前先對(duì)bert的代碼做一個(gè)簡單的解讀
bert源碼
首先我們從官方bert倉庫clone一份源碼到本地,看下目錄結(jié)構(gòu):
.
├── CONTRIBUTING.md
├── create_pretraining_data.py # 構(gòu)建預(yù)訓(xùn)練結(jié)構(gòu)數(shù)據(jù)
├── extract_features.py
├── __init__.py
├── LICENSE
├── modeling.py # 預(yù)訓(xùn)練模型結(jié)果
├── modeling_test.py
├── multilingual.md
├── optimization.py # 優(yōu)化器選擇, 學(xué)習(xí)率等參數(shù)設(shè)置
├── optimization_test.py
├── predicting_movie_reviews_with_bert_on_tf_hub.ipynb
├── README.md
├── requirements.txt
├── run_classifier.py # 自定義微調(diào)腳本
├── run_classifier_with_tfhub.py
├── run_pretraining.py # 預(yù)訓(xùn)練腳本
├── run_squad.py
├── sample_text.txt
├── tokenization.py # 分詞工具
└── tokenization_test.py
跟咱們有關(guān)的只有 run_classifier.py 以及 run_pretraining.py 兩個(gè)腳本
run_pretraining
其中 run_pretraining.py 是用于預(yù)訓(xùn)練的腳本, 這個(gè)真的老老實(shí)實(shí)用谷歌已經(jīng)訓(xùn)練好的模型
吧, 畢竟沒有條件支撐自己去重新訓(xùn)練一個(gè)模型。找到簡體中文模型(chinese_L-12_H-768_A-12),將模型下載解壓后目錄結(jié)構(gòu)如下:
├── bert_config.json # bert基礎(chǔ)參數(shù)配置
├── bert_model.ckpt.data-00000-of-00001 # 預(yù)訓(xùn)練模型
├── bert_model.ckpt.index
├── bert_model.ckpt.meta
└── vocab.txt # 字符編碼
之后的各種NLP任務(wù)都可以用這個(gè)模型。實(shí)際上我用的是哈工大版的中文預(yù)訓(xùn)練BERT-wwm模型,由于其預(yù)訓(xùn)練階段采用全詞遮罩(Whole Word Masking)技術(shù),據(jù)稱實(shí)際效果要優(yōu)于谷歌官方發(fā)布的中文與訓(xùn)練模型,感興趣的小伙伴可以點(diǎn)擊該鏈接
具體查看。
run_classifier
微調(diào)(Fine-Tuning)階段是核心部分,關(guān)鍵代碼就是如何自定義自己的 Processor,源碼中已經(jīng)包含了4個(gè)NLP任務(wù)的 Processor 寫法示例,分別為:XnliProcessor MnliProcessor MrpcProcessor ColaProcessor。每個(gè) Processor 都實(shí)現(xiàn)了下面的這些函數(shù),以 MnliProcessor 為例:
get_train_examples: 獲取訓(xùn)練數(shù)據(jù)函數(shù),需要在對(duì)應(yīng)文件夾下有 "train.tsv" 文件
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
get_dev_examples: 獲取驗(yàn)證數(shù)據(jù)函數(shù),需要在對(duì)應(yīng)文件夾下有 "dev_matched.tsv" 文件
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
get_test_examples: 獲取測試數(shù)據(jù)函數(shù),需要在對(duì)應(yīng)文件夾下有 "dev_matched.tsv" 文件
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
get_labels: 分類標(biāo)簽獲取
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
_create_examples: 構(gòu)建訓(xùn)練樣本
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
if set_type == "test":
label = "contradiction"
else:
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
開始
首先將 run_classifier.py 文件備份一份, 然后我們直接在這上面修改, 本次做的文本分類任務(wù)是一個(gè)多分類任務(wù)(文本涉黃涉政檢查), 首先重寫自己的 Processor
class TextProcessor(DataProcessor):
"""用于文本分類任務(wù)的Processor"""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""
0: 正常文本
1: 涉黃文本
2: 涉政文本
"""
return ["0", "1", "2"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
# 注意下自己的樣本格式, 我是label在第一列, 文本在第二列
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
然后將剛剛寫好的 Processor 注冊(cè)至 main 函數(shù)下的 processors字典, 其中 text 是自定義的任務(wù)名稱, 運(yùn)行腳本時(shí)需要用到
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"text": TextProcessor, # 文本分類
}
...
另外, 為了后續(xù)準(zhǔn)備我們需要在 convert_single_example 函數(shù)中增加點(diǎn)內(nèi)容, 將label的id映射寫入文件。如果忘記了也沒影響,自己補(bǔ)充上這個(gè)文件即可
import pickle
...
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
if isinstance(example, PaddingInputExample):
return InputFeatures(
input_ids=[0] * max_seq_length,
input_mask=[0] * max_seq_length,
segment_ids=[0] * max_seq_length,
label_id=0,
is_real_example=False)
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
########################## 新增部分 ###################################
output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
if not os.path.exists(output_label2id_file):
with open(output_label2id_file, 'wb') as w:
pickle.dump(label_map, w)
########################## 新增部分 ###################################
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
...
至此 run_classifier.py 就改造完了
數(shù)據(jù)準(zhǔn)備
創(chuàng)建一個(gè) data 文件夾
.
├── dev_matched.tsv # 驗(yàn)證集
├── test_matched.tsv # 測試集
└── train.tsv # 訓(xùn)練集
數(shù)據(jù)結(jié)構(gòu)如下
0 絲襪被撕開,一片雪白的肌膚乍現(xiàn)...
0 保姆接下來快速地在我那根手指上面揉搓,頓時(shí)感到疼痛...
1 她再也忍不住,本能占據(jù)了她的矜持,伸手探向了唐楓的身...
2 北京天安門...
開始訓(xùn)練
為了方便,將啟動(dòng)任務(wù)寫到腳本中
export DATA_DIR="data文件夾絕對(duì)路徑"
export BERT_BASE_DIR="bert訓(xùn)練模型路徑"
export OUTPUT_DIR="模型輸出路徑"
python run_classifier.py \
--task_name=text \ '任務(wù)名,上邊寫道 `processor` 中的key'
--do_train=true \ '進(jìn)行訓(xùn)練, data下要有對(duì)應(yīng)的 train.tsv'
--do_eval=true \ '進(jìn)行驗(yàn)證, data下要有對(duì)應(yīng)的 dev_matched.tsv'
--do_predict=true \ '進(jìn)行測試, data下要有對(duì)應(yīng)的 test_matched.tsv'
--data_dir=${DATA_DIR}/ \
--vocab_file=${BERT_BASE_DIR}/vocab.txt \
--bert_config_file=${BERT_BASE_DIR}/bert_config.json \
--init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \
--max_seq_length=128 \ '序列最大長度'
--train_batch_size=16 \ '批大小, 過大可能顯存超出會(huì)報(bào)錯(cuò), 過小可能擬合不夠好'
--learning_rate=2e-5 \ '學(xué)習(xí)率, 默認(rèn)'
--num_train_epochs=20 \ '訓(xùn)練輪數(shù)'
--output_dir=${OUTPUT_DIR} '輸出路徑'
好了,啟動(dòng)腳本進(jìn)行訓(xùn)練吧,會(huì)看到日志:
WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7fd4b2c01488>) includes params argument, but params are not passed to Estimator.
INFO:tensorflow:Using config: {'_model_dir': '/home/SanJunipero/rd/dujihan/book_content_review/model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: truegraph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd4b1b80518>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=1000, num_shards=8, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None} INFO:tensorflow:_TPUContext: eval_on_tpu TrueWARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.
INFO:tensorflow:Writing example 0 of 70000
...
...
...
遇到的錯(cuò)誤:
- 版本問題,gpu版本嘗試使用1.x版本,cpu版本查看函數(shù)api模塊位置是否需要修改
- 批大小設(shè)置問題,第一次設(shè)置的
train_batch_size=32,顯存不足導(dǎo)致報(bào)錯(cuò),修改為 16 后問題解決
同時(shí)會(huì)在輸出文件夾中有 events 開頭的訓(xùn)練過程參數(shù)變化信息, 如果安裝了 tensorboard 可以通過瀏覽器查看模型訓(xùn)練情況, 通過命令行啟動(dòng) tensorboard
tensorboard --logdir=./輸出路徑
看到如下信息表示成功
W1031 17:09:47.128739 140197393516288 plugin_event_accumulator.py:294] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events. Overwriting the graph with the newest event.
W1031 17:09:47.147893 140197393516288 plugin_event_accumulator.py:302] Found more than one metagraph event per run. Overwriting the metagraph with the newest event.
TensorBoard 1.13.1 at http://127.0.0.1:6006 (Press CTRL+C to quit)
然后通過瀏覽器打開上邊的地址就可以看到模型的訓(xùn)練情況

訓(xùn)練完成
訓(xùn)練完成后會(huì)在輸出目錄看到如下文件
.
├── checkpoint
├── eval # 如果訓(xùn)練時(shí) --do_eval 為 true 就會(huì)有此目錄
│ └── events.out.tfevents.1572485638.localhost
├── eval_results.txt # 模型在驗(yàn)證集上的表現(xiàn)
├── eval.tf_record
├── events.out.tfevents.localhost
├── graph.pbtxt
├── label2id.pkl # label對(duì)應(yīng)的id映射
├── model.ckpt-84000.data-00000-of-00001 # 訓(xùn)練好的模型, 默認(rèn)保存最近5個(gè)
├── model.ckpt-84000.index
├── model.ckpt-84000.meta
├── model.ckpt-85000.data-00000-of-00001
├── model.ckpt-85000.index
├── model.ckpt-85000.meta
├── model.ckpt-86000.data-00000-of-00001
├── model.ckpt-86000.index
├── model.ckpt-86000.meta
├── model.ckpt-87000.data-00000-of-00001
├── model.ckpt-87000.index
├── model.ckpt-87000.meta
├── model.ckpt-87500.data-00000-of-00001
├── model.ckpt-87500.index
├── model.ckpt-87500.meta
├── predict.tf_record
├── test_results.tsv # 模型對(duì)測試集預(yù)測的結(jié)果
└── train.tf_record
打開 eval_results.txt 看下模型的最終效果
1 eval_accuracy = 0.8573
2 eval_loss = 1.4312192
3 global_step = 87500
4 loss = 1.4312192
至此, 模型訓(xùn)練完畢
服務(wù)部署
服務(wù)部署前我們需要通過一個(gè)別人寫好的腳本 freeze_graph.py 將我們的模型壓縮一下(腳本需與run_classifier.py放在同級(jí)目錄下), 完整代碼如下, 或者移步至我的github下載:
import json
import os
from enum import Enum
from termcolor import colored
import sys
import modeling
import logging
import pickle
import tensorflow as tf
import argparse
def set_logger(context, verbose=False):
if os.name == 'nt': # for Windows
return NTLogger(context, verbose)
logger = logging.getLogger(context)
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
formatter = logging.Formatter(
'%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt=
'%m-%d %H:%M:%S')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
console_handler.setFormatter(formatter)
logger.handlers = []
logger.addHandler(console_handler)
return logger
class NTLogger:
def __init__(self, context, verbose):
self.context = context
self.verbose = verbose
def info(self, msg, **kwargs):
print('I:%s:%s' % (self.context, msg), flush=True)
def debug(self, msg, **kwargs):
if self.verbose:
print('D:%s:%s' % (self.context, msg), flush=True)
def error(self, msg, **kwargs):
print('E:%s:%s' % (self.context, msg), flush=True)
def warning(self, msg, **kwargs):
print('W:%s:%s' % (self.context, msg), flush=True)
def create_classification_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels):
#import tensorflow as tf
#import modeling
# 通過傳入的訓(xùn)練數(shù)據(jù),進(jìn)行representation
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
)
embedding_layer = model.get_sequence_output()
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
if is_training:
# I.e., 0.1 dropout
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
probabilities = tf.nn.softmax(logits, axis=-1)
log_probs = tf.nn.log_softmax(logits, axis=-1)
if labels is not None:
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
else:
loss, per_example_loss = None, None
return (loss, per_example_loss, logits, probabilities)
def init_predict_var(path):
label2id_file = os.path.join(path, 'label2id.pkl')
if os.path.exists(label2id_file):
with open(label2id_file, 'rb') as rf:
label2id = pickle.load(rf)
id2label = {value: key for key, value in label2id.items()}
num_labels = len(label2id.items())
return num_labels, label2id, id2label
def optimize_class_model(args, logger=None):
if not logger:
logger = set_logger(colored('CLASSIFICATION_MODEL, Lodding...', 'cyan'), args.verbose)
pass
try:
# 如果PB文件已經(jīng)存在則,返回PB文件的路徑,否則將模型轉(zhuǎn)化為PB文件,并且返回存儲(chǔ)PB文件的路徑
if args.model_pb_dir is None:
tmp_file = args.model_dir
else:
tmp_file = args.model_pb_dir
pb_file = os.path.join(tmp_file, 'classification_model.pb')
if os.path.exists(pb_file):
print('pb_file exits', pb_file)
return pb_file
#增加 從label2id.pkl中讀取num_labels, 這樣也可以不用指定num_labels參數(shù); 2019/4/17
if not args.num_labels:
num_labels, label2id, id2label = init_predict_var()
else:
num_labels = args.num_labels
graph = tf.Graph()
with graph.as_default():
with tf.Session() as sess:
input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')
bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json'))
loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False,
input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
# pred_ids = tf.argmax(probabilities, axis=-1, output_type=tf.int32, name='pred_ids')
# pred_ids = tf.identity(pred_ids, 'pred_ids')
probabilities = tf.identity(probabilities, 'pred_prob')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
latest_checkpoint = tf.train.latest_checkpoint(args.model_dir)
logger.info('loading... %s ' % latest_checkpoint )
saver.restore(sess,latest_checkpoint )
logger.info('freeze...')
from tensorflow.python.framework import graph_util
tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob'])
logger.info('predict cut finished !!!')
# 存儲(chǔ)二進(jìn)制模型到文件中
logger.info('write graph to a tmp file: %s' % pb_file)
with tf.gfile.GFile(pb_file, 'wb') as f:
f.write(tmp_g.SerializeToString())
return pb_file
except Exception as e:
logger.error('fail to optimize the graph! %s' % e, exc_info=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Trans ckpt file to .pb file')
parser.add_argument('-bert_model_dir', type=str, required=True,
help='chinese google bert model path')
parser.add_argument('-model_dir', type=str, required=True,
help='directory of a pretrained BERT model')
parser.add_argument('-model_pb_dir', type=str, default=None,
help='directory of a pretrained BERT model,default = model_dir')
parser.add_argument('-max_seq_len', type=int, default=128,
help='maximum length of a sequence,default:128')
parser.add_argument('-num_labels', type=int, default=None,
help='length of all labels,default=2')
parser.add_argument('-verbose', action='store_true', default=False,
help='turn on tensorflow logging for debug')
args = parser.parse_args()
optimize_class_model(args, logger=None)
運(yùn)行命令:
python freeze_graph.py \
-bert_model_dir="bert預(yù)訓(xùn)練模型地址" \
-model_dir="模型輸出地址(和上邊模型訓(xùn)練輸出地址一樣即可)" \
-max_seq_len=128 \ # 序列長度, 需要與訓(xùn)練時(shí) max_seq_length 參書相同
-num_labels=3 # label數(shù)量
運(yùn)行后會(huì)在輸出文件夾中多出一個(gè) classification_model.pb 文件, 就是壓縮后的模型, 之后就可以開始部署了
服務(wù)部署用到別人寫好的開源框架BERT-BiLSTM-CRF-NER, 雖然叫NER, 但是可以用來部署分類任務(wù)的bert模型哈, 下載安裝
pip install bert-base==0.0.7 -i https://pypi.python.org/simple
為了方便我們同樣將部署過程寫到腳本中
bert-base-serving-start \
-model_dir "訓(xùn)練好的模型路徑" \
-bert_model_dir "bert預(yù)訓(xùn)練模型路徑" \
-model_pb_dir "classification_model.pb文件路徑" \
-mode CLASS \ # 模式, 咱們是分類所以用CLASS
-max_seq_len 128 \ # 序列長度與上邊保持一致
-port 7006 \ # 端口號(hào), 不要與其他程序沖突
-port_out 7007 # 端口號(hào)
具體的安裝使用方法及參數(shù)含義可以查看BERT-BiLSTM-CRF-NER及bert-as-service
啟動(dòng)服務(wù)后我們會(huì)看到如下log信息
usage: xxxx/bin/bert-base-serving-start -model_dir xxxx/model -bert_model_dir xxxx/bert_model -model_pb_dir xxxxx/model -mode CLASS -max_seq_len 128 -port 7006 -port_out 7007
ARG VALUE
__________________________________________________
bert_model_dir = xxxx
ckpt_name = bert_model.ckpt
config_name = bert_config.json
cors = *
cpu = False
device_map = []
fp16 = False
gpu_memory_fraction = 0.5
http_max_connect = 10
http_port = None
mask_cls_sep = False
max_batch_size = 1024
max_seq_len = 128
mode = CLASS
model_dir = xxxxx
model_pb_dir = xxxxx
num_worker = 1
pooling_layer = [-2]
pooling_strategy = REDUCE_MEAN
port = 7006
port_out = 7007
prefetch_size = 10
priority_batch_size = 16
tuned_model_dir = None
verbose = False
xla = False
I:VENTILATOR:[__i:__i:104]:lodding classification predict, could take a while...
I:VENTILATOR:[__i:__i:111]:contain 0 labels:dict_values(['0', '1', '2'])
pb_file exits xxxx/model/classification_model.pb
I:VENTILATOR:[__i:__i:114]:optimized graph is stored at: xxxxx/model/classification_model.pb
I:VENTILATOR:[__i:_ru:148]:bind all sockets
I:VENTILATOR:[__i:_ru:153]:open 8 ventilator-worker sockets, ipc://tmp0cZQ9R/socket,ipc://tmp6uxbcD/socket,ipc://tmpu7Xxeo/socket,ipc://tmpsF2Ug9/socket,ipc://tmpMJTkjU/socket,ipc://tmpkvoLlF/socket,ipc://tmpefSdoq/socket,ipc://tmpW60Iqb/socket
I:VENTILATOR:[__i:_ru:157]:start the sink
I:VENTILATOR:[__i:_ge:239]:get devices
I:SINK:[__i:_ru:317]:ready
I:VENTILATOR:[__i:_ge:271]:device map:
worker 0 -> gpu 0
I:WORKER-0:[__i:_ru:497]:use device gpu: 0, load graph from xxxx/model/classification_model.pb
I:WORKER-0:[__i:gen:537]:ready and listening!
bert服務(wù)部署完成
使用示例
In [1]: from bert_base.client import BertClient
In [2]: str1="我愛北京天安門"
In [3]: str2 = "哈哈哈哈"
In [4]: with BertClient(show_server_config=False, check_version=False, check_length=False,
...: mode="CLASS", port=7006, port_out=7007) as bc:
...: res = bc.encode([str1, str2])
...:
In [5]: print(res)
[{'pred_label': ['2', '1'], 'score': [0.9999899864196777, 0.9999299049377441]}]
參考鏈接:
BERT源碼注釋(run_classifier.py) - 全網(wǎng)最詳細(xì)
干貨 | BERT fine-tune 終極實(shí)踐教程
NLP之BERT分類模型部署提供服務(wù)