本文首發(fā)自【簡(jiǎn)書】作者【西北小生_】的博客,轉(zhuǎn)載請(qǐng)私聊作者!

一、什么是CAM?
CAM的全稱是Class Activation Mapping或Class Activation Map,即類激活映射或類激活圖。
論文《Learning Deep Features for Discriminative Localization》發(fā)現(xiàn)了CNN分類模型的一個(gè)有趣的現(xiàn)象:
CNN的最后一層卷積輸出的特征圖,對(duì)其通道進(jìn)行加權(quán)疊加后,其激活值(ReLU激活后的非零值)所在的區(qū)域,即為圖像中的物體所在區(qū)域。而將這一疊加后的單通道特征圖覆蓋到輸入圖像上,即可高亮圖像中物體所在位置區(qū)域。如圖1中的輸入圖像和輸出圖像所示。
該文章作者將實(shí)現(xiàn)這一現(xiàn)象的方法命名為類激活映射,并將特征圖疊加在原始輸入圖像上生成的新圖片命名為類激活圖。
二、CAM有什么用?
CAM一般有兩種用途:
- 可視化模型特征圖,以便觀察模型是通過圖像中的哪些區(qū)域特征來區(qū)分物體類別的;
- 利用卷積神經(jīng)網(wǎng)絡(luò)分類模型進(jìn)行弱監(jiān)督的圖像目標(biāo)定位。
第一種用途是最直接的用途,根據(jù)CAM高亮的圖像區(qū)域,可以直觀地解釋CNN是如何區(qū)分不同類別的物體的。
對(duì)于第二種用途,一般的目標(biāo)定位方法,都需要專門對(duì)圖像中的物體位置區(qū)域進(jìn)行標(biāo)注,并將標(biāo)注信息作為圖像標(biāo)簽的一部分,然后通過訓(xùn)練帶標(biāo)簽的圖像和專門的目標(biāo)定位模型才能實(shí)現(xiàn)定位,是一種強(qiáng)監(jiān)督的方法。而CAM方法不需要物體在圖像中的位置信息,僅僅依靠圖像整體的類別標(biāo)簽訓(xùn)練分類模型,即可找到圖像中物體所在的大致位置并高亮之,因此可以作為一種弱監(jiān)督的目標(biāo)定位方法。
三、CAM原理

如圖2所示,CNN最后一層卷積層輸出的特征圖是三維的:[C, H, W ],設(shè)特征圖的第
其中
表示特征圖的每個(gè)輸出通道首先被平均為一個(gè)值,
個(gè)通道得到
個(gè)值,然后這些值再被加權(quán)相加得到一個(gè)數(shù),這個(gè)數(shù)就是第
類的置信分?jǐn)?shù),表征著輸入圖像的類別是
的可能性大小。
表示首先對(duì)特征圖的每個(gè)通道進(jìn)行加權(quán)求和(
),得到一個(gè)二維的特征圖(通道維坍塌),然后再對(duì)這個(gè)二維特征圖求平均值,得到第
類的置信分?jǐn)?shù)。
由公式(1)的推導(dǎo)可知,先對(duì)特征圖進(jìn)行全局平均池化,再進(jìn)行加權(quán)求和得到類別的置信分?jǐn)?shù),等價(jià)于先對(duì)特征圖進(jìn)行通道維度的加權(quán)求和,再進(jìn)行全局平均池化。
經(jīng)過這一等價(jià)變換,就突顯了特征圖通道加權(quán)和的重要性了:一方面,特征圖的通道加權(quán)和直接編碼了類別信息;另一方面,也是最重要的,特征圖的通道加權(quán)和是二維的,還保留著圖像的空間位置信息。我們可以通過可視化方法觀察到圖像中的相對(duì)位置信息與CNN編碼的類別信息的關(guān)系。
這里的特征圖的通道加權(quán)之和就叫做類別激活圖。
四、CAM的PyTorch實(shí)現(xiàn)
本文以PyTorch自帶的ResNet-18為例,分步驟講解并用代碼實(shí)現(xiàn)CAM的整個(gè)流程和細(xì)節(jié)。
1.準(zhǔn)備工作
首先導(dǎo)入需要用到的包:
import math
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from typing import Optional, List
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
from torch import Tensor
from matplotlib import cm
from torchvision.transforms.functional import to_pil_image
定義輸入圖片路徑,和保存輸出的類激活圖的路徑:
img_path = '/home/dell/img/1.JPEG' # 輸入圖片的路徑
save_path = '/home/dell/cam/CAM1.png' # 類激活圖保存路徑
定義輸入圖片預(yù)處理方式。由于本文用的輸入圖片來自ILSVRC-2012驗(yàn)證集,因此采用PyTorch官方文檔提供的ImageNet驗(yàn)證集處理流程:
preprocess = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
2.獲取CNN最后一層卷積層的輸出特征圖
本文選用的CNN模型是PyTorch自帶的ResNet-18,首先導(dǎo)入預(yù)訓(xùn)練模型:
net = models.resnet18(pretrained=True).cuda() # 導(dǎo)入模型
由于特征圖是模型前向傳播時(shí)的中間變量,不能直接從模型中獲取,需要用到PyTorch提供的hook工具,補(bǔ)課請(qǐng)參考我的這兩篇博客:hook1,hook2。
通過輸出模型(print(net))我們就能看到ResNet-18輸出最后一層特征圖的層為net.layer4(或者net.layer4[1]、net.layer4[1].bn2都可)。我們用hook工具注冊(cè)這一層,以便獲得它的輸出特征圖:
feature_map = [] # 建立列表容器,用于盛放輸出特征圖
def forward_hook(module, inp, outp): # 定義hook
feature_map.append(outp) # 把輸出裝入字典feature_map
net.layer4.register_forward_hook(forward_hook) # 對(duì)net.layer4這一層注冊(cè)前向傳播
做好了hook的定義和注冊(cè)工作,現(xiàn)在只需要對(duì)輸入圖片進(jìn)行預(yù)處理,然后執(zhí)行一次模型前向傳播即可獲得CNN最后一層卷積層的輸出特征圖:
orign_img = Image.open(img_path).convert('RGB') # 打開圖片并轉(zhuǎn)換為RGB模型
img = preprocess(orign_img) # 圖片預(yù)處理
img = torch.unsqueeze(img, 0) # 增加batch維度 [1, 3, 224, 224]
with torch.no_grad():
out = net(img.cuda()) # 前向傳播
這時(shí)我們想要的特征圖已經(jīng)裝在列表feature_map中了。我們輸出尺寸來驗(yàn)證一下:
In [10]: print(feature_map[0].size())
torch.Size([1, 512, 7, 7])
3.獲取權(quán)重
CAM使用的權(quán)重是全連接輸出層中,對(duì)應(yīng)這張圖像所屬類別的權(quán)重。文字表述可能存在歧義或不清楚,直接看本文最上面的圖中全連接層被著色的連接??梢钥吹?,每個(gè)連接對(duì)應(yīng)一個(gè)權(quán)重值,左邊和特征圖的每個(gè)通道(全局平均池化后)一一連接,右邊全都連接著輸出類別所對(duì)應(yīng)的那個(gè)神經(jīng)元。
由于我也不知道這張圖的類別標(biāo)簽,這里假設(shè)模型對(duì)這張圖像分類正確,我們來獲得其輸出類別所對(duì)應(yīng)的權(quán)重:
cls = torch.argmax(out).item() # 獲取預(yù)測(cè)類別編碼
weights = net._modules.get('fc').weight.data[cls,:] # 獲取類別對(duì)應(yīng)的權(quán)重
4.對(duì)特征圖的通道進(jìn)行加權(quán)疊加,獲得CAM
cam = (weights.view(*weights.shape, 1, 1) * feature_map[0].squeeze(0)).sum(0)
這里的代碼比較簡(jiǎn)單,擴(kuò)充權(quán)重的維度([512, ][512, 1, 1])是為了使之在通道上與特征圖相乘;去除特征圖的batch維([1, 512, 7, 7]
[512, 7, 7])是為了使其維度和weights擴(kuò)充后的維度相同以相乘。最后在第一維(通道維)上相加求和,得到一個(gè)
的類激活圖。
5.對(duì)CAM進(jìn)行ReLU激活和歸一化
這一步有兩個(gè)細(xì)節(jié)需要注意:
- 上步得到的類激活圖像素值分布雜亂,要想確定目標(biāo)位置,須先進(jìn)行ReLU激活,將正值保留,負(fù)值置零。像素值正值所在的(一個(gè)或多個(gè))區(qū)域即為目標(biāo)定位區(qū)域。
- 上步獲得的激活圖還只是一個(gè)普通矩陣,需要變換成圖像規(guī)格,將其值歸一化到[0,1]之間。
我們首先定義歸一化函數(shù):
def _normalize(cams: Tensor) -> Tensor:
"""CAM normalization"""
cams.sub_(cams.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1))
cams.div_(cams.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1))
return cams
然后對(duì)類激活圖執(zhí)行ReLU激活和歸一化,并利用PyTorch的 to_pil_image函數(shù)將其轉(zhuǎn)換為PIL格式以便下步處理:
cam = _normalize(F.relu(cam, inplace=True)).cpu()
mask = to_pil_image(cam.detach().numpy(), mode='F')
將類激活圖轉(zhuǎn)換成PIL格式是為了方便下一步和輸入圖像融合,因?yàn)楸纠形覀冞x用的PIL庫將輸入圖像打開,選用PIL庫也是因?yàn)镻yTorch處理圖像時(shí)默認(rèn)的圖像格式是PIL格式的。
6.將類激活圖覆蓋到輸入圖像上,實(shí)現(xiàn)目標(biāo)定位
這一步也有很多細(xì)節(jié)需要注意:
- 上步得到的類激活圖只有
的尺寸,想要將其覆蓋在輸入圖像上顯示,就需將其用插值的方法擴(kuò)大到和輸入圖像相同大小。
- 我們的目的是用類激活圖中被激活(非零值)的位置區(qū)域,來高亮原始圖像中相應(yīng)的位置區(qū)域,這一高亮的方法就是將激活圖變換為熱力圖的形式:值越大的像素顏色越紅,值越小的像素顏色越藍(lán)。
- 如果直接將熱力圖覆蓋到原始輸入圖像上,會(huì)遮蔽圖像中的內(nèi)容導(dǎo)致不容易觀察,因此需要設(shè)置兩個(gè)圖像融合的比例(透明度),即在兩種圖像融合在一起時(shí),將原始輸入圖像的像素值權(quán)重設(shè)置大一些,而把熱力圖的像素值權(quán)重設(shè)置小一些,這樣就會(huì)使生成圖像中原始輸入圖像的內(nèi)容更加清晰,易于觀察。(mixup方法同理)
- 兩種圖像融合后的像素值會(huì)超出圖像規(guī)格像素值的范圍[0,1],因此還需要將其轉(zhuǎn)換為圖像規(guī)格。
我們將兩個(gè)圖像交疊融合的過程封裝成了函數(shù):
def overlay_mask(img: Image.Image, mask: Image.Image, colormap: str = 'jet', alpha: float = 0.6) -> Image.Image:
"""Overlay a colormapped mask on a background image
Args:
img: background image
mask: mask to be overlayed in grayscale
colormap: colormap to be applied on the mask
alpha: transparency of the background image
Returns:
overlayed image
"""
if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image):
raise TypeError('img and mask arguments need to be PIL.Image')
if not isinstance(alpha, float) or alpha < 0 or alpha >= 1:
raise ValueError('alpha argument is expected to be of type float between 0 and 1')
cmap = cm.get_cmap(colormap)
# Resize mask and apply colormap
overlay = mask.resize(img.size, resample=Image.BICUBIC)
overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 1:]).astype(np.uint8)
# Overlay the image with the mask
overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))
return overlayed_img
接下來就是激動(dòng)人心的時(shí)刻了?。。㈩惣せ顖D作為掩碼,以一定的比例覆蓋到原始輸入圖像上,生成類激活圖:
result = overlay_mask(orign_img, mask)
這里的變量result已經(jīng)是有著PIL圖片格式的類激活圖了,我們可以通過:
result.show()
可視化輸出,也可以通過:
result.save(save_path)
將圖片保存在本地查看。我們?cè)谶@里展示一下輸入圖像和輸出定位圖像的對(duì)比:
