torchsummary 中input size 異常的問題


本文解決問題

torchsummary針對多個輸入模型的時候,其輸出信息中input size等存在著錯誤,這里提供方案解決這個錯誤。


當我們使用pytorch搭建好我們自己的深度學(xué)習(xí)模型的的時候,我們總想看看具體的網(wǎng)絡(luò)信息以及參數(shù)量大小,這時候就要請出我們的神器 torchsummary了,torchsummary的簡單使用如下所示:

# pip install torchsummary
from torchsummary import summary

model = OurOwnModel()
summary(model, input_size=(3, 224, 224), device='cpu')

此時一切正常的話將會輸出下面的信息:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Linear-32                 [-1, 4096]     102,764,544
             ReLU-33                 [-1, 4096]               0
          Dropout-34                 [-1, 4096]               0
           Linear-35                 [-1, 4096]      16,781,312
             ReLU-36                 [-1, 4096]               0
          Dropout-37                 [-1, 4096]               0
           Linear-38                 [-1, 1000]       4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96
----------------------------------------------------------------

你發(fā)現(xiàn)一切安好,nice。但是當你像我一樣開始搭建一個多輸入網(wǎng)絡(luò)的時候,這時候麻煩就來了。

from torchsummary import summary

model = OurOwnModel()
summary(model, input_size=[(3, 224, 224), (3, 224, 224), (3, 123)], device='cpu')

此時輸出的信息就會有錯誤了。

# 上面正確的信息省略了
================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 25169045225472.00  # 輸入的大小顯然不對啊
Forward/backward pass size (MB): 22975.86
Params size (MB): 188.32
Estimated Total Size (MB): 25169045248636.18 # 看起來整個數(shù)據(jù)也是顯然有錯誤的
----------------------------------------------------------------

上面的 Input Size(MB) Estimated Total Size (MB)這兩項顯然是有錯誤的。

這里提供如下的解決辦法:

import torchsummary
print(torchsummary.__file__)

上面代碼會輸出torchsummary的安裝路徑,這里得到的如下:

/home/guangkun/anaconda3/envs/jet/lib/python3.7/site-packages/torchsummary/__init__.py

我們知道了torchsummary的地址之后,進入該文件夾,同級目錄如下:

├── __init__.py
├── __pycache__
│   ├── __init__.cpython-37.pyc
│   └── torchsummary.cpython-37.pyc
└── torchsummary.py

修改 torchsummary.py文件(大概在100行-103行):

  total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
  total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
  total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
  total_size = total_params_size + total_output_size + total_input_size

修改為:

total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size

保存后再運行即可發(fā)現(xiàn)正常了,正常的輸出信息如下:

================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.64
Forward/backward pass size (MB): 179.50
Params size (MB): 188.32
Estimated Total Size (MB): 369.45
----------------------------------------------------------------
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容