pytorch的模型保存與恢復(fù)~
首先pytorch官網(wǎng)doc中推薦兩種方法。link
然而在需要注意的是:
方法一:
保存
torch.save(the_model.state_dict(), PATH)
恢復(fù)
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
然而這種方法只會保存模型的參數(shù),并不會保存Epoch、optimizer、weight之類。我們需要自己導(dǎo)入模型的結(jié)構(gòu)信息。
方法二:
保存
torch.save(the_model, PATH)
恢復(fù)
the_model = torch.load(PATH)
一個相對完整的例子
保存
torch.save({
? ? ? ? ? ? 'epoch': epoch + 1,
? ? ? ? ? ? 'arch': args.arch,
? ? ? ? ? ? 'state_dict': model.state_dict(),
? ? ? ? ? ? 'best_prec1': best_prec1,
? ? ? ? }, 'checkpoint.tar' )
恢復(fù)
if args.resume:
? ? ? ? if os.path.isfile(args.resume):
? ? ? ? ? ? print("=> loading checkpoint '{}'".format(args.resume))
? ? ? ? ? ? checkpoint = torch.load(args.resume)
? ? ? ? ? ? args.start_epoch = checkpoint['epoch']
? ? ? ? ? ? best_prec1 = checkpoint['best_prec1']
? ? ? ? ? ? model.load_state_dict(checkpoint['state_dict'])
? ? ? ? ? ? print("=> loaded checkpoint '{}' (epoch {})"? ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ?.format(args.evaluate, checkpoint['epoch']))
獲取模型中某些層的參數(shù)
對于恢復(fù)的模型,如果我們想查看某些層的參數(shù),可以:
# 定義一個網(wǎng)絡(luò)
from collections import Ordered
Dictmodel = nn.Sequential(OrderedDict([? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ('conv1', nn.Conv2d(1,20,5)),
? ? ? ? ? ? ? ? ? ('relu1', nn.ReLU()),
? ? ? ? ? ? ? ? ? ('conv2', nn.Conv2d(20,64,5)),
? ? ? ? ? ? ? ? ? ('relu2', nn.ReLU())
? ? ? ? ? ? ? ? ]))# 打印網(wǎng)絡(luò)的結(jié)構(gòu)print(model)
OUT:
Sequential (
? (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
? (relu1): ReLU ()
? (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
? (relu2): ReLU ()
)
如果我們想獲取conv1的weight和bias:
params=model.state_dict()
for k,v in params.items():
? ? print(k)? ? #打印網(wǎng)絡(luò)中的變量名
print(params['conv1.weight'])? #打印conv1的weight
print(params['conv1.bias']) #打印conv1的bias