一、數(shù)據(jù)集介紹與加載
本次使用的數(shù)據(jù)集來(lái)自 Hugging Face 平臺(tái),是一個(gè)中文商品評(píng)價(jià)數(shù)據(jù)集,包含正向(positive)和負(fù)向(negative)兩類(lèi)標(biāo)簽,適用于文本分類(lèi)任務(wù)。數(shù)據(jù)集通過(guò)人工標(biāo)注完成,標(biāo)注過(guò)程中對(duì)模糊語(yǔ)句進(jìn)行了單獨(dú)處理與審核,確保數(shù)據(jù)質(zhì)量。
數(shù)據(jù)集包含三個(gè)部分:
- 訓(xùn)練集:9600 條文本/標(biāo)簽對(duì)
- 驗(yàn)證集:1200 條
- 測(cè)試集:1200 條
標(biāo)簽為 0(負(fù)向)和 1(正向),例如:
- 負(fù)向示例:“入住的時(shí)候剛裝修味道很濃,房間內(nèi)的床墊坐上去好像有點(diǎn)不穩(wěn)?!?→ 標(biāo)簽 0
- 正向示例:“趕上五一促銷(xiāo)入手,做工還行,接口夠用,價(jià)格能接受,電池強(qiáng)勁…” → 標(biāo)簽 1
第一步:在線加載
from datasets import load_dataset
ds = load_dataset("lansinuote/ChnSentiCorp", cache_dir=cache_dir)
這種方式會(huì)默認(rèn)從 Hugging Face 平臺(tái)下載數(shù)據(jù)集到本地,且必須聯(lián)網(wǎng),即使本地已經(jīng)有有數(shù)據(jù)集,也需要聯(lián)網(wǎng),但不會(huì)重復(fù)下載。數(shù)據(jù)集格式為 Arrow(平臺(tái)自定義的加密格式),無(wú)法直接查看,可通過(guò)代碼訪問(wèn)。
第二步:本地加載(推薦)
在第一步中已把數(shù)據(jù)集已下載到本地,然后可使用 Dataset.from_file 加載,并使用 PyTorch 創(chuàng)建自定義數(shù)據(jù)集:
- PyTorch 中自定 Dataset 類(lèi)必須實(shí)現(xiàn)
__init__(self)、__getitem__(self, index)以及__len__(self)三個(gè)方法; -
__init__(self)方法中首先加載父類(lèi)的方法,然后通過(guò)Dataset.from_file加載數(shù)據(jù)集; -
__getitem__(self, index)返回指定索引的數(shù)據(jù),這里返回評(píng)論的文本和標(biāo)簽。
from datasets import Dataset
from torch.utils import data
class MyDataset(data.Dataset):
def __init__(self):
super().__init__()
self.train_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-train.arrow')
self.validation_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-validation.arrow')
self.test_dataset = Dataset.from_file(r'/Users/Desktop/huggingface/data/lansinuote___chn_senti_corp/default/0.0.0/b0c4c119c3fb33b8e735969202ef9ad13d717e5a/chn_senti_corp-test.arrow')
def __len__(self):
return len(self.train_dataset)
def __getitem__(self, item):
text = self.train_dataset[item]["text"]
label = self.train_dataset[item]["label"]
return text, label
if __name__ == '__main__':
data = MyDataset()
for d in data:
print(d)
注意:
- 路徑應(yīng)為包含
dataset_info.json的根目錄。 - 使用絕對(duì)路徑,并建議在路徑字符串前加
r防止轉(zhuǎn)義。
加載后可查看數(shù)據(jù)集信息:
{'text': '酒店的位置不錯(cuò),附近都靠近購(gòu)物中心和寫(xiě)字樓區(qū)。以前來(lái)大連一直都住,但感覺(jué)比較陳舊了。住的期間,酒店在進(jìn)行裝修,翻新和升級(jí)房間設(shè)備。好是好,希望到時(shí)房?jī)r(jià)別漲太多了。', 'label': 1}
{'text': '位置不很方便,周?chē)鷣y哄哄的,衛(wèi)生條件也不如其他如家的店。以后絕不會(huì)再住在這里。', 'label': 0}
{'text': '抱著很大興趣買(mǎi)的,買(mǎi)來(lái)粗粗一翻排版很不錯(cuò),姐姐還說(shuō)快看吧,如果好我也買(mǎi)一本??墒钦娴目戳?,實(shí)在不怎么樣。就是中文里夾英文單詞說(shuō)話(huà),才翻了2頁(yè)實(shí)在不想勉強(qiáng)自己了。我想說(shuō)的是,練習(xí)英文單詞,靠這本書(shū)肯定沒(méi)有效果,其它好的方法比這強(qiáng)多了。', 'label': 0}
{'text': '東西不錯(cuò),不過(guò)有人不太喜歡鏡面的,我個(gè)人比較喜歡,總之還算滿(mǎn)意。', 'label': 1}
{'text': '房間不錯(cuò),只是上網(wǎng)速度慢得無(wú)法忍受,打開(kāi)一個(gè)網(wǎng)頁(yè)要等半小時(shí),連郵件都無(wú)法收。另前臺(tái)工作人員服務(wù)態(tài)度是很好,只是效率有得改善。', 'label': 1}
{'text': '挺失望的,還不如買(mǎi)一本張愛(ài)玲文集呢,以<色戒>命名,可這篇文章僅僅10多頁(yè),且無(wú)頭無(wú)尾的,完全比不上里面的任意一篇其它文章.', 'label': 0}
二、模型和分詞器的本地加載
2.1 模型來(lái)源
同樣通過(guò) Hugging Face 平臺(tái)獲取 bert-base-chinese,這是一個(gè)中文文本分類(lèi)模型。
下載模型的同時(shí)需要下載對(duì)應(yīng)的分詞器,原因在于模型不能直接識(shí)別文字,分詞器的作用是把每個(gè)文字轉(zhuǎn)為模型可識(shí)別的數(shù)字,再把數(shù)字輸入給模型,不同的模型有不同的分詞器,因此模型與分詞器必須匹配。
from transformers import AutoModel, BertTokenizerFast
model_name = 'ckiplab/bert-base-chinese'
cache_dir = 'model'
# 下載分詞器
BertTokenizerFast.from_pretrained('bert-base-chinese', cache_dir=cache_dir)
print("tokenizer done")
# 下載模型
AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
print("model done")
平臺(tái)上的模型大多存儲(chǔ)在 Google 云盤(pán),下載時(shí)需聯(lián)網(wǎng),但可緩存到本地供后續(xù)離線使用。
from transformers import BertModel, BertTokenizerFast
# 從本地加載模型和分詞器
tokenizer = BertTokenizerFast.from_pretrained(r'/Users/Desktop/huggingface/model/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea')
model = BertModel.from_pretrained(r'/Users/Desktop/huggingface/model/models--ckiplab--bert-base-chinese/snapshots/efe27bb4a9373384e0120ffe1cf327714ceb61bf')
print(model)
2.2 模型結(jié)構(gòu)簡(jiǎn)介
打印 BERT 模型,其結(jié)構(gòu)主要分為三部分:
- Embedding 層:將輸入的位置編碼(即分詞后的數(shù)字序列)轉(zhuǎn)換為 768 維的詞向量。
- Encoder 層:由多層 Transformer 編碼器組成,用于提取特征。BERT-base 包含 12 層 Encoder,這是 Transformer 模型有效的最低層數(shù)要求。
- Pooler 層:由一個(gè)全聯(lián)接層和一個(gè)激活函數(shù)組成,其中全聯(lián)接層的輸出維度為 768,這個(gè)很重要,關(guān)系到后面模型的搭建。
三、模型設(shè)計(jì)
定義模型時(shí),只需要定義全連接層,再追加到 BERT 模型的輸出。因此整體流程是:BERT —> 自定義全連接層 —> softmax。
- BERT 模型:作為通用語(yǔ)言理解基座,提供高質(zhì)量的文本特征表示
- 自定義全連接層:作為專(zhuān)用任務(wù)頭,將通用特征映射到具體業(yè)務(wù)分類(lèi)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# in_features=768 是BERT的輸出維度,out_features=2 是分類(lèi)數(shù)
self.fc = torch.nn.Linear(in_features=768, out_features=2)
def forward(self, input_ids, attention_mask, token_type_ids):
# BERT 模型不需要訓(xùn)練
with torch.no_grad():
out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# 取[CLS]位置的隱藏狀態(tài)
out = self.fc(out.last_hidden_state[:, 0])
out = out.softmax(dim=1)
return out
BERT 的輸入?yún)?shù)包含:
- input_ids: 將文本轉(zhuǎn)換成的token索引序列。例如“我喜歡自然語(yǔ)言處理” → [101, 2769, 4263, 3315, 4476, 1566, 5543, 102]
- attention_mask: 標(biāo)識(shí)哪些token是有效內(nèi)容,哪些是填充(padding)。因?yàn)?BERT 要求輸入是固定長(zhǎng)度的,所以對(duì)于短于最大長(zhǎng)度的序列,用 0 進(jìn)行填充,會(huì)忽略這些填充詞。
- token_type_ids: 區(qū)分句子對(duì)中的第一個(gè)句子和第二個(gè)句子
BERT 的這些輸入將在 Dataloader 中,使用分詞器提取。
BERT 的輸出參數(shù)包含:
- last_hidden_state:最后一層的隱藏狀態(tài) [batch_size, seq_len, hidden_size]
- pooler_output:[CLS] token 經(jīng)過(guò)線性層和 tanh 激活后的表示 [batch_size, hidden_size]
取 [CLS] 位置的隱藏狀態(tài),是因?yàn)?[CLS] 包含了整個(gè)句子的信息,并且預(yù)訓(xùn)練任務(wù)中已經(jīng)訓(xùn)練它做類(lèi)似的任務(wù)。其他 token 的表示則更多關(guān)注局部信息,不適合直接用于整個(gè)句子的分類(lèi)。
四、創(chuàng)建自定義 Dataloader
Dataloader 需要去從 Dataset 讀取數(shù)據(jù),它提供了一種簡(jiǎn)便的方式來(lái)迭代數(shù)據(jù)集。
這里需要設(shè)置批處理數(shù)據(jù)讀取,有一個(gè)很重要的操作就是要先定義一個(gè) collate_fn 函數(shù),在這個(gè)函數(shù)中我們將 Dataset 原始文本通過(guò)分詞器(tokenizer)進(jìn)行詞向量的轉(zhuǎn)換,轉(zhuǎn)換為模型可理解的參數(shù)輸入給模型。
collate_fn 是一個(gè)可選參數(shù),允許用戶(hù)自定義如何將多個(gè)數(shù)據(jù)樣本合并成一個(gè) batch,由于 bert 模型需要三個(gè)入?yún)ⅲ╥nput_ids, attention_mask, token_type_ids),因此需要把數(shù)據(jù)集通過(guò)分詞器轉(zhuǎn)化為這三個(gè)參數(shù)。
import torch
from my_dataset import MyDataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
# 加載分詞器
token_path = r'/Users/Desktop/huggingface/model/models--bert-base-chinese/snapshots/8f23c25b06e129b6c986331a13d8d025a92cf0ea'
tokenizer = BertTokenizerFast.from_pretrained(token_path)
def collate_fn(data):
sentes = [i[0] for i in data]
label = [i[1] for i in data]
data = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentes,
truncation=True,
padding='max_length',
max_length=512,
return_tensors='pt', # 輸出數(shù)據(jù)將作為 PyTorch 張量返回,而不是 NumPy 數(shù)組或其他格式
return_length=True
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
label = torch.LongTensor(label)
return input_ids, attention_mask, token_type_ids, label
# 初始化數(shù)據(jù)集
train_data = MyDataset()
# 數(shù)據(jù)加載
train_loader = DataLoader(
dataset=train_data,
batch_size=32,
shuffle=True,
drop_last=True,
collate_fn=collate_fn
)
五、模型訓(xùn)練
定義模型、優(yōu)化器、損失函數(shù),開(kāi)啟訓(xùn)練模式,在訓(xùn)練批次中,從 train_loader 加載數(shù)據(jù),并把數(shù)據(jù)(input_ids, attention_mask, token_type_ids)傳入給模型。
model = Model().to('cpu')
optimizer = AdamW(model.parameters(), lr=0.001)
loss_func = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(10):
for i, (input_ids, attention_mask, token_type_ids, label) in enumerate(train_loader):
input_ids, attention_mask, token_type_ids, label = input_ids.to('cpu'), attention_mask.to('cpu'), token_type_ids.to('cpu'), label.to('cpu')
out = model(input_ids, attention_mask, token_type_ids)
loss = loss_func(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 5 == 0:
out = out.argmax(dim=1)
acc = (out == label).sum().item() / len(label)
print(f'epoch={epoch}, i={i}, loss={loss.item()}, acc={acc}')
torch.save(model.state_dict(), f'param/{epoch}bert.pt')
print(epoch, 'save success')
開(kāi)始訓(xùn)練,打印準(zhǔn)確率:
epoch = 0, i=0, loss=0.7204890847206116, acc=0.4375
epoch = 0, i=5, loss=0.6873999834060669, acc=0.65625
epoch = 0, i=10, loss=0.5797467827796936, acc=0.78125
epoch = 0, i=15, loss=0.6594133377075195, acc=0.625
epoch = 0, i=20, loss=0.6064494252204895, acc=0.71875
epoch = 0, i=25, loss=0.649083137512207, acc=0.625
epoch = 0, i=30, loss=0.5833480358123779, acc=0.75
epoch = 0, i=35, loss=0.5145970582962036, acc=0.875
……
六、模型預(yù)測(cè)
預(yù)測(cè)的流程與訓(xùn)練流程類(lèi)似,也需要把輸入的文字先通過(guò)分詞器轉(zhuǎn)為模型能理解的數(shù)字,再輸入給模型。
def collate_fn2(data):
sentes = []
sentes.append(data)
data = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentes,
truncation=True,
padding='max_length',
max_length=512,
return_tensors='pt',
return_length=True
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
return input_ids, attention_mask, token_type_ids
def test():
name = ["差評(píng)", "好評(píng)"]
model.load_state_dict(torch.load("param/1bert.pt"))
model.eval()
while True:
data = input("input content: ")
input_ids, attention_mask, token_type_ids = collate_fn2(data)
input_ids, attention_mask, token_type_ids = input_ids.to('cpu'), attention_mask.to('cpu'), token_type_ids.to('cpu')
with torch.no_grad():
out = model(input_ids, attention_mask, token_type_ids)
out = out.argmax(dim=1)
print(f"result={name[out]}\n")
input content: 5月8日付款成功,當(dāng)當(dāng)網(wǎng)顯示5月10日發(fā)貨,可是至今還沒(méi)看到貨物,也沒(méi)收到任何通知,簡(jiǎn)不知怎么說(shuō)好?。。?result=差評(píng)
input content: 下次不會(huì)再買(mǎi)了,不喜歡
result=差評(píng)
input content: 質(zhì)量挺不錯(cuò)的,下次再來(lái)
result=好評(píng)