CAM系列(一)之CAM(原理講解和PyTorch代碼實(shí)現(xiàn))

本文首發(fā)自【簡(jiǎn)書】作者【西北小生_】的博客,轉(zhuǎn)載請(qǐng)私聊作者!


圖1 CAM實(shí)現(xiàn)示意圖

一、什么是CAM?

CAM的全稱是Class Activation MappingClass Activation Map,即類激活映射類激活圖。

論文《Learning Deep Features for Discriminative Localization》發(fā)現(xiàn)了CNN分類模型的一個(gè)有趣的現(xiàn)象:
CNN的最后一層卷積輸出的特征圖,對(duì)其通道進(jìn)行加權(quán)疊加后,其激活值(ReLU激活后的非零值)所在的區(qū)域,即為圖像中的物體所在區(qū)域。而將這一疊加后的單通道特征圖覆蓋到輸入圖像上,即可高亮圖像中物體所在位置區(qū)域。如圖1中的輸入圖像和輸出圖像所示。

該文章作者將實(shí)現(xiàn)這一現(xiàn)象的方法命名為類激活映射,并將特征圖疊加在原始輸入圖像上生成的新圖片命名為類激活圖。

二、CAM有什么用?

CAM一般有兩種用途:

  • 可視化模型特征圖,以便觀察模型是通過圖像中的哪些區(qū)域特征來區(qū)分物體類別的;
  • 利用卷積神經(jīng)網(wǎng)絡(luò)分類模型進(jìn)行弱監(jiān)督的圖像目標(biāo)定位

第一種用途是最直接的用途,根據(jù)CAM高亮的圖像區(qū)域,可以直觀地解釋CNN是如何區(qū)分不同類別的物體的。

對(duì)于第二種用途,一般的目標(biāo)定位方法,都需要專門對(duì)圖像中的物體位置區(qū)域進(jìn)行標(biāo)注,并將標(biāo)注信息作為圖像標(biāo)簽的一部分,然后通過訓(xùn)練帶標(biāo)簽的圖像和專門的目標(biāo)定位模型才能實(shí)現(xiàn)定位,是一種強(qiáng)監(jiān)督的方法。而CAM方法不需要物體在圖像中的位置信息,僅僅依靠圖像整體的類別標(biāo)簽訓(xùn)練分類模型,即可找到圖像中物體所在的大致位置并高亮之,因此可以作為一種弱監(jiān)督的目標(biāo)定位方法。

三、CAM原理

圖2 輸出結(jié)構(gòu)示意圖

如圖2所示,CNN最后一層卷積層輸出的特征圖是三維的:[C, H, W ],設(shè)特征圖的第k個(gè)通道可表示為f_k(x,y),其中x,y分別是寬和高維度上的索引。若最后一個(gè)卷積層連接一個(gè)全局平均池化層,然后再由一個(gè)全連接層輸出分類結(jié)果,則由最后一個(gè)卷積層的輸出特征圖到輸出層中的第c個(gè)類別的置信分?jǐn)?shù)(未進(jìn)行Softmax映射前)的計(jì)算過程可表示為:
S_c=\sum_{k}w_{k}^{c} \sum_{x,y}f_{k}(x,y)=\sum_{x,y} \sum_{k} w_{k}^{c} f_{k}(x,y) \tag{1}
其中\sum_{x,y}f_{k}(x,y)為全局平均池化(省略了除以元素總數(shù)),由于只對(duì)空間上到寬和高兩個(gè)維度求和,結(jié)果就是這兩個(gè)維度坍塌,只剩通道維度保持不變,即計(jì)算結(jié)果為C個(gè)數(shù)值,每個(gè)值代表著該通道上所有值的平均值。w_{k}^{c}表示全連接輸出層中第c類對(duì)應(yīng)的C個(gè)權(quán)重中的第k個(gè):即全連接層的權(quán)重矩陣W[N_o,C]維的(N_o即輸出類別數(shù),C是最后一層卷積層的輸出通道數(shù)),那么第c類對(duì)應(yīng)的權(quán)重w^c就應(yīng)該是W[c,:],w^c有著C個(gè)權(quán)重參數(shù),對(duì)應(yīng)著每個(gè)輸入值(即全局平均池化的結(jié)果),w^c_k就是這C個(gè)權(quán)重參數(shù)中的第k個(gè)數(shù)。

\sum_{k}w_{k}^{c} \sum_{x,y}f_{k}(x,y)表示特征圖的每個(gè)輸出通道首先被平均為一個(gè)值,C個(gè)通道得到C個(gè)值,然后這些值再被加權(quán)相加得到一個(gè)數(shù),這個(gè)數(shù)就是第c類的置信分?jǐn)?shù),表征著輸入圖像的類別是c的可能性大小。

\sum_{x,y} \sum_{k} w_{k}^{c} f_{k}(x,y)表示首先對(duì)特征圖的每個(gè)通道進(jìn)行加權(quán)求和(\sum_{k} w_{k}^{c} f_{k}(x,y)),得到一個(gè)二維的特征圖(通道維坍塌),然后再對(duì)這個(gè)二維特征圖求平均值,得到第c類的置信分?jǐn)?shù)。

由公式(1)的推導(dǎo)可知,先對(duì)特征圖進(jìn)行全局平均池化,再進(jìn)行加權(quán)求和得到類別的置信分?jǐn)?shù),等價(jià)于先對(duì)特征圖進(jìn)行通道維度的加權(quán)求和,再進(jìn)行全局平均池化。

經(jīng)過這一等價(jià)變換,就突顯了特征圖通道加權(quán)和\sum_{k} w_{k}^{c} f_{k}(x,y)的重要性了:一方面,特征圖的通道加權(quán)和直接編碼了類別信息;另一方面,也是最重要的,特征圖的通道加權(quán)和是二維的,還保留著圖像的空間位置信息。我們可以通過可視化方法觀察到圖像中的相對(duì)位置信息與CNN編碼的類別信息的關(guān)系。

這里的特征圖的通道加權(quán)之和\sum_{k} w_{k}^{c} f_{k}(x,y)就叫做類別激活圖。

四、CAM的PyTorch實(shí)現(xiàn)

本文以PyTorch自帶的ResNet-18為例,分步驟講解并用代碼實(shí)現(xiàn)CAM的整個(gè)流程和細(xì)節(jié)。

1.準(zhǔn)備工作

首先導(dǎo)入需要用到的包:

import math
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from typing import Optional, List
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
from torch import Tensor
from matplotlib import cm
from torchvision.transforms.functional import to_pil_image

定義輸入圖片路徑,和保存輸出的類激活圖的路徑:

img_path = '/home/dell/img/1.JPEG'     # 輸入圖片的路徑
save_path = '/home/dell/cam/CAM1.png'    # 類激活圖保存路徑

定義輸入圖片預(yù)處理方式。由于本文用的輸入圖片來自ILSVRC-2012驗(yàn)證集,因此采用PyTorch官方文檔提供的ImageNet驗(yàn)證集處理流程:

preprocess = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
2.獲取CNN最后一層卷積層的輸出特征圖

本文選用的CNN模型是PyTorch自帶的ResNet-18,首先導(dǎo)入預(yù)訓(xùn)練模型:

net = models.resnet18(pretrained=True).cuda()   # 導(dǎo)入模型

由于特征圖是模型前向傳播時(shí)的中間變量,不能直接從模型中獲取,需要用到PyTorch提供的hook工具,補(bǔ)課請(qǐng)參考我的這兩篇博客:hook1,hook2。

通過輸出模型(print(net))我們就能看到ResNet-18輸出最后一層特征圖的層為net.layer4(或者net.layer4[1]、net.layer4[1].bn2都可)。我們用hook工具注冊(cè)這一層,以便獲得它的輸出特征圖:

feature_map = []     # 建立列表容器,用于盛放輸出特征圖

def forward_hook(module, inp, outp):     # 定義hook
    feature_map.append(outp)    # 把輸出裝入字典feature_map

net.layer4.register_forward_hook(forward_hook)    # 對(duì)net.layer4這一層注冊(cè)前向傳播

做好了hook的定義和注冊(cè)工作,現(xiàn)在只需要對(duì)輸入圖片進(jìn)行預(yù)處理,然后執(zhí)行一次模型前向傳播即可獲得CNN最后一層卷積層的輸出特征圖:

orign_img = Image.open(img_path).convert('RGB')    # 打開圖片并轉(zhuǎn)換為RGB模型
img = preprocess(orign_img)     # 圖片預(yù)處理
img = torch.unsqueeze(img, 0)     # 增加batch維度 [1, 3, 224, 224]

with torch.no_grad():
    out = net(img.cuda())     # 前向傳播

這時(shí)我們想要的特征圖已經(jīng)裝在列表feature_map中了。我們輸出尺寸來驗(yàn)證一下:

In [10]: print(feature_map[0].size())
torch.Size([1, 512, 7, 7])
3.獲取權(quán)重

CAM使用的權(quán)重是全連接輸出層中,對(duì)應(yīng)這張圖像所屬類別的權(quán)重。文字表述可能存在歧義或不清楚,直接看本文最上面的圖中全連接層被著色的連接??梢钥吹?,每個(gè)連接對(duì)應(yīng)一個(gè)權(quán)重值,左邊和特征圖的每個(gè)通道(全局平均池化后)一一連接,右邊全都連接著輸出類別所對(duì)應(yīng)的那個(gè)神經(jīng)元。

由于我也不知道這張圖的類別標(biāo)簽,這里假設(shè)模型對(duì)這張圖像分類正確,我們來獲得其輸出類別所對(duì)應(yīng)的權(quán)重:

cls = torch.argmax(out).item()    # 獲取預(yù)測(cè)類別編碼
weights = net._modules.get('fc').weight.data[cls,:]    # 獲取類別對(duì)應(yīng)的權(quán)重
4.對(duì)特征圖的通道進(jìn)行加權(quán)疊加,獲得CAM
cam = (weights.view(*weights.shape, 1, 1) * feature_map[0].squeeze(0)).sum(0)

這里的代碼比較簡(jiǎn)單,擴(kuò)充權(quán)重的維度([512, ]\rightarrow[512, 1, 1])是為了使之在通道上與特征圖相乘;去除特征圖的batch維([1, 512, 7, 7]\rightarrow[512, 7, 7])是為了使其維度和weights擴(kuò)充后的維度相同以相乘。最后在第一維(通道維)上相加求和,得到一個(gè)7\times 7的類激活圖。

5.對(duì)CAM進(jìn)行ReLU激活和歸一化

這一步有兩個(gè)細(xì)節(jié)需要注意:

  • 上步得到的類激活圖像素值分布雜亂,要想確定目標(biāo)位置,須先進(jìn)行ReLU激活,將正值保留,負(fù)值置零。像素值正值所在的(一個(gè)或多個(gè))區(qū)域即為目標(biāo)定位區(qū)域。
  • 上步獲得的激活圖還只是一個(gè)普通矩陣,需要變換成圖像規(guī)格,將其值歸一化到[0,1]之間。

我們首先定義歸一化函數(shù):

def _normalize(cams: Tensor) -> Tensor:
        """CAM normalization"""
        cams.sub_(cams.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1))
        cams.div_(cams.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1))

        return cams

然后對(duì)類激活圖執(zhí)行ReLU激活和歸一化,并利用PyTorch的 to_pil_image函數(shù)將其轉(zhuǎn)換為PIL格式以便下步處理:

cam = _normalize(F.relu(cam, inplace=True)).cpu()
mask = to_pil_image(cam.detach().numpy(), mode='F')

將類激活圖轉(zhuǎn)換成PIL格式是為了方便下一步和輸入圖像融合,因?yàn)楸纠形覀冞x用的PIL庫將輸入圖像打開,選用PIL庫也是因?yàn)镻yTorch處理圖像時(shí)默認(rèn)的圖像格式是PIL格式的。

6.將類激活圖覆蓋到輸入圖像上,實(shí)現(xiàn)目標(biāo)定位

這一步也有很多細(xì)節(jié)需要注意:

  • 上步得到的類激活圖只有7\times 7的尺寸,想要將其覆蓋在輸入圖像上顯示,就需將其用插值的方法擴(kuò)大到和輸入圖像相同大小。
  • 我們的目的是用類激活圖中被激活(非零值)的位置區(qū)域,來高亮原始圖像中相應(yīng)的位置區(qū)域,這一高亮的方法就是將激活圖變換為熱力圖的形式:值越大的像素顏色越紅,值越小的像素顏色越藍(lán)。
  • 如果直接將熱力圖覆蓋到原始輸入圖像上,會(huì)遮蔽圖像中的內(nèi)容導(dǎo)致不容易觀察,因此需要設(shè)置兩個(gè)圖像融合的比例(透明度),即在兩種圖像融合在一起時(shí),將原始輸入圖像的像素值權(quán)重設(shè)置大一些,而把熱力圖的像素值權(quán)重設(shè)置小一些,這樣就會(huì)使生成圖像中原始輸入圖像的內(nèi)容更加清晰,易于觀察。(mixup方法同理)
  • 兩種圖像融合后的像素值會(huì)超出圖像規(guī)格像素值的范圍[0,1],因此還需要將其轉(zhuǎn)換為圖像規(guī)格。

我們將兩個(gè)圖像交疊融合的過程封裝成了函數(shù):

def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = 'jet', alpha: float = 0.6) -> Image.Image:
    """Overlay a colormapped mask on a background image

    Args:
        img: background image
        mask: mask to be overlayed in grayscale
        colormap: colormap to be applied on the mask
        alpha: transparency of the background image

    Returns:
        overlayed image
    """

    if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
        raise TypeError('img and mask arguments need to be PIL.Image')

    if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
        raise ValueError('alpha argument is expected to be of type float between 0 and 1')

    cmap = cm.get_cmap(colormap)    
    # Resize mask and apply colormap
    overlay = mask.resize(img.size, resample=Image.BICUBIC)
    overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 1:]).astype(np.uint8)
    # Overlay the image with the mask
    overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))

    return overlayed_img

接下來就是激動(dòng)人心的時(shí)刻了?。。㈩惣せ顖D作為掩碼,以一定的比例覆蓋到原始輸入圖像上,生成類激活圖:

result = overlay_mask(orign_img, mask) 

這里的變量result已經(jīng)是有著PIL圖片格式的類激活圖了,我們可以通過:

result.show()

可視化輸出,也可以通過:

result.save(save_path)

將圖片保存在本地查看。我們?cè)谶@里展示一下輸入圖像和輸出定位圖像的對(duì)比:


(左)輸入圖像;(右)定位圖像
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容