PyTorch如何打印模型詳細(xì)信息

我們以resnet18為例,介紹幾種獲取模型摘要的方法。

import torchvistion
model = torchvision.models.resnet18()

1.直接使用PrettyTable

from prettytable import PrettyTable

table = PrettyTable(['Modules', 'Parameters']) 
total_params = 0 
for name, parameter in model.named_parameters():
    if not parameter.requires_grad: continue
    params = parameter.numel()
    table.add_row([name, params])
    total_params+=params
print(table) 
print(f'Total Trainable Params: {total_params}') 

效果如下:


PrettyTable

比較簡單,也沒有模型的輸入輸出情況。

2. TorchSummary

from torchsummary import summary
summary(model, input_size = (3, 64, 64), batch_size = -1)
TorchSummary

整體看美觀了很多,也有了輸出的維度。但是如果能打印出模型的層次結(jié)構(gòu)就更好了。

3. torchinfo

import torchinfo 
torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 0)
torchinfo

這種方式更加美觀,且內(nèi)容詳細(xì),灰常棒。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容