pytorch之保存與加載模型

pytorch之保存與加載模型

本篇筆記譯自pytorch官網(wǎng)tutorial,用于方便查看。
pytorch與保存、加載模型有關(guān)的常用函數(shù)3個(gè):

  • torch.save(): 保存一個(gè)序列化的對(duì)象到磁盤,使用的是Pythonpickle庫(kù)來(lái)實(shí)現(xiàn)的。
  • torch.load(): 解序列化一個(gè)pickled對(duì)象并加載到內(nèi)存當(dāng)中。
  • torch.nn.Module.load_state_dict(): 加載一個(gè)解序列化的state_dict對(duì)象

1. state_dict

PyTorch中所有可學(xué)習(xí)的參數(shù)保存在model.parameters()中。state_dict是一個(gè)Python字典。保存了各層與其參數(shù)張量之間的映射。torch.optim對(duì)象也有一個(gè)state_dict,它包含了optimizerstate,以及一些超參數(shù)。

2. 保存&加載模型來(lái)inference(recommended)

save

torch.save(model.state_dict(), PATH)

load

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()  # 當(dāng)用于inference時(shí)不要忘記添加
  • 保存的文件名后綴可以是.pt.pth
  • 當(dāng)用于inference時(shí)不要忘記添加model.eval()

3. 保存&加載整個(gè)模型(not recommended)

save

torch.save(model, PATH)

load

# Model class must be defined somewhere
model = torch.load()
model.eval()

4. 保存&加載帶checkpoint的模型用于inferenceresuming training

save

torch.save({
  'epoch': epoch,
  'model_state_dict': model.state_dict(),
  'optimizer_state_dict': optimizer.state_dict(),
  'loss': loss,
  ...
  }, PATH)

load

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# or
model.train()

5. 保存多個(gè)模型到一個(gè)文件中

save

torch.save({
  'modelA_state_dict': modelA.state_dict(),
  'modelB_state_dict': modelB.state_dict(),
  'optimizerA_state_dict': optimizerA.state_dict(),
  'optimizerB_state_dict': optimizerB.state_dict(),
  ...
  }, PATH)

load

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelAClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict']
modelB.load_state_dict(checkpoint['modelB_state_dict']
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']

modelA.eval()
modelB.eval()
# or
modelA.train()
modelB.train()
  • 此情況可能在GAN,Sequence-to-sequence,或ensemble models中使用
  • 保存checkpoint常用.tar文件擴(kuò)展名

6. Warmstarting Model Using Parameters From A Different Model

save

torch.save(modelA.state_dict(), PATH)

load

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
  • 在遷移訓(xùn)練時(shí),可能希望只加載部分模型參數(shù),此時(shí)可置strict參數(shù)為False來(lái)忽略那些沒有匹配到的keys

7. 保存&加載模型跨設(shè)備

(1) Save on GPU, Load on CPU
save

torch.save(model.state_dict(), PATH)

load

device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

(2) Save on GPU, Load on GPU
save

torch.save(model.state_dict(), PATH)

load

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

(3) Save on CPU, Load on GPU
save

torch.save(model.state_dict(), PATH)

load

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)

8. 保存torch.nn.DataParallel模型

save

torch.save(model.module.state_dict(), PATH)

load

# Load to whatever device you want

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容