pytorch多GPU并行計算及模型的保存與載入

1.多gpu并行計算

class Net(nn.Module):
    def __init__(input, output):
        pass
        #define your network 
net = Net(input, output) #實例化模型
net = nn.DataParallel(net) #數(shù)據(jù)并行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #初始化計算設備
net .to(device)

2.模型的保存

if
torch.save(net.)

3.模型的重載

checkpoint = torch.load(resume)
state_dict =checkpoint['state_dict']

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of dataparallel
    new_state_dict[name]=v

model.load_state_dict(new_state_dict)

4.模型的遷移

# cpu or gpu
torch.load('model/path', map_location='cpu')

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

友情鏈接更多精彩內容