2022-1-5, Wed., 13:37 于鳶尾花基地
可以采用如下方式對之前保存的預(yù)訓(xùn)練模型進(jìn)行批量測試:
for ckpt in ckpt_list:
model = ptl_module.load_from_checkpoint(ckpt, args=args)
trainer.test(model, dataloaders=test_dataloader)
然而,在上述循環(huán)中,通過trainer.test每執(zhí)行一次測試,都只是執(zhí)行了一個epoch的測試(也就是執(zhí)行多次ptl_module.test_step和一次ptl_module.test_epoch_end),而不可能把ckpt_list中的多個預(yù)訓(xùn)練模型(checkpoint)當(dāng)做多個epoch,多次執(zhí)行ptl_module.test_epoch_end。
我們期望,對多個checkpoint的測試能像對多個epoch的訓(xùn)練一樣簡潔:
trainer.test(ptl_module, dataloaders=test_dataloader)
怎么做到?在訓(xùn)練過程中,要訓(xùn)練多少個epoch是由參數(shù)max_epochs來決定的;而在測試過程中,怎么辦?PTL并非完整地保存了所有epoch的預(yù)訓(xùn)練模型。
由于在測試過程中對各checkpoint是獨(dú)立測試的,如果要統(tǒng)計多個checkpoint的最優(yōu)性能(如最大PSNR/SSIM),怎么辦?這里的一個關(guān)鍵問題是如何保存每次測試得到的評估結(jié)果,好像PTL并未對此提供接口。
解決方案
PTL提供了“回調(diào)類(Callback)”(在 pytorch_lightning.callbacks 中),可以自定義一個回調(diào)類,并重載on_test_epoch_end方法,來監(jiān)聽ptl_module.test_epoch_end。
如何使用?只需要在定義trainer時,把該自定義的回調(diào)函數(shù)加入其參數(shù)callbacks即可:ptl.Trainer(callbacks=[MetricTracker()])。這里,MetricTracker為自定義的回調(diào)類,具體如下:
class MetricTracker(Callback):
def __init__(self):
self.optim_metrics = None
def on_test_epoch_end(self, trainer, pl_module):
if self.optim_metrics is None:
self.optim_metrics = pl_module.metrics_dict
return
tensorboard = pl_module.logger.experiment
metrics_key_list, metrics_val_list = [], []
for k in pl_module.metrics_dict:
# comp_fun 是自己定義的比較函數(shù)
self.optim_metrics[k] = comp_fun(self.optim_metrics[k], pl_module.metrics_dict[k])
評論: 由于MetricTracker具有與Trainer相同的生命周期,因此,在整個測試過程中,MetricTracker能夠維護(hù)一個最優(yōu)的評估結(jié)果optim_metrics。