論文 | NeurIPS2020 CrossTransformers:spatially-aware few-shot transfer

一 寫在前面

未經(jīng)允許,不得轉(zhuǎn)載,謝謝~~~

嘿,好久不見,我要開始慢慢恢復(fù)科研論文筆記的更新啦~

今天分享的文章是做小樣本圖像識(shí)別的。

主要信息:

二 主要內(nèi)容

2.1 相關(guān)背景

小樣本圖像識(shí)別的方法從整體上來看大概可以分成兩個(gè)階段:

  1. representation learning:獲取到一個(gè)比較好的圖像特征提取器;
  2. classifier:通過比對(duì)query images和support images進(jìn)行query image的標(biāo)簽預(yù)測(cè)

文章首先總結(jié)了現(xiàn)有方法的共同點(diǎn):

  1. 在representation leanring的學(xué)習(xí)上具有的一個(gè)共同點(diǎn)就是都會(huì)使用訓(xùn)練圖片的類別標(biāo)簽做一個(gè)監(jiān)督學(xué)習(xí);
  2. 在classifier學(xué)習(xí)階段具有的一個(gè)共同點(diǎn)是會(huì)將query和support圖像之間的整體特征進(jìn)行比較,例如ProtoNet就是將query的特征與support set中每個(gè)類中心的特征進(jìn)行比較。

2.2 本文工作

文章首先支持現(xiàn)有方法的不足:

  1. 完全依靠類別標(biāo)簽進(jìn)行特征學(xué)習(xí)的方式會(huì)導(dǎo)致只能學(xué)習(xí)到跟類別相關(guān)的信息,而忽略其他更加通用的特征表示;
  2. 在做圖像比較的時(shí)候,圖像中的一些重要objects和scenes通常是local的,直接用整體特征進(jìn)行比較的效果不一定是最好的;

相對(duì)應(yīng)地,文章提出了從兩個(gè)方面進(jìn)行優(yōu)化:

  1. 針對(duì)第一個(gè)問題,提出引入自監(jiān)督學(xué)習(xí)的方法SimCLR來獲取更加通用的圖像特征表示;
  2. 針對(duì)第二個(gè)問題,提出基于Transformer的新結(jié)構(gòu)CrossTransformer,希望能夠進(jìn)行l(wèi)ocal信息的圖像匹配;

三 方法介紹

文章是基于ProtoNet結(jié)構(gòu)的,所以首先介紹下ProtoNet, 然后分別介紹以上兩點(diǎn)novelty。

3.1 ProtoNet

ProtoNet算的上是小樣本圖像識(shí)別領(lǐng)域最flagship的工作了,這里只做個(gè)簡(jiǎn)單的介紹。

N-way-K-shot
給定一堆帶標(biāo)簽可供參考的Support Images,具體表示為有N個(gè)類別,每個(gè)類別有K張帶標(biāo)注的圖像,以及一個(gè)等待被分類的query image (query image的類別一定屬于N個(gè)類別),我們需要根據(jù)support images預(yù)測(cè)出query image的類別標(biāo)簽。

key idea:
Protonet的想法非常直接但有效。即對(duì)每張圖像都先用神經(jīng)網(wǎng)絡(luò)得到一個(gè)特征表示,然后對(duì)support set中每個(gè)類別c的所有特征取一個(gè)平均,作為這個(gè)類別的類中心。最后比較query feature跟各個(gè)類中心之間的距離,取最近的一個(gè)類別作為預(yù)測(cè)結(jié)果。

3.2 SSL with SimCLR

這里的想法也比較直接,就是覺得自監(jiān)督學(xué)習(xí)得到的特征表示不僅對(duì)semantic敏感,而且對(duì)屬于相同類別的不同圖片也具有區(qū)分度,可以理解為只用class informaction進(jìn)行監(jiān)督學(xué)習(xí)得到的特征是class-level的,SSL學(xué)習(xí)到的是instance-level的,因此作者認(rèn)為SSL學(xué)習(xí)到的特征泛化性會(huì)更好。

具體的做法也比較簡(jiǎn)單。為了區(qū)分原來的episode和現(xiàn)在用自監(jiān)督的episode, 分別用MD-categorization episode以及SimCLR episode來表示它們。在訓(xùn)練的過程中隨機(jī)轉(zhuǎn)化50%的MD-categorization episode為SimCLR episode, 對(duì)SimCLR episode用SimCLR中的方法進(jìn)行增強(qiáng),然后對(duì)query image也進(jìn)行增強(qiáng),最后用各自對(duì)應(yīng)的loss function進(jìn)行優(yōu)化。

:( 這邊的具體細(xì)節(jié)感覺只看文章還不是特別清楚,可能需要感興趣的同學(xué)可以自己看看他們的code

3.3 CrossTransformers

這部分都是基于Transformer構(gòu)建的,如果之前完全不了解的話或許是會(huì)比較困難的,建議看看原文:https://arxiv.org/abs/1706.03762, 或者推薦一個(gè)我個(gè)人最推薦的blog:https://zhuanlan.zhihu.com/p/48508221。

文章的主要框架圖如下圖所示:

第一張是文章原圖,第二張是我在原圖的基礎(chǔ)上把各個(gè)重要部分對(duì)應(yīng)的數(shù)據(jù)維度標(biāo)注上去以及補(bǔ)充了額外內(nèi)容的圖,可以對(duì)照著看。

文章原圖
帶標(biāo)記圖

主要的pipeline包括以下幾步:

  1. 首先看輸入,給定最左邊的一個(gè)query image x_q, 以及最上面的support set中類別為c的幾個(gè)圖像{x_1^c, x_2^c, ...}, 網(wǎng)絡(luò)的目的是要獲取到一個(gè)query-specific的類中心(不再是原始ProtoNet版本中直接取平均的方法)
  2. 首先注意到不管是對(duì)于query還是對(duì)于support images,都是先用一個(gè)\phi()得到圖像的特征表示,這里文章中用的是ResNet,并且去掉了最后一個(gè)pooling層,所以得到的特征維度為R^{H`^ \times W^` \times D}。
  3. 接下來就是基于query,key,value的attention操作。這里的query是指query image,而key和value都是指support sets。理解這一點(diǎn)對(duì)理解整個(gè)attention還挺重要的。

網(wǎng)絡(luò)圖中的query heads,key heads都是將輸入特征從D維度映射到d_k維度,而value heads將輸入特征從D維度映射到d_v維度。

具體地,(建議對(duì)著圖看)

  • query heads將query特征從R^{H^` \times W^` \times D}維度映射到R^{H^` \times W^` \times d_k}維度(圖中shi黃色的框框);
  • key heads將support特征從R^{H^` \times W^` \times D}維度映射到R^{H^` \times W^` \times d_k}維度(圖中亮黃色的框框,左右兩個(gè)表示的是一樣的意思,看第一個(gè)就行了);
  • value heads將support特征從R^{H^` \times W^` \times D}映射到R^{H^` \times W^` \times d_v}維度(圖中紅色框框,也看其中一個(gè)就可)。
  1. 然后就是計(jì)算query和key之間的attention,我們還是只看一個(gè)query(shi黃色框)和一個(gè)support圖像特征(第一個(gè)亮黃色框框),經(jīng)過映射之后兩個(gè)的特征維度都是R^{H^` \times W^` \times d_k},對(duì)于query中任意一個(gè)位置p和support中的任意一個(gè)位置m,特征維度都是d_k, 通過向量點(diǎn)乘的方法可以得到這2個(gè)點(diǎn)之間的attention值,圖中小黑點(diǎn)在的位置。對(duì)每個(gè)HxW中的點(diǎn)都計(jì)算一次attention,最終就會(huì)得到一張query和一張support的attention mapa_1^c, 當(dāng)然還做了一個(gè)softmax操作得到更新后的attention map\tilde{a_1^c}。對(duì)suppport中的多張圖采取同樣的操作就會(huì)得到多張attention map。

  2. 最后就是利用這些attention maps對(duì)support set中不同圖像的vaule特征進(jìn)行加權(quán)平均。這部分操作可以理解為,對(duì)于<query, support image i>, 對(duì)于HxW中的任意一個(gè)位置,都用其第i張attention map的值乘上對(duì)應(yīng)第i個(gè)紅色框框位置的value,最后把不同support images的結(jié)果值進(jìn)行相加得到最終query-aligned prototype的特征表示,其維度為R^{H^` \times W^` \times d_v}。

  3. 到這里為止我們獲取到了query-aligned prototype R^{H^` \times W^` \times d_v}。 但是要做小樣本預(yù)測(cè)到這里還沒有完全完整,我把第二張圖中把剩下的部分補(bǔ)上了。對(duì)于query image,其實(shí)也用value head做了一個(gè)映射,得到一個(gè)query image的value 特征表示,其維度為R^{H^` \times W^` \times d_v}, 跟prototype的維度是一樣的,這樣就可以比較這兩者之間的距離,進(jìn)而進(jìn)行l(wèi)abel預(yù)測(cè)了.

五 寫在最后

我在寫這個(gè)blog的時(shí)候,盡量避免了公式的出現(xiàn),但可能有些地方解釋的還是有些不好理解,尤其是crossTransformer部分涉及的符號(hào)略多,大家見諒啦。

這篇文章暫時(shí)介紹到這里,最后打個(gè)不那么相關(guān)的廣告,我們做小樣本視頻分類的工作(AMeFu-Net)近期開源了,link: https://github.com/lovelyqian/AMeFu-Net,歡迎大家關(guān)注~

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容