模型訓(xùn)練當(dāng)中 checkpoint 作用是什么

最近在微調(diào)大語言模型的過程中發(fā)現(xiàn)訓(xùn)練時(shí)會(huì)在模型生成的目錄出現(xiàn)很多checkpoint開頭的文件夾,這些文件夾下面基本都是一套完整可用的模型文件,還比較占用空間。這里詳細(xì)總結(jié)一下checkpoint 相關(guān)的使用。


訓(xùn)練中產(chǎn)生的檢查點(diǎn)

checkpoint文件的來源

檢查點(diǎn)(checkpoint)的概念最早出現(xiàn)在高性能計(jì)算領(lǐng)域,長(zhǎng)時(shí)間運(yùn)行的任務(wù)容易因?yàn)橐恍┸浻布栴}而失敗。為了避免從頭開始重新運(yùn)行任務(wù),才有了檢查點(diǎn)的概念。在計(jì)算任務(wù)的某個(gè)時(shí)刻保存當(dāng)前狀態(tài)(稱為檢查點(diǎn)),如果任務(wù)中斷,可以從最近的檢查點(diǎn)恢復(fù)而不是重新開始。

深度學(xué)習(xí)領(lǐng)域?yàn)榱藨?yīng)對(duì)訓(xùn)練過程中可能出現(xiàn)的中斷,也采用了檢查點(diǎn)技術(shù)。以 huggingface 的 transformer 庫為例,假如采用如下訓(xùn)練代碼,將會(huì)產(chǎn)生如上圖所示的一系列檢查點(diǎn)文件夾。

epochs = 10
lr = 2e-5
train_bs = 8
eval_bs = train_bs * 2

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=epochs,
    learning_rate=lr,
    per_device_train_batch_size=train_bs,
    per_device_eval_batch_size=eval_bs,
    evaluation_strategy="epoch",
    logging_steps=logging_steps
)

這里還會(huì)有一個(gè)疑問,檢查點(diǎn)文件夾的命名規(guī)則是什么,結(jié)尾的數(shù)字可以看出都是 500 的倍數(shù)。

根據(jù) huggingface 的文檔,檢查點(diǎn)的產(chǎn)生跟 TrainingArguments 的以下幾個(gè)參數(shù)有關(guān)

  • save_strategy 決定了檢查點(diǎn)保存的邏輯,有以下 3 個(gè)選項(xiàng),默認(rèn)為 steps
    • no 訓(xùn)練中不保存檢查點(diǎn)
    • epoch 對(duì)每一個(gè)訓(xùn)練周期保存
    • steps 通過 save_steps 定義如何按訓(xùn)練步數(shù)保存
  • save_steps 兩個(gè)檢查點(diǎn)之間經(jīng)歷的訓(xùn)練步數(shù),默認(rèn)為 500 步。

按照上面訓(xùn)練代碼的邏輯,由于這兩個(gè)參數(shù)都沒有制定,因此默認(rèn)采用訓(xùn)練步數(shù)的方式保存檢查點(diǎn),并且每個(gè) 500 步就會(huì)保存一次。

最后還有一個(gè)問題,就是訓(xùn)練步數(shù)的計(jì)算,每處理一個(gè) batch 數(shù)據(jù)并進(jìn)行一次參數(shù)更新就算作一個(gè) step,按照這個(gè)定義計(jì)算的話,總步數(shù) = (樣本數(shù) / 批大小) * epochs。

我的樣本數(shù)為 1624,批大小為 8,周期數(shù)為 10,帶入公式計(jì)算總步數(shù) = (1624 // 8) * 10 = 2030,這樣也就可以解釋為什么最后一個(gè)檢查點(diǎn)的命名為 checkpoint-2000

checkpoint文件相關(guān)的使用方法

斷點(diǎn)續(xù)訓(xùn)

檢查點(diǎn)設(shè)計(jì)的初衷就是為了任務(wù)中斷之后能夠快速恢復(fù),按照前面設(shè)定的邏輯,使用 transformer 庫恢復(fù)訓(xùn)練的方法如下

# Trainer 的定義
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

# 從最近的檢查點(diǎn)恢復(fù)訓(xùn)練
trainer.train(resume_from_checkpoint=True)

加載最好的模型

考慮到訓(xùn)練過程中發(fā)生的過擬合,常常需要選擇在驗(yàn)證集上性能最好的模型,可通過如下設(shè)置load_best_model_at_end 達(dá)到自動(dòng)選擇的目的

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',            # 保存路徑
    num_train_epochs=5,                # 訓(xùn)練周期數(shù)
    per_device_train_batch_size=32,    # 每個(gè)設(shè)備的訓(xùn)練batch大小
    evaluation_strategy="steps",       # 評(píng)估策略
    save_total_limit=3,                # 保留最近的3個(gè)檢查點(diǎn)
    load_best_model_at_end=True,       # 在訓(xùn)練結(jié)束時(shí)加載驗(yàn)證集上最好的模型
    metric_for_best_model="accuracy",  # 用于選擇最佳模型的指標(biāo)
    greater_is_better=True             # 指標(biāo)越高越好
)

其他

分布式檢查點(diǎn)

對(duì)于分布式訓(xùn)練場(chǎng)景下的管理,參考微軟推出的 DeepSpeed

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容