PyTorch Lightning 中的批量測試及其存在的問題

2022-1-5, Wed., 13:37 于鳶尾花基地
可以采用如下方式對之前保存的預(yù)訓(xùn)練模型進(jìn)行批量測試:

for ckpt in ckpt_list:
    model = ptl_module.load_from_checkpoint(ckpt, args=args)
    trainer.test(model, dataloaders=test_dataloader)

然而,在上述循環(huán)中,通過trainer.test每執(zhí)行一次測試,都只是執(zhí)行了一個epoch的測試(也就是執(zhí)行多次ptl_module.test_step和一次ptl_module.test_epoch_end),而不可能把ckpt_list中的多個預(yù)訓(xùn)練模型(checkpoint)當(dāng)做多個epoch,多次執(zhí)行ptl_module.test_epoch_end。

我們期望,對多個checkpoint的測試能像對多個epoch的訓(xùn)練一樣簡潔:

trainer.test(ptl_module, dataloaders=test_dataloader)

怎么做到?在訓(xùn)練過程中,要訓(xùn)練多少個epoch是由參數(shù)max_epochs來決定的;而在測試過程中,怎么辦?PTL并非完整地保存了所有epoch的預(yù)訓(xùn)練模型。

由于在測試過程中對各checkpoint是獨(dú)立測試的,如果要統(tǒng)計多個checkpoint的最優(yōu)性能(如最大PSNR/SSIM),怎么辦?這里的一個關(guān)鍵問題是如何保存每次測試得到的評估結(jié)果,好像PTL并未對此提供接口。

解決方案
PTL提供了“回調(diào)類(Callback)”(在 pytorch_lightning.callbacks 中),可以自定義一個回調(diào)類,并重載on_test_epoch_end方法,來監(jiān)聽ptl_module.test_epoch_end。
如何使用?只需要在定義trainer時,把該自定義的回調(diào)函數(shù)加入其參數(shù)callbacks即可:ptl.Trainer(callbacks=[MetricTracker()])。這里,MetricTracker為自定義的回調(diào)類,具體如下:

class MetricTracker(Callback):

    def __init__(self):
        self.optim_metrics = None

    def on_test_epoch_end(self, trainer, pl_module):
        if self.optim_metrics is None:
            self.optim_metrics = pl_module.metrics_dict
            return

        tensorboard = pl_module.logger.experiment
        metrics_key_list, metrics_val_list = [], []
        for k in pl_module.metrics_dict:
            # comp_fun 是自己定義的比較函數(shù)
            self.optim_metrics[k] = comp_fun(self.optim_metrics[k], pl_module.metrics_dict[k])

評論: 由于MetricTracker具有與Trainer相同的生命周期,因此,在整個測試過程中,MetricTracker能夠維護(hù)一個最優(yōu)的評估結(jié)果optim_metrics。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容