PyTorch之HOOK——獲取神經(jīng)網(wǎng)絡(luò)特征和梯度的有效工具

本文首發(fā)于簡書 西北小生_ 的博客:http://www.itdecent.cn/u/898c7641f6ea,未經(jīng)允許,禁止轉(zhuǎn)載!

為了更深入地理解神經(jīng)網(wǎng)絡(luò)模型,有時候我們需要觀察它訓(xùn)練得到的卷積核、特征圖或者梯度等信息,這在CNN可視化研究中經(jīng)常用到。其中,卷積核最易獲取,將模型參數(shù)保存即可得到;特征圖是中間變量,所對應(yīng)的圖像處理完即會被系統(tǒng)清除,否則將嚴(yán)重占用內(nèi)存;梯度跟特征圖類似,除了葉子結(jié)點(diǎn)外,其它中間變量的梯度都被會內(nèi)存釋放,因而不能直接獲取。
最容易想到的獲取方法就是改變模型結(jié)構(gòu),在forward的最后不但返回模型的預(yù)測輸出,還返回所需要的特征圖等信息。

如何在不改變模型結(jié)構(gòu)的基礎(chǔ)上獲取特征圖、梯度等信息呢?

Pytorch的hook編程可以在不改變網(wǎng)絡(luò)結(jié)構(gòu)的基礎(chǔ)上有效獲取、改變模型中間變量以及梯度等信息。
hook可以提取或改變Tensor的梯度,也可以獲取nn.Module的輸出和梯度(這里不能改變)。因此有3個hook函數(shù)用于實(shí)現(xiàn)以上功能:

Tensor.register_hook(hook_fn),
nn.Module.register_forward_hook(hook_fn),
nn.Module.register_backward_hook(hook_fn).

下面對其用法進(jìn)行一一介紹。

1.Tensor.register_hook(hook_fn)

功能:注冊一個反向傳播hook函數(shù),用于自動記錄Tensor的梯度。
PyTorch對中間變量和非葉子節(jié)點(diǎn)的梯度運(yùn)行完后會自動釋放,以減緩內(nèi)存占用。什么是中間變量?什么是非葉子節(jié)點(diǎn)?


Tensor計算

上圖中,a,b,d就是葉子節(jié)點(diǎn),c,e,o是非葉子節(jié)點(diǎn),也是中間變量。

In [18]: a = torch.Tensor([1,2]).requires_grad_() 
    ...: b = torch.Tensor([3,4]).requires_grad_() 
    ...: d = torch.Tensor([2]).requires_grad_() 
    ...: c = a + b 
    ...: e = c * d 
    ...: o = e.sum()     

In [19]: o.backward()

In [20]: print(a.grad)
tensor([2., 2.])

In [21]: print(b.grad)
tensor([2., 2.])

In [22]: print(c.grad)
None

In [23]: print(d.grad)
tensor([10.])

In [24]: print(e.grad)
None

In [25]: print(o.grad)
None

可以從程序的輸出中看到,a,b,d作為葉子節(jié)點(diǎn),經(jīng)過反向傳播后梯度值仍然保留,而其它非葉子節(jié)點(diǎn)的梯度已經(jīng)被自動釋放了,要想得到它們的梯度值,就需要使用hook了。

我們首先自定義一個hook_fn函數(shù),用于記錄對Tensor梯度的操作,然后用Tensor.register_hook(hook_fn)對要獲取梯度的非葉子結(jié)點(diǎn)的Tensor進(jìn)行注冊,然后重新反向傳播一次:

In [44]: def hook_fn(grad):
    ...:     print(grad)
    ...:

In [45]: e.register_hook(hook_fn)
Out[45]: <torch.utils.hooks.RemovableHandle at 0x1d139cf0a88>

In [46]: o.backward()
tensor([1., 1.])

這時就自動輸出了e的梯度。

自定義的hook_fn函數(shù)的函數(shù)名可以是任取的,它的參數(shù)是grad,表示Tensor的梯度。這個自定義函數(shù)主要是用于描述對Tensor梯度值的操作,上例中我們是對梯度直接進(jìn)行輸出,所以是print(grad)。我們也可以把梯度裝在一個列表或字典里,甚至可以修改梯度,這樣如果梯度很小的時候?qū)⑵渥兇笠稽c(diǎn)就可以防止梯度消失的問題了:

In [28]: a = torch.Tensor([1,2]).requires_grad_() 
    ...: b = torch.Tensor([3,4]).requires_grad_() 
    ...: d = torch.Tensor([2]).requires_grad_() 
    ...: c = a + b 
    ...: e = c * d 
    ...: o = e.sum()                                                            

In [29]: grad_list = []                                                         

In [30]: def hook(grad): 
    ...:     grad_list.append(grad)    # 將梯度裝在列表里
    ...:     return 2 * grad    # 將梯度放大兩倍
    ...:                                                                        

In [31]: c.register_hook(hook)                                                  
Out[31]: <torch.utils.hooks.RemovableHandle at 0x7f009b713208>

In [32]: o.backward()                                                           

In [33]: grad_list                                                              
Out[33]: [tensor([2., 2.])]

In [34]: a.grad                                                                 
Out[34]: tensor([4., 4.])

In [35]: b.grad                                                                 
Out[35]: tensor([4., 4.])

上例中,我們定義的hook函數(shù)執(zhí)行了兩個操作:一是將梯度裝進(jìn)列表grad_list中,二是把梯度放大兩倍。從輸出中我們可以看到,執(zhí)行反向傳播后,我們注冊的非葉子節(jié)點(diǎn)c的梯度保存在了列表grad_list中,并且a和b的梯度都變?yōu)樵瓉淼膬杀?。這里需要注意的是,如果要將梯度值裝在一個列表或字典里,那么首先要定義一個同名的全局變量的列表或字典,即使是局部變量,也要在自定義的hook函數(shù)外面。另一個需要注意的點(diǎn)就是如果要改變梯度值,hook函數(shù)要有返回值,返回改變后的梯度。

這里總結(jié)一下,如果要獲取非葉子節(jié)點(diǎn)Tensor的梯度值,我們需要在反向傳播前
1)自定義一個hook函數(shù),描述對梯度的操作,函數(shù)名自擬,參數(shù)只有g(shù)rad,表示Tensor的梯度;
2)對要獲取梯度的Tensor用方法Tensor.register_hook(hook)進(jìn)行注冊。
3)執(zhí)行反向傳播。

2.nn.Module.register_forward_hook(hook_fn)和nn.Module.register_backward_hook(hook_fn)

這兩個的操作對象都是nn.Module類,如神經(jīng)網(wǎng)絡(luò)中的卷積層(nn.Conv2d),全連接層(nn.Linear),池化層(nn.MaxPool2d, nn.AvgPool2d),激活層(nn.ReLU)或者nn.Sequential定義的小模塊等,所以放在一起講。

對于模型的中間模塊,也可以視作中間節(jié)點(diǎn)(非葉子節(jié)點(diǎn)),它的輸出為特征圖或激活值,反向傳播的梯度值都會被系統(tǒng)自動釋放,如果想要獲取它們,就要用到hook功能。

有名字即可看出,register_forward_hook是獲取前向傳播的輸出的,即特征圖或激活值;register_backward_hook是獲取反向傳播的輸出的,即梯度值。它們的用法和上面介紹的register_hook類似。我們先看一下hook_fn的定義:

對于register_forward_hook(hook_fn),其hook_fn函數(shù)定義如下:

def forward_hook(module, input, output):
    operations

這里有3個參數(shù),分別表示:模塊,模塊的輸入,模塊的輸出。函數(shù)用于描述對這些參數(shù)的操作,一般我們都是為了獲取特征圖,即只描述對output的操作即可。

對于register_backward_hook(hook_fn),其hook_fn函數(shù)定義如下:

def backward_hook(module, grad_in, grad_out):
    operations

這里也有3個參數(shù),分別表示:模塊,模塊輸入端的梯度,模塊輸出端的梯度。這里需要特別注意的是,此處的輸入端和輸出端,是前向傳播時的輸入端和輸出端,也就是說,上面的output的梯度對應(yīng)這里的grad_out。例如線性模塊:o=W*x+b,其輸入端為 W,x 和 b,輸出端為 o。

如果模塊有多個輸入或者輸出的話,grad_in和grad_out可以是 tuple 類型。對于線性模塊:o=W*x+b ,它的輸入端包括了W、x 和 b 三部分,因此 grad_input 就是一個包含三個元素的 tuple。

這里注意和 forward hook 的不同:

  1. 在 forward hook 中,input 是 x,而不包括 W 和 b。
  2. 返回 Tensor 或者 None,backward hook 函數(shù)不能直接改變它的輸入變量,但是可以返回新的 grad_in,反向傳播到它上一個模塊。

此處的自定義的函數(shù)hook_fn也可以自擬名稱,但如果兩個hook函數(shù)同時使用的時候注意名稱的區(qū)別,一般在函數(shù)名里添加對應(yīng)的forward和backward就不易搞混了。

下面看一個具體用例:

#-*- utf-8 -*-

'''本程序用于驗(yàn)證hook編程獲取卷積層的輸出特征圖和特征圖的梯度'''

__author__ = 'puxitong from UESTC'

import torch
import torch.nn as nn
import numpy as np 
import torchvision.transforms as transforms


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,3,1,1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,9,3,1,1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(8*8*9, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120,10)

    def forward(self, x):
        out = self.pool1(self.relu1(self.conv1(x)))
        out = self.pool2(self.relu2(self.conv2(out)))
        out = out.view(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)

        return out


def backward_hook(module, grad_in, grad_out):
    grad_block['grad_in'] = grad_in
    grad_block['grad_out'] = grad_out


def farward_hook(module, inp, outp):
    fmap_block['input'] = inp
    fmap_block['output'] = outp


loss_func = nn.CrossEntropyLoss()

# 生成一個假標(biāo)簽以便演示
label = torch.empty(1, dtype=torch.long).random_(3)

# 生成一副假圖像以便演示
input_img = torch.randn(1,3,32,32).requires_grad_()  

fmap_block = dict()  # 裝feature map
grad_block = dict()  # 裝梯度

net = Net()

# 注冊hook
net.conv2.register_forward_hook(farward_hook)
net.conv2.register_backward_hook(backward_hook)

outs = net(input_img)
loss = loss_func(outs, label)
loss.backward()

print('End.')

上面的程序中,我們先定義了一個簡單的卷積神經(jīng)網(wǎng)絡(luò)模型,我們對第二層卷積模塊進(jìn)行hook注冊,既獲取它的輸入輸出,又獲取輸入輸出的梯度,并將它們分別裝在字典里。為了達(dá)到驗(yàn)證效果,我們隨機(jī)生成一個假圖像,它的尺寸和cifar-10數(shù)據(jù)集的圖像尺寸一致,并且給這個假圖像定義一個類別標(biāo)簽,用損失函數(shù)進(jìn)行反向傳播,模擬神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程。

在IPython中運(yùn)行程序后,相應(yīng)的特征圖和梯度就會出現(xiàn)在兩個列表fmap_block和grad_block中了。我們看一下它們的輸入和輸出的維度:

In [17]: len(fmap_block['input'])                                               
Out[17]: 1

In [18]: len(fmap_block['output'])                                              
Out[18]: 1

In [19]: len(grad_block['grad_in'])                                             
Out[19]: 3

In [20]: len(grad_block['grad_out'])                                            
Out[20]: 1

可以看出,第二層卷積模塊的輸入和輸出都只有一個,即相應(yīng)的特征圖。而輸入端的梯度值有3個,分別為權(quán)重的梯度,偏差的梯度,以及輸入特征圖的梯度。輸出端的梯度值只有一個,即輸出特征圖的梯度。正如上面強(qiáng)調(diào)的,輸入端即使有W, X和b三個,對于前項(xiàng)傳播來說只有X是其輸入,而對于反向傳播來說,3個都是輸入。輸出端3項(xiàng)的梯度值排列的順序是什么呢,我們來看一下3項(xiàng)梯度的具體維度:

In [21]: grad_block['grad_in'][0].shape                                         
Out[21]: torch.Size([1, 6, 16, 16])

In [22]: grad_block['grad_in'][1].shape                                         
Out[22]: torch.Size([9, 6, 3, 3])

In [23]: grad_block['grad_in'][2].shape                                         
Out[23]: torch.Size([9])

從輸出端梯度的維度可以判斷,第一個顯然是特征圖的梯度,第二個則是權(quán)重(卷積核/濾波器)的梯度,第三個是偏置的梯度。為了驗(yàn)證梯度和這些參數(shù)具有同樣的維度,我們再來看看這三個值前向傳播時的維度:

In [24]: fmap_block['input'][0].shape                                           
Out[24]: torch.Size([1, 6, 16, 16])

In [25]: net.conv2.weight.shape         
Out[25]: torch.Size([9, 6, 3, 3])

In [26]: net.conv2.bias.shape                                                   
Out[26]: torch.Size([9])

可以看到,我們的判斷是正確的。

最后需要注意的一點(diǎn)是,如果需要獲取輸入圖像的梯度,一定要將輸入Tensor的requires_grad屬性設(shè)為True。

原創(chuàng)不易,有用請點(diǎn)贊支持~

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容