給gpt喂自己的數(shù)據(jù)!

上一篇文章說完了安裝,這下我要喂自己的數(shù)據(jù)了。

1. 準備數(shù)據(jù)

首先要按照給的格式創(chuàng)建自己的json數(shù)據(jù),這個比較好創(chuàng)建,之前沒用過json搜一下就行了。
原文的格式如下:

[
    {
        "instruction": "Give three tips for staying healthy.",
        "input": "",
        "output": "1. Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
    },
    {
        "instruction": "What are the three primary colors?",
        "input": "",
        "output": "The three primary colors are red, blue, and yellow."
    },
]

這個格式就是一個存了dict的list,換成自己的數(shù)據(jù),代碼如下:

import json
data = []
for p in x:   ##這里將自己的數(shù)據(jù)每個換成了對應的dict,然后用list存儲所有的dict
    a={
        "instruction":x[0],
        "input":x[1],
        "output":x[2]
    }
    data.append(a)

直接將data全部以json的格式存到文件里。

with open('./file.json',"w",encoding="utf-8") as  f:
    # ensure_ascii 顯示中文,不以ASCII的方式顯示
    json.dump(data,f, ensure_ascii=False, indent=2)  ##縮進2格,dump函數(shù)將數(shù)據(jù)格式成json類型

得到的格式就跟源碼一樣拉,直接喂給gpt學把!本文采用了40W條指令,batch_size=128,結果gpu超出20G了,重新把size調小了試試,可憐兮兮...

2. 下載參數(shù)

LLaMA-7B-HF 大模型下載:

python
>>> from huggingface_hub import snapshot_download
>>> snapshot_download(repo_id="decapoda-research/llama-7b-hf")

Lora 參數(shù)下載:

>>> snapshot_download(repo_id="tloen/alpaca-lora-7b")

調整finetun.py里的base_model字符串,改成上面的LLaMA-7B-HF大模型的地址就行。
另外在運行的時候發(fā)現(xiàn)程序在驗證會報GPU爆掉的錯誤,搜了一圈說是因為在驗證的時候梯度累積了(但是鏈接的庫太多了不知道在哪里執(zhí)行了驗證程序),后來我在對應的錯誤代碼上加了以下的代碼:

torch.cuda.empty_cache();

nvidia-smi查看gpu使用情況,發(fā)現(xiàn)確實是會少,但是仍然爆顯存,后來我把batch_size改成了10,運行的時候一直查看gpu,一到驗證測試的時候就開始飆升,前面幾輪勉勉強強過去,差那么1G就爆了,結果還是在800的時候爆掉,后來看了下代碼,發(fā)現(xiàn)一個參數(shù):

gradient_accumulation_steps = 2

查閱了下資料,表示的是梯度累積的步數(shù),正常來說是一次batch_size進行一次反向傳播,設置了gradient_accumulation_steps為2,那就是2次進行一次反向傳播,這樣我們就可以每次處理的batch_size少一點了,從而可以減少顯存的使用。
比如我目前的batch_size是10,gradient_acc_steps為5,表示每處理10 * 5條命令就進行一次反向傳播(源代碼是100*2)。每次僅需處理10條命令就可,累積到次數(shù)后反向傳播修正參數(shù)。(照理說不是直接的原因,因為我是到驗證的時候就爆顯存,但是改了gradient后真的不超了,后面再研究研究)。
跑起來了后面再繼續(xù)構建自己的實驗~加油?。?!

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容