【轉(zhuǎn)載】Bert系列(一)——demo運(yùn)行

原文章鏈接:http://www.itdecent.cn/p/3d0bb34c488a

谷歌推出的Bert,最近有多火,估計(jì)做自然語(yǔ)言處理的都知道。據(jù)稱(chēng)在SQuAD等11項(xiàng)任務(wù)當(dāng)中達(dá)到了state of the art。bert的原理可參考論文,或者網(wǎng)上其他人翻譯的資料。谷歌已經(jīng)在github上開(kāi)源了代碼,相信每一個(gè)從事NLP的都應(yīng)該和我一樣摩拳擦掌,迫不及待地想要學(xué)習(xí)它了吧。

就我個(gè)人而言學(xué)習(xí)一個(gè)開(kāi)源項(xiàng)目,第一步是安裝,第二步是跑下demo,第三步才是閱讀源碼。安裝bert簡(jiǎn)單,直接github上拉下來(lái)就可以了,跑demo其實(shí)也不難,參照README.md一步步操作就行了,但是經(jīng)我實(shí)操過(guò)后,發(fā)現(xiàn)里面有個(gè)小坑,所以用這篇文章記錄下來(lái),供讀者參考。

閑言少敘,書(shū)歸正傳。本次介紹的demo只有兩個(gè),一個(gè)是基于MRPC(Microsoft Research Paraphrase Corpus )的句子對(duì)分類(lèi)任務(wù),一個(gè)是基于SQuAD語(yǔ)料的閱讀理解任務(wù)。run demo分為以下幾步:

1、下載bert源碼

這沒(méi)什么好說(shuō)的,直接clone

git clone https://github.com/google-research/bert.git

2、下載預(yù)訓(xùn)練模型

BERT-Base, Uncased

為什么選擇BERT-Base, Uncased這個(gè)模型呢?原因有三:1、訓(xùn)練語(yǔ)料為英文,所以不選擇中文或者多語(yǔ)種;2、設(shè)備條件有限,如果您的顯卡內(nèi)存小于16個(gè)G,那就請(qǐng)乖乖選擇base,不要折騰large了;3、cased表示區(qū)分大小寫(xiě),uncased表示不區(qū)分大小寫(xiě)。除非你明確知道你的任務(wù)對(duì)大小寫(xiě)敏感(比如命名實(shí)體識(shí)別、詞性標(biāo)注等)那么通常情況下uncased效果更好。

3、下載訓(xùn)練數(shù)據(jù):

(1)下載MRPC語(yǔ)料:

官網(wǎng)上指定的方式是通過(guò)跑腳本download_glue_data.py來(lái)下載 GLUE data 。指定數(shù)據(jù)存放地址為:glue_data, 下載任務(wù)為:MRPC,執(zhí)行(本篇中所有python3的命令同樣適用于python):

python3 download_glue_data.py --data_dir glue_data --tasks MRPC

原始文章中是使用python3的,我機(jī)器上的python版本是python2,所以直接使用其代碼會(huì)有問(wèn)題,我下面給出我改動(dòng)后的download_glue_data.py。所以其運(yùn)行命令就變成了:

python download_glue_data.py --data_dir glue_data --tasks MRPC
# download_glue_data.py 代碼
import os
import sys
import shutil
import argparse
import tempfile
import urllib
import zipfile
import codecs

TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
             "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
             "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
             "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
             "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
             "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
             "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
             "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
             "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
             "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
             "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}

MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'

def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    data_file = "%s.zip" % task
    urllib.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")

def format_mrpc(data_dir, path_to_data):
    print("Processing MRPC...")
    mrpc_dir = os.path.join(data_dir, "MRPC")
    if not os.path.isdir(mrpc_dir):
        os.mkdir(mrpc_dir)
    if path_to_data:
        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
    else:
        print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
        mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
        urllib.urlretrieve(MRPC_TRAIN, mrpc_train_file)
        urllib.urlretrieve(MRPC_TEST, mrpc_test_file)
    assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
    assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
    urllib.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))

    dev_ids = []
    file = os.path.join(mrpc_dir, "dev_ids.tsv")
    with codecs.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf-8") as ids_fh:
        for row in ids_fh:
            dev_ids.append(row.strip().split('\t'))

    with codecs.open(mrpc_train_file, encoding="utf-8") as data_fh, \
            codecs.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf-8") as train_fh, \
            codecs.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf-8") as dev_fh:
        header = data_fh.readline()
        train_fh.write(header)
        dev_fh.write(header)
        for row in data_fh:
            label, id1, id2, s1, s2 = row.strip().split('\t')
            if [id1, id2] in dev_ids:
                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
            else:
                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))

    with codecs.open(mrpc_test_file, encoding="utf-8") as data_fh, \
            codecs.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf-8") as test_fh:
        header = data_fh.readline()
        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
        for idx, row in enumerate(data_fh):
            label, id1, id2, s1, s2 = row.strip().split('\t')
            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
    print("\tCompleted!")

def download_diagnostic(data_dir):
    print("Downloading and extracting diagnostic...")
    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
        os.mkdir(os.path.join(data_dir, "diagnostic"))
    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
    urllib.urlretrieve(TASK2PATH["diagnostic"], data_file)
    print("\tCompleted!")
    return

def get_tasks(task_names):
    task_names = task_names.split(',')
    if "all" in task_names:
        tasks = TASKS
    else:
        tasks = []
        for task_name in task_names:
            assert task_name in TASKS, "Task %s not found!" % task_name
            tasks.append(task_name)
    return tasks

def main(arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
    parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
                        type=str, default='all')
    parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
                        type=str, default='')
    args = parser.parse_args(arguments)

    if not os.path.isdir(args.data_dir):
        os.mkdir(args.data_dir)
    tasks = get_tasks(args.tasks)

    for task in tasks:
        if task == 'MRPC':
            format_mrpc(args.data_dir, args.path_to_mrpc)
        elif task == 'diagnostic':
            download_diagnostic(args.data_dir)
        else:
            download_and_extract(task, args.data_dir)


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))

如果上述方法不行我找到了網(wǎng)友百度云的分享:
鏈接:https://pan.baidu.com/s/1-b4I3ocYhiuhu3bpSmCJ_Q
提取碼:z6mk

(2)下載SQuAD語(yǔ)料:

基本上沒(méi)什么波折,可以使用下面三個(gè)鏈接直接下載,放置于$SQUAD_DIR路徑下

4、run demo

(1) 基于MRPC語(yǔ)料的句子對(duì)分類(lèi)任務(wù)

訓(xùn)練:

設(shè)置環(huán)境變量,指定預(yù)訓(xùn)練模型文件和語(yǔ)料地址

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue_data

在bert源碼文件里執(zhí)行run_classifier.py,基于預(yù)訓(xùn)練模型進(jìn)行fine-tune

python run_classifier.py \
  --task_name=MRPC \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=/tmp/mrpc_output/

模型保存在output_dir, 驗(yàn)證結(jié)果為:

# 在單機(jī)上面跑我跑了大概3個(gè)小時(shí)。。。囧,有GPU還是用GPU
INFO:tensorflow:***** Eval results *****
INFO:tensorflow:  eval_accuracy = 0.86519605
INFO:tensorflow:  eval_loss = 0.40176657
INFO:tensorflow:  global_step = 343
INFO:tensorflow:  loss = 0.40176657

預(yù)測(cè):

指定fine-tune之后模型文件所在地址

export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier

執(zhí)行以下語(yǔ)句完成預(yù)測(cè)任務(wù),預(yù)測(cè)結(jié)果輸出在output_dir文件夾中

python run_classifier.py \
  --task_name=MRPC \
  --do_predict=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=/tmp/mrpc_output/

(2)基于SQuAD語(yǔ)料的閱讀理解任務(wù)

設(shè)置為語(yǔ)料所在文件夾為$SQUAD_DIR

python run_squad.py \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --do_train=True \
  --train_file=$SQUAD_DIR/train-v1.1.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v1.1.json \
  --train_batch_size=12 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=/tmp/squad_base/

在output_dir文件夾下會(huì)輸出一個(gè)predictions.json文件,執(zhí)行:

python3 $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json predictions.json

看到以下結(jié)果,說(shuō)明執(zhí)行無(wú)誤:

{"f1": 88.41249612335034, "exact_match": 81.2488174077578}

5、總結(jié):

本篇內(nèi)容主要解決了以下兩個(gè)問(wèn)題:

(1) 基于MRPC語(yǔ)料的句子對(duì)分類(lèi)任務(wù)和基于SQuAD語(yǔ)料的閱讀理解任務(wù)的demo執(zhí)行,主要是翻譯源碼中README.md的部分內(nèi)容;

(2) 對(duì)于部分語(yǔ)料無(wú)法下載的情況,提供了其他的搜集方式。系列后續(xù)將對(duì)bert源碼進(jìn)行解讀,敬請(qǐng)關(guān)注

Reference
1.https://github.com/google-research/bert

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容