改代碼將Bert的Tensorflow 檢查點(diǎn)轉(zhuǎn)換為 Pytorch的檢查點(diǎn),整理Transformers的代碼得到,為了方便使用同時記錄踩的坑。
Tensorflow檢查點(diǎn)文件解析。
1. 包括以下3個文件
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
2. 其中model.ckpt為checkpoint的文件前綴,在命令行調(diào)用該代碼提供 --tf_checkpoint_path 時需要同時提供checkpoint 前綴,例如 --tf_checkpoint_path model_checkpoint/model.ckpt
同時提供模型Config文件,名字通常為bert_config.json。
調(diào)用該代碼命令行為:
# 依賴自行下載
# $checkpoint_path 為TF-checkpoint路徑
# $save_file 為pytorch-checkpoint 保存文件
python3 convert_bert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path $checkpoint_path/model.ckpt --bert_config_file $checkpoint_path/bert_config.json --pytorch_dump_path $save_file
保存后得到一個 pytorch-checkpoint, 需要同 bert_config.json 和 vocab.txt在同一個文件夾,同時需要將Bert_config.json增加一個命名為config.json的文件,Transformers加載Pytorch模型時會自動調(diào)用,之后可以通過Transformers正常使用。
目前該代碼已經(jīng)保存至 https://github.com/YaoXinZhi/Convert-Bert-TF-checkpoint-to-Pytorch