15 分鐘搭建一個(gè)基于XLNET的文本分類模型——keras實(shí)戰(zhàn)

今天筆者將簡要介紹一下后bert 時(shí)代中一個(gè)又一比較重要的預(yù)訓(xùn)練的語言模型——XLNET ,下圖是XLNET在中文問答數(shù)據(jù)集CMRC 2018數(shù)據(jù)集(哈工大訊飛聯(lián)合實(shí)驗(yàn)室發(fā)布的中文機(jī)器閱讀理解數(shù)據(jù),形式與SQuAD相同)上的表現(xiàn)。我們可以看到XLNET的實(shí)力略勝于BERT。

XLNET 的一些表現(xiàn)

這里筆者會(huì)先簡單地介紹一下XLNET精妙的算法設(shè)計(jì),當(dāng)然我盡量采用通俗的語言去表達(dá)那些深?yuàn)W的數(shù)學(xué)表達(dá)式,整個(gè)行文過程會(huì)直接采用原論文的行文流程:Observition—>Motivition—>Contribution。然后我會(huì)介紹一下如何用python在15分鐘之內(nèi)搭建一個(gè)基于XLNET的文本分類模型。

XLNET的原理

Observision

XLNET的原論文將預(yù)訓(xùn)練的語言模型分為兩類:

1. 自回歸:根據(jù)上文預(yù)測下文將要出現(xiàn)的單詞,讓模型在預(yù)訓(xùn)練階段去做補(bǔ)充句子任務(wù),其中代表模型就是GPT。

把句子補(bǔ)充完整。

1.秋天到了,________________________。

2 .自編碼:根據(jù)上下文去預(yù)測中間詞,其實(shí)就是讓模型在預(yù)訓(xùn)練階段去做完形填空任務(wù),其中的代表模型就是BERT,在實(shí)際預(yù)訓(xùn)練過程中用[MASK]這個(gè)字符替代要被預(yù)測的目標(biāo)單詞。

請(qǐng)選擇最合適的詞填入空缺處

1.這種夢(mèng)說明你正在______挑戰(zhàn),但你還沒有做好準(zhǔn)備。
A.面臨 B.逃避 C.參考 D.承擔(dān)

但這兩種語言模型都有各自的不足之處,如下圖所示:

  • 自回歸語言模型:由于是標(biāo)準(zhǔn)的單向語言模型,無法看到下文的信息,可能導(dǎo)致信息利用不充分的問題,實(shí)際的實(shí)驗(yàn)結(jié)果也證明利用上下文信息的BERT的效果要強(qiáng)于只利用上文信息的GPT。
  • 自編碼語言模型:1.預(yù)訓(xùn)練階段(存在[Mask] 字符)和finetuning 階段(無[Mask] 字符)文本分布不一致,會(huì)影響到下游任務(wù)的funetuning效果。2.被預(yù)測的token(詞或者字)之間彼此獨(dú)立,沒有任何語義關(guān)聯(lián),


    兩種語言模型

Motivation

  • 解決自回歸語言模型無法獲取下文信息(預(yù)知未來)的問題
  • 優(yōu)化自編碼語言模型存在的兩個(gè)缺點(diǎn):1.預(yù)訓(xùn)練階段(存在[Mask] 字符)和finetuning 階段(無[Mask] 字符)文本分布不一致,2.被預(yù)測Mask詞之間彼此獨(dú)立,不符合人的認(rèn)知。

Contribution

有了上面兩個(gè)Motivation,作者就提出了XLNET去解決這兩個(gè)問題,這里我不會(huì)引入公式,只是簡單的解釋一下XLNET的三個(gè)比較創(chuàng)新的設(shè)計(jì)。

  • Permutation Language Model:這個(gè)部分是XLNET最精彩的設(shè)計(jì)。主要目的就是為了解決單向語言模型無法利用下文信息的缺陷,如下圖所示,其實(shí)就是對(duì)句子的順序做隨機(jī)打亂后,然后作為單向語言模型的輸入。預(yù)訓(xùn)練時(shí)需預(yù)測當(dāng)前位置的token。我們來詳細(xì)體會(huì)一下 Permutation Language Model的過程。
    1 .假設(shè)我們有一個(gè)序列[1,2,3,4],預(yù)測目標(biāo)是3。
    2 .先對(duì)該序列進(jìn)行因式分解,對(duì)句子的順序做隨機(jī)打亂,最終會(huì)有24種排列方式,下圖是其中可能的四種情況。
    3 .其中右上的圖中,3的左邊還包括了2與4,以為著我們?cè)陬A(yù)測3的時(shí)候看到3這個(gè)token上下文的信息,同時(shí)依然保持著單向的語言模型。

    這部分的設(shè)計(jì)就是為了解決之前的自回歸和自編碼語言模型的缺陷,也就是Motivation中的兩個(gè)問題。所以這個(gè)設(shè)計(jì)是整片文章最精妙的部分。

接下來的兩個(gè)設(shè)計(jì)主要為了解決構(gòu)建Permutation Language Model帶來的兩個(gè)問題:
1.序列進(jìn)行因式分解使得token失去位置感。
2.序列進(jìn)行因式分解使得進(jìn)行預(yù)測時(shí)模型能夠提前看到答案。

Permutation Language Model
  • Reparameterization with positions:由于Permutation Language Model對(duì)句子順序做了隨機(jī)打亂,可能會(huì)導(dǎo)致模型失去對(duì)某個(gè)token在句子中的原始位置的感知,導(dǎo)致模型退化成一個(gè)詞袋模型,于是作者利用Reparameterization with positions去解決這個(gè)問題。

  • Two-stream attention
    這部分的設(shè)計(jì)依然是為了解決Permutation Language Model留下的另外一個(gè)問題,細(xì)心的讀者會(huì)發(fā)現(xiàn)序列因式分解會(huì)使得被預(yù)測的token被提起預(yù)知,回到上圖:有一個(gè)序列[1,2,3,4],預(yù)測目標(biāo)是3??墒瞧渲杏袀€(gè)因式分解的順序是[3,2,4,1] 導(dǎo)致預(yù)測3的時(shí)候提前拿到了3的信息,這樣就使得預(yù)訓(xùn)練的過程變得無意義了,于是作者設(shè)計(jì)了Two-stream attention:

    1.一個(gè)query stream只編碼目標(biāo)單詞以外的上下文的信息以及目標(biāo)單詞的位置信息。
    2.一個(gè)content stream既編碼目標(biāo)單詞自己的信息,又編碼上下文的信息供。

    query stream 的作用就是避免了模型在預(yù)訓(xùn)練時(shí)提前預(yù)知答案。 而做下游任務(wù) fine-tune 時(shí)將query stream去掉,這樣就完美的解決了這個(gè)問題。

下圖就是文章作者的PPT中對(duì)XLNET的總結(jié)。具體如果實(shí)作的建議讀者去看看原文和作者對(duì)XLNET的講解視頻,需要在提一嘴的是XLNET的網(wǎng)絡(luò)架構(gòu)使用了transfomer-xl,這都使得XLNET在長文本的處理能力變得更強(qiáng)。

XLNET的行文邏輯

XLNET實(shí)戰(zhàn)部分

好的,經(jīng)過了不是特別深刻的XLNET原理簡介,來到我們激動(dòng)人心的實(shí)戰(zhàn)部分。(其實(shí)上文原理部分只是希望大家在感性上知道XLNET到底做了些什么,至于如何做到的還請(qǐng)拜讀原文和源代碼)。

準(zhǔn)備工作

定義一下超參數(shù)和預(yù)訓(xùn)練模型的路徑

import os
import sys
from collections import namedtuple
import numpy as np
import pandas as pd
from keras_xlnet.backend import keras
from keras_bert.layers import Extract
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI
from keras_radam import RAdam

### 預(yù)訓(xùn)練模型的路徑
pretrained_path  = "/opt/developer/wp/xlnet/pretrain"
EPOCH = 10
BATCH_SIZE = 16
SEQ_LEN = 256

MODEL_NAME = 'xlnet_cls.h5'
PretrainedPaths = namedtuple('PretrainedPaths', ['config', 'model', 'vocab'])

config_path = os.path.join(pretrained_path, 'xlnet_config.json')
model_path = os.path.join(pretrained_path, 'xlnet_model.ckpt')
vocab_path = os.path.join(pretrained_path, 'spiece.model')
paths = PretrainedPaths(config_path, model_path, vocab_path)
tokenizer = Tokenizer(paths.vocab)

數(shù)據(jù)準(zhǔn)備

這里我選擇了一個(gè)語義相似度判斷的任務(wù)——及判斷兩個(gè)問題是否是同一個(gè)問題。數(shù)據(jù)集如下圖所示:


數(shù)據(jù)集

如上圖所示:數(shù)據(jù)集共2萬條,主要任務(wù)就是判斷question1 和 question2 是否為同一問題。label為1則是同一個(gè)問題,label為0則不是同一個(gè)問題。其中還有列標(biāo)注了問題的類型。

在由于XLNET的輸入為單輸入,于是筆者將數(shù)據(jù)預(yù)處理成 "問題類型:'question1是否和question2是同一個(gè)問題?"。舉個(gè)例了,上圖數(shù)據(jù)集的第一條數(shù)據(jù)最終變成如下格式:

  • data: aids:艾滋病窗口期會(huì)出現(xiàn)腹瀉癥狀嗎是否和頭疼腹瀉四肢無力是不是艾滋病是同一個(gè)問題?
  • label:0
# Read data
class DataSequence(keras.utils.Sequence):

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return (len(self.y) + BATCH_SIZE - 1) // BATCH_SIZE

    def __getitem__(self, index):
        s = slice(index * BATCH_SIZE, (index + 1) * BATCH_SIZE)
        return [item[s] for item in self.x], self.y[s]


def generate_sequence(df):
    tokens, classes = [], []
    for _, row in df.iterrows():
        ###這里筆者將數(shù)據(jù)進(jìn)行拼接  類型+問題1+問題2
        text, cls = row["category"] + ":"+ row['question1'] + "是否和" + row['question2'] + "是同一個(gè)問題?", row['label']
        encoded = tokenizer.encode(text)[:SEQ_LEN - 1]
        encoded = [tokenizer.SYM_PAD] * (SEQ_LEN - 1 - len(encoded)) + encoded + [tokenizer.SYM_CLS]
        tokens.append(encoded)
        classes.append(int(cls))
    tokens, classes = np.array(tokens), np.array(classes)
    segments = np.zeros_like(tokens)
    segments[:, -1] = 1
    lengths = np.zeros_like(tokens[:, :1])
    return DataSequence([tokens, segments, lengths], classes)


### 讀取數(shù)據(jù),然后將數(shù)據(jù)
data_path = "/opt/developer/wp/xlnet/data/train.csv"
data = pd.read_csv(data_path)
test = data.sample(2000)
train = data.loc[list(set(data.index)-set(test.index))]
### 生成訓(xùn)練集和測試集
train_g = generate_sequence(train)
test_g = generate_sequence(test)

加載模型

加載事先下載好的xlnet的預(yù)訓(xùn)練語言模型,然后再接兩個(gè)dense層和一個(gè)bn層構(gòu)建一個(gè)類別為2的分類器。

# Load pretrained model
model = load_trained_model_from_checkpoint(
    config_path=paths.config,
    checkpoint_path=paths.model,
    batch_size=BATCH_SIZE,
    memory_len=0,
    target_len=SEQ_LEN,
    in_train_phase=False,
    attention_type=ATTENTION_TYPE_BI,
)

#### 加載預(yù)訓(xùn)練權(quán)重
# Build classification model
last = model.output
extract = Extract(index=-1, name='Extract')(last)
dense = keras.layers.Dense(units=768, name='Dense')(extract)
norm = keras.layers.BatchNormalization(name='Normal')(dense)
output = keras.layers.Dense(units=2, activation='softmax', name='Softmax')(norm)
model = keras.models.Model(inputs=model.inputs, outputs=output)
model.summary()

針對(duì)下游任務(wù)fine-tuning

接下來只需要定義好優(yōu)化器,學(xué)習(xí)率,損失函數(shù),評(píng)估函數(shù),以及一些回調(diào)函數(shù),就可以開始針對(duì)語義相似度判斷任務(wù)進(jìn)行模型微調(diào)了。

# 定義優(yōu)化器,loss和metrics
model.compile(
    optimizer=RAdam(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy'],
)
### 定義callback函數(shù),只保留val_sparse_categorical_accuracy 得分最高的模型
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint("./model/best_xlnet.h5", monitor='val_sparse_categorical_accuracy', verbose=1, save_best_only=True,
                            mode='max')

模型訓(xùn)練
model.fit_generator(
    generator=train_g,
    validation_data=test_g,
    epochs=EPOCH,
    callbacks=[checkpoint],
)

下發(fā)截圖是模型訓(xùn)練過程,看起來還不錯(cuò),loss下降和sparse_categorical_accuracy得分上升都很穩(wěn)定。

模型訓(xùn)練

至于最后效果,我只和bert做了一個(gè)簡單的對(duì)比。在這個(gè)任務(wù)上,xlnet相較于bert稍微弱了2個(gè)百分點(diǎn),可能是由于文本過短的原因或者我自己打開的方式不對(duì),對(duì)應(yīng)xlnet的finetune還是需要好好研究。最后的感覺就是深度學(xué)習(xí)的煉丹之路任重而道遠(yuǎn),大家且行且珍惜。

結(jié)語

預(yù)訓(xùn)練的語言模型成為NLP的熱點(diǎn),通過無監(jiān)督的預(yù)訓(xùn)練讓模型學(xué)習(xí)領(lǐng)域基礎(chǔ)知識(shí),之后專注在這個(gè)領(lǐng)域的某個(gè)下游任務(wù)上才能有所建樹。這個(gè)過程和我們大學(xué)教育培養(yǎng)模式很像,大學(xué)四年海納百川式的學(xué)習(xí)一些基本的科學(xué)知識(shí),研究生之后專注于某個(gè)方向深入研究。所以這個(gè)過程make sense。筆者認(rèn)為以下三個(gè)方向是預(yù)訓(xùn)練模型還可以更進(jìn)一步的方向:

  • 更好,更難的的無監(jiān)督預(yù)訓(xùn)練的任務(wù)(算法)
  • 更精妙的網(wǎng)絡(luò)設(shè)計(jì)(算法)
  • 更大,更高質(zhì)量的數(shù)據(jù)(算力)

參考文獻(xiàn)

xlnet原文
https://github.com/CyberZHG/keras-xlnet
https://github.com/ymcui/Chinese-PreTrained-XLNet
https://www.bilibili.com/video/av67474438?from=search&seid=12927135467023026176

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容