最近在使用華為云的模型訓(xùn)練服務(wù)平臺(tái)時(shí),發(fā)現(xiàn)一個(gè)問題:當(dāng)算法需要一個(gè)預(yù)訓(xùn)練模型的時(shí)候,怎么去上傳呢?在平臺(tái)上面只找到了一個(gè)數(shù)據(jù)集上傳的接口,但這個(gè)接口最大上傳容量時(shí)80M,而算法需要的預(yù)訓(xùn)練模型遠(yuǎn)遠(yuǎn)大于80M, 那該怎么辦呢?既然沒辦法直接上傳,只能考慮分塊上傳再合并,但是由于這個(gè)平臺(tái)和本地操作還是有一些區(qū)別,對(duì)于不熟悉的平臺(tái)的新手來說(比如我),中間可能需要踩很多次坑。為了避免下一次我使用的時(shí)時(shí)候踩很多坑,所以做個(gè)筆記簡單記錄一下模型上傳的整個(gè)流程。下面以pytorch的預(yù)訓(xùn)練模型resnet50.pth.tar為例子,展示整個(gè)操作流程。
1. 本地切分模型
利用以下代碼在本地將resnet50的模型以50MB為單位進(jìn)行切分
def split(orgFile, chunkSize, saveDir):
'''
:param orgFile: 原模型存儲(chǔ)的位置, 比如 ./checkpoint/resnet50.pth.tar
:param chunkSize: 文件切分字節(jié)數(shù),比如以50M為單位切分原模型,chunkSize 1024*1024*50
:param saveDir: 切分后的文件儲(chǔ)存路徑, 比如 ./resnet50
:return:
'''
if not os.path.exists(orgFile):
print("Cannot find orgFile:{0}".format(orgFile))
return -1
if not os.path.exists(saveDir):
os.makedirs(saveDir)
partNum = 0
with open(orgFile, 'rb') as inputFile:
while True:
chunk = inputFile.read(chunkSize)
if not chunk:
break
partNum += 1
fileName = os.path.join(saveDir, 'part%04d' % partNum)
with open(fileName, 'wb') as f:
f.write(chunk)
return partNum

切分后的數(shù)據(jù).png
2. 分批上傳模型
如下圖所示,分批上傳分割后的模型,數(shù)據(jù)類別一定要選擇其他。

模型上傳.png
上傳完畢之后,可以在數(shù)據(jù)集的目錄下看到已經(jīng)上傳好的模型

模型上傳成功.png
3. 合并模型
運(yùn)行下面的代碼進(jìn)行合并
# -*- coding: utf-8 -*-
from __future__ import print_function # do not delete this line if you want to save your log file.
import softcomai as sai
import numpy
import os
import shutil
import torch
def joinFile():
# 獲取各個(gè)文件的路徑
filePaths = []
for i in range(1,5):
entityName = "resnet50_{0}".format(i)
data_reference = sai.data_reference.get_data_reference(dataset="Default", dataset_entity=entityName)
filePaths.append(data_reference.get_files_paths()) # 注意data_reference.get_files_paths()返回的是一個(gè)列表
# 獲取平臺(tái)的sdk返回模型保存路徑
saveDir = sai.context.param(sai.context.MODEL_PATH)
fileName = 'resnet50.pth.tar'
savePath = os.path.join(saveDir, fileName)
outFile = open(savePath, 'wb')
# 合并文件
for filePath in filePaths:
inFile = open(filePath[0], 'rb')
data = inFile.read()
outFile.write(data)
inFile.close()
outFile.close()
return savePath
if __name__ == "__main__":
savePath = joinFile()
# 測試合并后的文件是否可用
checkpoint = torch.load(savePath,map_location='cpu')
print(checkpoint.keys())
運(yùn)行完后,可以在任務(wù)目錄下找到已經(jīng)保存好的預(yù)訓(xùn)練模型

文件合并成功.png