將TF-checkpoint 文件轉(zhuǎn)換為 pytorch-checkpoint 踩坑

  1. 改代碼將Bert的Tensorflow 檢查點(diǎn)轉(zhuǎn)換為 Pytorch的檢查點(diǎn),整理Transformers的代碼得到,為了方便使用同時記錄踩的坑。

  2. Tensorflow檢查點(diǎn)文件解析。

1. 包括以下3個文件
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
2. 其中model.ckpt為checkpoint的文件前綴,在命令行調(diào)用該代碼提供 --tf_checkpoint_path 時需要同時提供checkpoint 前綴,例如 --tf_checkpoint_path model_checkpoint/model.ckpt
  1. 同時提供模型Config文件,名字通常為bert_config.json。

  2. 調(diào)用該代碼命令行為:

# 依賴自行下載
# $checkpoint_path 為TF-checkpoint路徑
# $save_file 為pytorch-checkpoint 保存文件
python3 convert_bert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path $checkpoint_path/model.ckpt --bert_config_file $checkpoint_path/bert_config.json --pytorch_dump_path $save_file
  1. 保存后得到一個 pytorch-checkpoint, 需要同 bert_config.json 和 vocab.txt在同一個文件夾,同時需要將Bert_config.json增加一個命名為config.json的文件,Transformers加載Pytorch模型時會自動調(diào)用,之后可以通過Transformers正常使用。

  2. 目前該代碼已經(jīng)保存至 https://github.com/YaoXinZhi/Convert-Bert-TF-checkpoint-to-Pytorch

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容