Hugging Face 加載數(shù)據(jù)集與 BERT 模型初探

一、數(shù)據(jù)集介紹與加載

本次使用的數(shù)據(jù)集來(lái)自 Hugging Face 平臺(tái),是一個(gè)中文商品評(píng)價(jià)數(shù)據(jù)集,包含正向(positive)和負(fù)向(negative)兩類(lèi)標(biāo)簽,適用于文本分類(lèi)任務(wù)。數(shù)據(jù)集通過(guò)人工標(biāo)注完成,標(biāo)注過(guò)程中對(duì)模糊語(yǔ)句進(jìn)行了單獨(dú)處理與審核,確保數(shù)據(jù)質(zhì)量。

數(shù)據(jù)集包含三個(gè)部分:

  • 訓(xùn)練集:9600 條文本/標(biāo)簽對(duì)
  • 驗(yàn)證集:1200 條
  • 測(cè)試集:1200 條

標(biāo)簽為 0(負(fù)向)和 1(正向),例如:

  • 負(fù)向示例:“入住的時(shí)候剛裝修味道很濃,房間內(nèi)的床墊坐上去好像有點(diǎn)不穩(wěn)?!?→ 標(biāo)簽 0
  • 正向示例:“趕上五一促銷(xiāo)入手,做工還行,接口夠用,價(jià)格能接受,電池強(qiáng)勁…” → 標(biāo)簽 1

第一步:在線加載

from datasets import load_dataset

ds = load_dataset("lansinuote/ChnSentiCorp", cache_dir=cache_dir)

這種方式會(huì)默認(rèn)從 Hugging Face 平臺(tái)下載數(shù)據(jù)集到本地,且必須聯(lián)網(wǎng),即使本地已經(jīng)有有數(shù)據(jù)集,也需要聯(lián)網(wǎng),但不會(huì)重復(fù)下載。數(shù)據(jù)集格式為 Arrow(平臺(tái)自定義的加密格式),無(wú)法直接查看,可通過(guò)代碼訪問(wèn)。

第二步:本地加載(推薦)

在第一步中已把數(shù)據(jù)集已下載到本地,然后可使用 Dataset.from_file 加載,并使用 PyTorch 創(chuàng)建自定義數(shù)據(jù)集:

  1. PyTorch 中自定 Dataset 類(lèi)必須實(shí)現(xiàn) __init__(self)、 __getitem__(self, index) 以及 __len__(self) 三個(gè)方法;
  2. __init__(self) 方法中首先加載父類(lèi)的方法,然后通過(guò) Dataset.from_file 加載數(shù)據(jù)集;
  3. __getitem__(self, index) 返回指定索引的數(shù)據(jù),這里返回評(píng)論的文本和標(biāo)簽。
from datasets import Dataset
from torch.utils import data

class MyDataset(data.Dataset):
    def __init__(self):
        super().__init__()
        self.train_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-train.arrow')
        self.validation_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-validation.arrow')
        self.test_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-test.arrow')

    def __len__(self):
        return len(self.train_dataset)

    def __getitem__(self, item):
        text = self.train_dataset[item]["text"]
        label = self.train_dataset[item]["label"]
        return text, label

if __name__ == '__main__':
    data = MyDataset()
    for d in data:
        print(d)

注意:

  • 路徑應(yīng)為包含 dataset_info.json 的根目錄。
  • 使用絕對(duì)路徑,并建議在路徑字符串前加 r 防止轉(zhuǎn)義。

加載后可查看數(shù)據(jù)集信息:

{'text': '酒店的位置不錯(cuò),附近都靠近購(gòu)物中心和寫(xiě)字樓區(qū)。以前來(lái)大連一直都住,但感覺(jué)比較陳舊了。住的期間,酒店在進(jìn)行裝修,翻新和升級(jí)房間設(shè)備。好是好,希望到時(shí)房?jī)r(jià)別漲太多了。', 'label': 1}
{'text': '位置不很方便,周?chē)鷣y哄哄的,衛(wèi)生條件也不如其他如家的店。以后絕不會(huì)再住在這里。', 'label': 0}
{'text': '抱著很大興趣買(mǎi)的,買(mǎi)來(lái)粗粗一翻排版很不錯(cuò),姐姐還說(shuō)快看吧,如果好我也買(mǎi)一本??墒钦娴目戳?,實(shí)在不怎么樣。就是中文里夾英文單詞說(shuō)話(huà),才翻了2頁(yè)實(shí)在不想勉強(qiáng)自己了。我想說(shuō)的是,練習(xí)英文單詞,靠這本書(shū)肯定沒(méi)有效果,其它好的方法比這強(qiáng)多了。', 'label': 0}
{'text': '東西不錯(cuò),不過(guò)有人不太喜歡鏡面的,我個(gè)人比較喜歡,總之還算滿(mǎn)意。', 'label': 1}
{'text': '房間不錯(cuò),只是上網(wǎng)速度慢得無(wú)法忍受,打開(kāi)一個(gè)網(wǎng)頁(yè)要等半小時(shí),連郵件都無(wú)法收。另前臺(tái)工作人員服務(wù)態(tài)度是很好,只是效率有得改善。', 'label': 1}
{'text': '挺失望的,還不如買(mǎi)一本張愛(ài)玲文集呢,以<色戒>命名,可這篇文章僅僅10多頁(yè),且無(wú)頭無(wú)尾的,完全比不上里面的任意一篇其它文章.', 'label': 0}

二、模型和分詞器的本地加載

2.1 模型來(lái)源

同樣通過(guò) Hugging Face 平臺(tái)獲取 bert-base-chinese,這是一個(gè)中文文本分類(lèi)模型。

下載模型的同時(shí)需要下載對(duì)應(yīng)的分詞器,原因在于模型不能直接識(shí)別文字,分詞器的作用是把每個(gè)文字轉(zhuǎn)為模型可識(shí)別的數(shù)字,再把數(shù)字輸入給模型,不同的模型有不同的分詞器,因此模型與分詞器必須匹配。

from transformers import AutoModel, BertTokenizerFast

model_name = 'ckiplab/bert-base-chinese'
cache_dir = 'model'
# 下載分詞器
BertTokenizerFast.from_pretrained('bert-base-chinese', cache_dir=cache_dir)
print("tokenizer done")
# 下載模型
AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
print("model done")

平臺(tái)上的模型大多存儲(chǔ)在 Google 云盤(pán),下載時(shí)需聯(lián)網(wǎng),但可緩存到本地供后續(xù)離線使用。

from transformers import BertModel, BertTokenizerFast
# 從本地加載模型和分詞器
tokenizer = BertTokenizerFast.from_pretrained(r'/Users/Desktop/huggingface/model/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea')
model = BertModel.from_pretrained(r'/Users/Desktop/huggingface/model/models--ckiplab--bert-base-chinese/snapshots/efe27bb4a9373384e0120ffe1cf327714ceb61bf')
print(model)

2.2 模型結(jié)構(gòu)簡(jiǎn)介

打印 BERT 模型,其結(jié)構(gòu)主要分為三部分:

  1. Embedding 層:將輸入的位置編碼(即分詞后的數(shù)字序列)轉(zhuǎn)換為 768 維的詞向量。
  2. Encoder 層:由多層 Transformer 編碼器組成,用于提取特征。BERT-base 包含 12 層 Encoder,這是 Transformer 模型有效的最低層數(shù)要求。
  3. Pooler 層:由一個(gè)全聯(lián)接層和一個(gè)激活函數(shù)組成,其中全聯(lián)接層的輸出維度為 768,這個(gè)很重要,關(guān)系到后面模型的搭建。

三、模型設(shè)計(jì)

定義模型時(shí),只需要定義全連接層,再追加到 BERT 模型的輸出。因此整體流程是:BERT —> 自定義全連接層 —> softmax。

  • BERT 模型:作為通用語(yǔ)言理解基座,提供高質(zhì)量的文本特征表示
  • 自定義全連接層:作為專(zhuān)用任務(wù)頭,將通用特征映射到具體業(yè)務(wù)分類(lèi)
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # in_features=768 是BERT的輸出維度,out_features=2 是分類(lèi)數(shù)
        self.fc = torch.nn.Linear(in_features=768, out_features=2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # BERT 模型不需要訓(xùn)練
        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 取[CLS]位置的隱藏狀態(tài)
        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)
        return out

BERT 的輸入?yún)?shù)包含:

  • input_ids: 將文本轉(zhuǎn)換成的token索引序列。例如“我喜歡自然語(yǔ)言處理” → [101, 2769, 4263, 3315, 4476, 1566, 5543, 102]
  • attention_mask: 標(biāo)識(shí)哪些token是有效內(nèi)容,哪些是填充(padding)。因?yàn)?BERT 要求輸入是固定長(zhǎng)度的,所以對(duì)于短于最大長(zhǎng)度的序列,用 0 進(jìn)行填充,會(huì)忽略這些填充詞。
  • token_type_ids: 區(qū)分句子對(duì)中的第一個(gè)句子和第二個(gè)句子

BERT 的這些輸入將在 Dataloader 中,使用分詞器提取。

BERT 的輸出參數(shù)包含:

  • last_hidden_state:最后一層的隱藏狀態(tài) [batch_size, seq_len, hidden_size]
  • pooler_output:[CLS] token 經(jīng)過(guò)線性層和 tanh 激活后的表示 [batch_size, hidden_size]

取 [CLS] 位置的隱藏狀態(tài),是因?yàn)?[CLS] 包含了整個(gè)句子的信息,并且預(yù)訓(xùn)練任務(wù)中已經(jīng)訓(xùn)練它做類(lèi)似的任務(wù)。其他 token 的表示則更多關(guān)注局部信息,不適合直接用于整個(gè)句子的分類(lèi)。


四、創(chuàng)建自定義 Dataloader

Dataloader 需要去從 Dataset 讀取數(shù)據(jù),它提供了一種簡(jiǎn)便的方式來(lái)迭代數(shù)據(jù)集。

這里需要設(shè)置批處理數(shù)據(jù)讀取,有一個(gè)很重要的操作就是要先定義一個(gè) collate_fn 函數(shù),在這個(gè)函數(shù)中我們將 Dataset 原始文本通過(guò)分詞器(tokenizer)進(jìn)行詞向量的轉(zhuǎn)換,轉(zhuǎn)換為模型可理解的參數(shù)輸入給模型。

collate_fn 是一個(gè)可選參數(shù),允許用戶(hù)自定義如何將多個(gè)數(shù)據(jù)樣本合并成一個(gè) batch,由于 bert 模型需要三個(gè)入?yún)ⅲ╥nput_ids, attention_mask, token_type_ids),因此需要把數(shù)據(jù)集通過(guò)分詞器轉(zhuǎn)化為這三個(gè)參數(shù)。

import torch
from my_dataset import MyDataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast

# 加載分詞器
token_path = r'/Users/Desktop/huggingface/model/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea'
tokenizer = BertTokenizerFast.from_pretrained(token_path)

def collate_fn(data):
    sentes = [i[0] for i in data]
    label = [i[1] for i in data]
    data = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=sentes,
        truncation=True,
        padding='max_length',
        max_length=512,
        return_tensors='pt', # 輸出數(shù)據(jù)將作為 PyTorch 張量返回,而不是 NumPy 數(shù)組或其他格式
        return_length=True
    )
    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    label = torch.LongTensor(label)
    return input_ids, attention_mask, token_type_ids, label

# 初始化數(shù)據(jù)集
train_data = MyDataset()
# 數(shù)據(jù)加載
train_loader = DataLoader(
    dataset=train_data,
    batch_size=32,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

五、模型訓(xùn)練

定義模型、優(yōu)化器、損失函數(shù),開(kāi)啟訓(xùn)練模式,在訓(xùn)練批次中,從 train_loader 加載數(shù)據(jù),并把數(shù)據(jù)(input_ids, attention_mask, token_type_ids)傳入給模型。

model = Model().to('cpu')
optimizer = AdamW(model.parameters(), lr=0.001)
loss_func = torch.nn.CrossEntropyLoss()

model.train()

for epoch in range(10):
    for i, (input_ids, attention_mask, token_type_ids, label) in enumerate(train_loader):
        input_ids, attention_mask, token_type_ids, label = input_ids.to('cpu'), attention_mask.to('cpu'), token_type_ids.to('cpu'), label.to('cpu')

        out = model(input_ids, attention_mask, token_type_ids)
        loss = loss_func(out, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 5 == 0:
            out = out.argmax(dim=1)
            acc = (out == label).sum().item() / len(label)
            print(f'epoch={epoch}, i={i}, loss={loss.item()}, acc={acc}')

    torch.save(model.state_dict(), f'param/{epoch}bert.pt')
    print(epoch, 'save success')

開(kāi)始訓(xùn)練,打印準(zhǔn)確率:

epoch = 0, i=0, loss=0.7204890847206116, acc=0.4375
epoch = 0, i=5, loss=0.6873999834060669, acc=0.65625
epoch = 0, i=10, loss=0.5797467827796936, acc=0.78125
epoch = 0, i=15, loss=0.6594133377075195, acc=0.625
epoch = 0, i=20, loss=0.6064494252204895, acc=0.71875
epoch = 0, i=25, loss=0.649083137512207, acc=0.625
epoch = 0, i=30, loss=0.5833480358123779, acc=0.75
epoch = 0, i=35, loss=0.5145970582962036, acc=0.875
……

六、模型預(yù)測(cè)

預(yù)測(cè)的流程與訓(xùn)練流程類(lèi)似,也需要把輸入的文字先通過(guò)分詞器轉(zhuǎn)為模型能理解的數(shù)字,再輸入給模型。

def collate_fn2(data):
    sentes = []
    sentes.append(data)
    data = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=sentes,
        truncation=True,
        padding='max_length',
        max_length=512,
        return_tensors='pt',
        return_length=True
    )
    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]

    return input_ids, attention_mask, token_type_ids


def test():
    name = ["差評(píng)", "好評(píng)"]
    model.load_state_dict(torch.load("param/1bert.pt"))
    model.eval()
    while True:
        data = input("input content: ")
        input_ids, attention_mask, token_type_ids = collate_fn2(data)
        input_ids, attention_mask, token_type_ids = input_ids.to('cpu'), attention_mask.to('cpu'), token_type_ids.to('cpu')
        with torch.no_grad():
            out = model(input_ids, attention_mask, token_type_ids)
            out = out.argmax(dim=1)
            print(f"result={name[out]}\n")
input content: 5月8日付款成功,當(dāng)當(dāng)網(wǎng)顯示5月10日發(fā)貨,可是至今還沒(méi)看到貨物,也沒(méi)收到任何通知,簡(jiǎn)不知怎么說(shuō)好?。。?result=差評(píng)

input content: 下次不會(huì)再買(mǎi)了,不喜歡
result=差評(píng)

input content: 質(zhì)量挺不錯(cuò)的,下次再來(lái)
result=好評(píng)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容