1. BERT簡介
BERT的全稱為Bidirectional Encoder Representation from Transformers,是一個預(yù)訓(xùn)練的語言表征模型。它強調(diào)了不再像以往一樣采用傳統(tǒng)的單向語言模型或者把兩個單向語言模型進行淺層拼接的方法進行預(yù)訓(xùn)練,而是采用新的masked language model(MLM),以致能生成深度的雙向語言表征。BERT論文發(fā)表時提及在11個NLP(Natural Language Processing,自然語言處理)任務(wù)中獲得了新的state-of-the-art的結(jié)果。
2. 環(huán)境配置
- Ubuntu16.04
- Anaconda3
- python >= 3.6
- tensorflow >= 1.12.0
- pandas
先安裝conda
# 查看conda環(huán)境
conda info -e
通過conda創(chuàng)建一個新的環(huán)境bert,切換到bert環(huán)境
# 切換到bert環(huán)境
conda activate bert
3. ChnSentiCorp數(shù)據(jù)集
我們選取ChnSentiCorp數(shù)據(jù)集,里面包含7000 多條酒店評論數(shù)據(jù),5000 多條正向評論,2000 多條負(fù)向評論,這些評論數(shù)據(jù)有兩個字段:label, review。
數(shù)據(jù)字段:
label:1表示正向評論,0表示負(fù)向評論
review:評論內(nèi)容
數(shù)據(jù)地址是:https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv
新建一個腳本split_data.py,拆分成訓(xùn)練集train.csv,開發(fā)集dev.csv,測試集test.csv,比例8:1:1。
import pandas as pd
df = pd.read_csv('ChnSentiCorp_htl_all.csv', dtype=str)
df = df.dropna()
df = df.applymap(lambda x: str(x).strip())
df = df.sample(frac=1).reset_index(drop=True)
# split train:dev:test as 8:1:1
train_df = df.iloc[:6212]
dev_df = df.iloc[6212:6989]
test_df = df.iloc[6989:]
train_df.to_csv('train.csv', sep=',', index=False)
dev_df.to_csv('dev.csv', sep=',', index=False)
test_df.to_csv('test.csv', sep=',', index=False)
腳本執(zhí)行完成后:
├── ChnSentiCorp_htl_all.csv
├── dev.csv
├── split_data.py
├── test.csv
└── train.csv
4. 下載BERT源碼和預(yù)訓(xùn)練模型
git clone https://github.com/google-research/bert.git
- 下載BERT中文預(yù)訓(xùn)練模型
下載地址: https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
解壓到自定義目錄下
├── bert_config.json
├── bert_model.ckpt.data-00000-of-00001
├── bert_model.ckpt.index
├── bert_model.ckpt.meta
└── vocab.txt
5. 修改代碼
在run_classifier.py文件中有一個基類DataProcessor類,其代碼如下:
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
在這個基類中定義了一個讀取文件的靜態(tài)方法_read_tsv,四個分別獲取訓(xùn)練集,驗證集,測試集和標(biāo)簽的方法。在run_classsifier.py文件中我們可以看到,google對于一些公開數(shù)據(jù)集已經(jīng)寫了一些processor,如XnliProcessor,MnliProcessor,MrpcProcessor和ColaProcessor。這給我們提供了一個很好的示例,指導(dǎo)我們?nèi)绾吾槍ψ约旱臄?shù)據(jù)集來寫processor。接下來我們要定義自己的數(shù)據(jù)處理的類,我們將新增的類命名為SentimentProcessor。
class SentimentProcessor(DataProcessor):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(
os.path.join(data_dir, "train.csv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(
os.path.join(data_dir, "dev.csv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(
os.path.join(data_dir, "test.csv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % (i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["0", "1"]
@classmethod
def _read_csv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter=",", quotechar=None)
lines = []
for line in reader:
lines.append(line)
return lines
在processors中增加SentimentProcessor
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"senti": SentimentProcessor,
}
6. 訓(xùn)練BERT模型
新建一個腳本文件train.sh,內(nèi)容如下,日志文件輸出到train.log,通過 tail -f train.log 查看,通過nvidia-smi命令查看GPU狀態(tài)。
參數(shù)說明:
data_dir: 訓(xùn)練數(shù)據(jù)的地址
task_name: processor的名字
vocab_file: 字典地址,用默認(rèn)提供的就可以了,當(dāng)然也可以自定義
bert_config_file: 配置文件
output_dir: 模型的輸出地址
do_train: 是否做fine-tuning,默認(rèn)為false,如果為true必須重寫獲取訓(xùn)練集的方法
do_eval: 是否運行驗證集,默認(rèn)為false,如果為true必須重寫獲取驗證集的方法
do_predict: 是否做預(yù)測,默認(rèn)為false,如果為true必須重寫獲取測試集的方法
#!/bin/bash
export BERT_BASE_DIR=bert-models/chinese_L-12_H-768_A-12
export MY_DATASET=data
export OUTPUT_PATH=output
export TASK_NAME=senti
nohup /home/peng/anaconda3/envs/bert/bin/python run_classifier.py \
--data_dir=$MY_DATASET \
--task_name=$TASK_NAME \
--output_dir=$OUTPUT_PATH \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--do_train=True \
--do_eval=True \
--do_predict=True \
--max_seq_length=128 \
--train_batch_size=16 \
--learning_rate=5e-5 \
--num_train_epochs=2.0 \
>train.log 2>&1 &
7. 訓(xùn)練結(jié)果
訓(xùn)練結(jié)果在自定義OUTPUT_PATH/eval_results.txt中,
eval_accuracy = 0.84942085
eval_loss = 0.3728643
global_step = 776
loss = 0.3766538
測試集的預(yù)測結(jié)果在OUTPUT_PATH/test_results.tsv中,
前5條數(shù)據(jù)格式如下,兩列數(shù)據(jù)分別表示[0, 1]概率:
0.012343313 0.9876567
0.9637287 0.03627124
0.3622907 0.6377093
0.0120654255 0.9879346
0.41722867 0.5827713
test.csv數(shù)據(jù)集中前5條如下:
