一 寫在前面
未經(jīng)允許,不得轉(zhuǎn)載,謝謝~~~
嘿,好久不見,我要開始慢慢恢復(fù)科研論文筆記的更新啦~
今天分享的文章是做小樣本圖像識(shí)別的。
主要信息:
- 視覺任務(wù): few-shot image classsification
- 文章出處:NeurIPS2020
- 文章代碼:https://github.com/google-research/meta-dataset
- 原文鏈接:https://arxiv.org/abs/2007.11498
二 主要內(nèi)容
2.1 相關(guān)背景
小樣本圖像識(shí)別的方法從整體上來看大概可以分成兩個(gè)階段:
- representation learning:獲取到一個(gè)比較好的圖像特征提取器;
- classifier:通過比對(duì)query images和support images進(jìn)行query image的標(biāo)簽預(yù)測(cè)
文章首先總結(jié)了現(xiàn)有方法的共同點(diǎn):
- 在representation leanring的學(xué)習(xí)上具有的一個(gè)共同點(diǎn)就是都會(huì)使用訓(xùn)練圖片的類別標(biāo)簽做一個(gè)監(jiān)督學(xué)習(xí);
- 在classifier學(xué)習(xí)階段具有的一個(gè)共同點(diǎn)是會(huì)將query和support圖像之間的整體特征進(jìn)行比較,例如ProtoNet就是將query的特征與support set中每個(gè)類中心的特征進(jìn)行比較。
2.2 本文工作
文章首先支持現(xiàn)有方法的不足:
- 完全依靠類別標(biāo)簽進(jìn)行特征學(xué)習(xí)的方式會(huì)導(dǎo)致只能學(xué)習(xí)到跟類別相關(guān)的信息,而忽略其他更加通用的特征表示;
- 在做圖像比較的時(shí)候,圖像中的一些重要objects和scenes通常是local的,直接用整體特征進(jìn)行比較的效果不一定是最好的;
相對(duì)應(yīng)地,文章提出了從兩個(gè)方面進(jìn)行優(yōu)化:
- 針對(duì)第一個(gè)問題,提出引入自監(jiān)督學(xué)習(xí)的方法SimCLR來獲取更加通用的圖像特征表示;
- 針對(duì)第二個(gè)問題,提出基于Transformer的新結(jié)構(gòu)CrossTransformer,希望能夠進(jìn)行l(wèi)ocal信息的圖像匹配;
三 方法介紹
文章是基于ProtoNet結(jié)構(gòu)的,所以首先介紹下ProtoNet, 然后分別介紹以上兩點(diǎn)novelty。
3.1 ProtoNet
ProtoNet算的上是小樣本圖像識(shí)別領(lǐng)域最flagship的工作了,這里只做個(gè)簡(jiǎn)單的介紹。
N-way-K-shot
給定一堆帶標(biāo)簽可供參考的Support Images,具體表示為有N個(gè)類別,每個(gè)類別有K張帶標(biāo)注的圖像,以及一個(gè)等待被分類的query image (query image的類別一定屬于N個(gè)類別),我們需要根據(jù)support images預(yù)測(cè)出query image的類別標(biāo)簽。
key idea:
Protonet的想法非常直接但有效。即對(duì)每張圖像都先用神經(jīng)網(wǎng)絡(luò)得到一個(gè)特征表示,然后對(duì)support set中每個(gè)類別c的所有特征取一個(gè)平均,作為這個(gè)類別的類中心。最后比較query feature跟各個(gè)類中心之間的距離,取最近的一個(gè)類別作為預(yù)測(cè)結(jié)果。
3.2 SSL with SimCLR
這里的想法也比較直接,就是覺得自監(jiān)督學(xué)習(xí)得到的特征表示不僅對(duì)semantic敏感,而且對(duì)屬于相同類別的不同圖片也具有區(qū)分度,可以理解為只用class informaction進(jìn)行監(jiān)督學(xué)習(xí)得到的特征是class-level的,SSL學(xué)習(xí)到的是instance-level的,因此作者認(rèn)為SSL學(xué)習(xí)到的特征泛化性會(huì)更好。
具體的做法也比較簡(jiǎn)單。為了區(qū)分原來的episode和現(xiàn)在用自監(jiān)督的episode, 分別用MD-categorization episode以及SimCLR episode來表示它們。在訓(xùn)練的過程中隨機(jī)轉(zhuǎn)化50%的MD-categorization episode為SimCLR episode, 對(duì)SimCLR episode用SimCLR中的方法進(jìn)行增強(qiáng),然后對(duì)query image也進(jìn)行增強(qiáng),最后用各自對(duì)應(yīng)的loss function進(jìn)行優(yōu)化。
:( 這邊的具體細(xì)節(jié)感覺只看文章還不是特別清楚,可能需要感興趣的同學(xué)可以自己看看他們的code
3.3 CrossTransformers
這部分都是基于Transformer構(gòu)建的,如果之前完全不了解的話或許是會(huì)比較困難的,建議看看原文:https://arxiv.org/abs/1706.03762, 或者推薦一個(gè)我個(gè)人最推薦的blog:https://zhuanlan.zhihu.com/p/48508221。
文章的主要框架圖如下圖所示:
第一張是文章原圖,第二張是我在原圖的基礎(chǔ)上把各個(gè)重要部分對(duì)應(yīng)的數(shù)據(jù)維度標(biāo)注上去以及補(bǔ)充了額外內(nèi)容的圖,可以對(duì)照著看。


主要的pipeline包括以下幾步:
- 首先看輸入,給定最左邊的一個(gè)query image
, 以及最上面的support set中類別為c的幾個(gè)圖像{
,
, ...}, 網(wǎng)絡(luò)的目的是要獲取到一個(gè)query-specific的類中心(不再是原始ProtoNet版本中直接取平均的方法)
- 首先注意到不管是對(duì)于query還是對(duì)于support images,都是先用一個(gè)
得到圖像的特征表示,這里文章中用的是ResNet,并且去掉了最后一個(gè)pooling層,所以得到的特征維度為
。
- 接下來就是基于query,key,value的attention操作。這里的query是指query image,而key和value都是指support sets。理解這一點(diǎn)對(duì)理解整個(gè)attention還挺重要的。
網(wǎng)絡(luò)圖中的query heads,key heads都是將輸入特征從維度映射到
維度,而value heads將輸入特征從
維度映射到
維度。
具體地,(建議對(duì)著圖看)
- query heads將query特征從
維度映射到
維度(圖中shi黃色的框框);
- key heads將support特征從
維度映射到
維度(圖中亮黃色的框框,左右兩個(gè)表示的是一樣的意思,看第一個(gè)就行了);
- value heads將support特征從
映射到
維度(圖中紅色框框,也看其中一個(gè)就可)。
然后就是計(jì)算query和key之間的attention,我們還是只看一個(gè)query(shi黃色框)和一個(gè)support圖像特征(第一個(gè)亮黃色框框),經(jīng)過映射之后兩個(gè)的特征維度都是
,對(duì)于query中任意一個(gè)位置p和support中的任意一個(gè)位置m,特征維度都是
, 通過向量點(diǎn)乘的方法可以得到這2個(gè)點(diǎn)之間的attention值,圖中小黑點(diǎn)在的位置。對(duì)每個(gè)HxW中的點(diǎn)都計(jì)算一次attention,最終就會(huì)得到一張query和一張support的attention map
, 當(dāng)然還做了一個(gè)softmax操作得到更新后的attention map
。對(duì)suppport中的多張圖采取同樣的操作就會(huì)得到多張attention map。
最后就是利用這些attention maps對(duì)support set中不同圖像的vaule特征進(jìn)行加權(quán)平均。這部分操作可以理解為,對(duì)于<query, support image i>, 對(duì)于HxW中的任意一個(gè)位置,都用其第i張attention map的值乘上對(duì)應(yīng)第i個(gè)紅色框框位置的value,最后把不同support images的結(jié)果值進(jìn)行相加得到最終query-aligned prototype的特征表示,其維度為
。
到這里為止我們獲取到了query-aligned prototype
。 但是要做小樣本預(yù)測(cè)到這里還沒有完全完整,我把第二張圖中把剩下的部分補(bǔ)上了。對(duì)于query image,其實(shí)也用value head做了一個(gè)映射,得到一個(gè)query image的value 特征表示,其維度為
, 跟prototype的維度是一樣的,這樣就可以比較這兩者之間的距離,進(jìn)而進(jìn)行l(wèi)abel預(yù)測(cè)了.
五 寫在最后
我在寫這個(gè)blog的時(shí)候,盡量避免了公式的出現(xiàn),但可能有些地方解釋的還是有些不好理解,尤其是crossTransformer部分涉及的符號(hào)略多,大家見諒啦。
這篇文章暫時(shí)介紹到這里,最后打個(gè)不那么相關(guān)的廣告,我們做小樣本視頻分類的工作(AMeFu-Net)近期開源了,link: https://github.com/lovelyqian/AMeFu-Net,歡迎大家關(guān)注~