torchsummary:計(jì)算神經(jīng)網(wǎng)絡(luò)模型各層輸出特征圖尺寸及參數(shù)量

之前寫過一篇自動(dòng)計(jì)算模型參數(shù)量、FLOPs、乘加數(shù)以及所需內(nèi)存等數(shù)據(jù)的博客,介紹了torchstat的用法?,F(xiàn)介紹一款更為輕量的工具:torchsummary。使用方法如下:

1:安裝

pip install torchsummary

2:導(dǎo)入和使用

【注意】:此工具是針對(duì)PyTorch的,需配合PyTorch使用!
使用順序可概括如下:
(1)導(dǎo)入torchsummary中的summary對(duì)象;
(2)建立神經(jīng)網(wǎng)絡(luò)模型;
(3)輸入 模型(model)、輸入尺寸(input_size)、批次大?。╞atch_size)、運(yùn)行平臺(tái)(device)信息,運(yùn)行后即可得到summary函數(shù)的返回值。

summary函數(shù)的接口信息如下:
summary(model, input_size, batch_size, device)

4個(gè)參數(shù)在(3)中已進(jìn)行了解釋,其中device是指cpu或gpu.

3:使用實(shí)例

import torch
import torchvision
# 導(dǎo)入torchsummary
from torchsummary import summary

# 需要使用device來指定網(wǎng)絡(luò)在GPU還是CPU運(yùn)行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 建立神經(jīng)網(wǎng)絡(luò)模型,這里直接導(dǎo)入已有模型
# model = model().to(device)
model = torchvision.models.vgg11_bn().to(device)
# 使用summary,注意輸入維度的順序
summary(model, input_size=(3, 224, 224))

輸出如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5        [-1, 128, 112, 112]          73,856
       BatchNorm2d-6        [-1, 128, 112, 112]             256
              ReLU-7        [-1, 128, 112, 112]               0
         MaxPool2d-8          [-1, 128, 56, 56]               0
            Conv2d-9          [-1, 256, 56, 56]         295,168
      BatchNorm2d-10          [-1, 256, 56, 56]             512
             ReLU-11          [-1, 256, 56, 56]               0
           Conv2d-12          [-1, 256, 56, 56]         590,080
      BatchNorm2d-13          [-1, 256, 56, 56]             512
             ReLU-14          [-1, 256, 56, 56]               0
        MaxPool2d-15          [-1, 256, 28, 28]               0
           Conv2d-16          [-1, 512, 28, 28]       1,180,160
      BatchNorm2d-17          [-1, 512, 28, 28]           1,024
             ReLU-18          [-1, 512, 28, 28]               0
           Conv2d-19          [-1, 512, 28, 28]       2,359,808
      BatchNorm2d-20          [-1, 512, 28, 28]           1,024
             ReLU-21          [-1, 512, 28, 28]               0
        MaxPool2d-22          [-1, 512, 14, 14]               0
           Conv2d-23          [-1, 512, 14, 14]       2,359,808
      BatchNorm2d-24          [-1, 512, 14, 14]           1,024
             ReLU-25          [-1, 512, 14, 14]               0
           Conv2d-26          [-1, 512, 14, 14]       2,359,808
      BatchNorm2d-27          [-1, 512, 14, 14]           1,024
             ReLU-28          [-1, 512, 14, 14]               0
        MaxPool2d-29            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-30            [-1, 512, 7, 7]               0
           Linear-31                 [-1, 4096]     102,764,544
             ReLU-32                 [-1, 4096]               0
          Dropout-33                 [-1, 4096]               0
           Linear-34                 [-1, 4096]      16,781,312
             ReLU-35                 [-1, 4096]               0
          Dropout-36                 [-1, 4096]               0
           Linear-37                 [-1, 1000]       4,097,000
================================================================
Total params: 132,868,840
Trainable params: 132,868,840
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 182.03
Params size (MB): 506.85
Estimated Total Size (MB): 689.46
----------------------------------------------------------------

可以看出,batch_size可以不指定,默認(rèn)為-1。summary函數(shù)會(huì)對(duì)模型中的每層輸出特征圖尺寸進(jìn)行計(jì)算,并計(jì)算每層含有的參數(shù)量以及模型的參數(shù)總量等信息,對(duì)于逐層統(tǒng)計(jì)計(jì)算和分析非常直觀和簡潔。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容