最近在微調(diào)大語言模型的過程中發(fā)現(xiàn)訓(xùn)練時(shí)會(huì)在模型生成的目錄出現(xiàn)很多checkpoint開頭的文件夾,這些文件夾下面基本都是一套完整可用的模型文件,還比較占用空間。這里詳細(xì)總結(jié)一下checkpoint 相關(guān)的使用。

checkpoint文件的來源
檢查點(diǎn)(checkpoint)的概念最早出現(xiàn)在高性能計(jì)算領(lǐng)域,長(zhǎng)時(shí)間運(yùn)行的任務(wù)容易因?yàn)橐恍┸浻布栴}而失敗。為了避免從頭開始重新運(yùn)行任務(wù),才有了檢查點(diǎn)的概念。在計(jì)算任務(wù)的某個(gè)時(shí)刻保存當(dāng)前狀態(tài)(稱為檢查點(diǎn)),如果任務(wù)中斷,可以從最近的檢查點(diǎn)恢復(fù)而不是重新開始。
深度學(xué)習(xí)領(lǐng)域?yàn)榱藨?yīng)對(duì)訓(xùn)練過程中可能出現(xiàn)的中斷,也采用了檢查點(diǎn)技術(shù)。以 huggingface 的 transformer 庫為例,假如采用如下訓(xùn)練代碼,將會(huì)產(chǎn)生如上圖所示的一系列檢查點(diǎn)文件夾。
epochs = 10
lr = 2e-5
train_bs = 8
eval_bs = train_bs * 2
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
learning_rate=lr,
per_device_train_batch_size=train_bs,
per_device_eval_batch_size=eval_bs,
evaluation_strategy="epoch",
logging_steps=logging_steps
)
這里還會(huì)有一個(gè)疑問,檢查點(diǎn)文件夾的命名規(guī)則是什么,結(jié)尾的數(shù)字可以看出都是 500 的倍數(shù)。
根據(jù) huggingface 的文檔,檢查點(diǎn)的產(chǎn)生跟 TrainingArguments 的以下幾個(gè)參數(shù)有關(guān)
-
save_strategy 決定了檢查點(diǎn)保存的邏輯,有以下 3 個(gè)選項(xiàng),默認(rèn)為
steps- no 訓(xùn)練中不保存檢查點(diǎn)
- epoch 對(duì)每一個(gè)訓(xùn)練周期保存
-
steps 通過
save_steps定義如何按訓(xùn)練步數(shù)保存
- save_steps 兩個(gè)檢查點(diǎn)之間經(jīng)歷的訓(xùn)練步數(shù),默認(rèn)為 500 步。
按照上面訓(xùn)練代碼的邏輯,由于這兩個(gè)參數(shù)都沒有制定,因此默認(rèn)采用訓(xùn)練步數(shù)的方式保存檢查點(diǎn),并且每個(gè) 500 步就會(huì)保存一次。
最后還有一個(gè)問題,就是訓(xùn)練步數(shù)的計(jì)算,每處理一個(gè) batch 數(shù)據(jù)并進(jìn)行一次參數(shù)更新就算作一個(gè) step,按照這個(gè)定義計(jì)算的話,總步數(shù) = (樣本數(shù) / 批大小) * epochs。
我的樣本數(shù)為 1624,批大小為 8,周期數(shù)為 10,帶入公式計(jì)算總步數(shù) = (1624 // 8) * 10 = 2030,這樣也就可以解釋為什么最后一個(gè)檢查點(diǎn)的命名為 checkpoint-2000
checkpoint文件相關(guān)的使用方法
斷點(diǎn)續(xù)訓(xùn)
檢查點(diǎn)設(shè)計(jì)的初衷就是為了任務(wù)中斷之后能夠快速恢復(fù),按照前面設(shè)定的邏輯,使用 transformer 庫恢復(fù)訓(xùn)練的方法如下
# Trainer 的定義
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
# 從最近的檢查點(diǎn)恢復(fù)訓(xùn)練
trainer.train(resume_from_checkpoint=True)
加載最好的模型
考慮到訓(xùn)練過程中發(fā)生的過擬合,常常需要選擇在驗(yàn)證集上性能最好的模型,可通過如下設(shè)置load_best_model_at_end 達(dá)到自動(dòng)選擇的目的
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir='./results', # 保存路徑
num_train_epochs=5, # 訓(xùn)練周期數(shù)
per_device_train_batch_size=32, # 每個(gè)設(shè)備的訓(xùn)練batch大小
evaluation_strategy="steps", # 評(píng)估策略
save_total_limit=3, # 保留最近的3個(gè)檢查點(diǎn)
load_best_model_at_end=True, # 在訓(xùn)練結(jié)束時(shí)加載驗(yàn)證集上最好的模型
metric_for_best_model="accuracy", # 用于選擇最佳模型的指標(biāo)
greater_is_better=True # 指標(biāo)越高越好
)
其他
分布式檢查點(diǎn)
對(duì)于分布式訓(xùn)練場(chǎng)景下的管理,參考微軟推出的 DeepSpeed