本文解決問題:
torchsummary針對多個輸入模型的時候,其輸出信息中input size等存在著錯誤,這里提供方案解決這個錯誤。
當我們使用pytorch搭建好我們自己的深度學(xué)習(xí)模型的的時候,我們總想看看具體的網(wǎng)絡(luò)信息以及參數(shù)量大小,這時候就要請出我們的神器 torchsummary了,torchsummary的簡單使用如下所示:
# pip install torchsummary
from torchsummary import summary
model = OurOwnModel()
summary(model, input_size=(3, 224, 224), device='cpu')
此時一切正常的話將會輸出下面的信息:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1,180,160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2,359,808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2,359,808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2,359,808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Linear-32 [-1, 4096] 102,764,544
ReLU-33 [-1, 4096] 0
Dropout-34 [-1, 4096] 0
Linear-35 [-1, 4096] 16,781,312
ReLU-36 [-1, 4096] 0
Dropout-37 [-1, 4096] 0
Linear-38 [-1, 1000] 4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96
----------------------------------------------------------------
你發(fā)現(xiàn)一切安好,nice。但是當你像我一樣開始搭建一個多輸入網(wǎng)絡(luò)的時候,這時候麻煩就來了。
from torchsummary import summary
model = OurOwnModel()
summary(model, input_size=[(3, 224, 224), (3, 224, 224), (3, 123)], device='cpu')
此時輸出的信息就會有錯誤了。
# 上面正確的信息省略了
================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 25169045225472.00 # 輸入的大小顯然不對啊
Forward/backward pass size (MB): 22975.86
Params size (MB): 188.32
Estimated Total Size (MB): 25169045248636.18 # 看起來整個數(shù)據(jù)也是顯然有錯誤的
----------------------------------------------------------------
上面的 Input Size(MB) Estimated Total Size (MB)這兩項顯然是有錯誤的。
這里提供如下的解決辦法:
import torchsummary
print(torchsummary.__file__)
上面代碼會輸出torchsummary的安裝路徑,這里得到的如下:
/home/guangkun/anaconda3/envs/jet/lib/python3.7/site-packages/torchsummary/__init__.py
我們知道了torchsummary的地址之后,進入該文件夾,同級目錄如下:
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── torchsummary.cpython-37.pyc
└── torchsummary.py
修改 torchsummary.py文件(大概在100行-103行):
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
修改為:
total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
保存后再運行即可發(fā)現(xiàn)正常了,正常的輸出信息如下:
================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.64
Forward/backward pass size (MB): 179.50
Params size (MB): 188.32
Estimated Total Size (MB): 369.45
----------------------------------------------------------------