pytorch:model save & model load

pytorch的模型保存與恢復(fù)~

首先pytorch官網(wǎng)doc中推薦兩種方法。link

然而在需要注意的是:

方法一:

保存

torch.save(the_model.state_dict(), PATH)

恢復(fù)

the_model = TheModelClass(*args, **kwargs)

the_model.load_state_dict(torch.load(PATH))

然而這種方法只會保存模型的參數(shù),并不會保存Epoch、optimizer、weight之類。我們需要自己導(dǎo)入模型的結(jié)構(gòu)信息。

方法二:

保存

torch.save(the_model, PATH)

恢復(fù)

the_model = torch.load(PATH)

一個相對完整的例子

保存

torch.save({

? ? ? ? ? ? 'epoch': epoch + 1,

? ? ? ? ? ? 'arch': args.arch,

? ? ? ? ? ? 'state_dict': model.state_dict(),

? ? ? ? ? ? 'best_prec1': best_prec1,

? ? ? ? }, 'checkpoint.tar' )

恢復(fù)

if args.resume:

? ? ? ? if os.path.isfile(args.resume):

? ? ? ? ? ? print("=> loading checkpoint '{}'".format(args.resume))

? ? ? ? ? ? checkpoint = torch.load(args.resume)

? ? ? ? ? ? args.start_epoch = checkpoint['epoch']

? ? ? ? ? ? best_prec1 = checkpoint['best_prec1']

? ? ? ? ? ? model.load_state_dict(checkpoint['state_dict'])

? ? ? ? ? ? print("=> loaded checkpoint '{}' (epoch {})"? ? ? ? ? ? ? ? ?

? ? ? ? ? ? ? ? ? ? ? ? ? ?.format(args.evaluate, checkpoint['epoch']))

獲取模型中某些層的參數(shù)

對于恢復(fù)的模型,如果我們想查看某些層的參數(shù),可以:

# 定義一個網(wǎng)絡(luò)

from collections import Ordered

Dictmodel = nn.Sequential(OrderedDict([? ? ? ? ? ? ? ?

? ? ? ? ? ? ? ? ? ('conv1', nn.Conv2d(1,20,5)),

? ? ? ? ? ? ? ? ? ('relu1', nn.ReLU()),

? ? ? ? ? ? ? ? ? ('conv2', nn.Conv2d(20,64,5)),

? ? ? ? ? ? ? ? ? ('relu2', nn.ReLU())

? ? ? ? ? ? ? ? ]))# 打印網(wǎng)絡(luò)的結(jié)構(gòu)print(model)

OUT:

Sequential (

? (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))

? (relu1): ReLU ()

? (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))

? (relu2): ReLU ()

)

如果我們想獲取conv1的weight和bias:

params=model.state_dict()

for k,v in params.items():

? ? print(k)? ? #打印網(wǎng)絡(luò)中的變量名

print(params['conv1.weight'])? #打印conv1的weight

print(params['conv1.bias']) #打印conv1的bias

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容