論文名稱:Cross-Image Pixel Contrasting for Semantic Segmentation
這是一種將對(duì)比學(xué)習(xí)運(yùn)用到全監(jiān)督語義分割里的方法,主要起到輔助訓(xùn)練的作用,在實(shí)際推理部署的時(shí)候,原本用來對(duì)比學(xué)習(xí)的分支是去除的。
主要解決的問題:
- 模型訓(xùn)練的時(shí)候只考慮當(dāng)前一張圖像的內(nèi)容,無法站在整個(gè)數(shù)據(jù)集的內(nèi)容上考慮問題。
創(chuàng)新點(diǎn):
- 基于像素到像素的對(duì)比(pixel to pixel)+像素到區(qū)域的對(duì)比(pixel to region),設(shè)計(jì)了更加高效的memory bank。
- 設(shè)計(jì)了一種更加合理的難樣本采樣策略,Segmentation-Aware Hard Anchor Sampling。
整體結(jié)構(gòu)

通過上圖可以看到,相比常規(guī)的全監(jiān)督語義分割結(jié)構(gòu),本文的方法只是額外增加了一條用于對(duì)比計(jì)算的輔助分支,該分支在實(shí)際推理部署的時(shí)候是去除的,所以對(duì)于語義分割模型本身是不增加推理負(fù)擔(dān)的。
Memory Bank
這個(gè)東西就是一個(gè)數(shù)據(jù)池,保存了歷史數(shù)據(jù)用于對(duì)比計(jì)算,這里面保存的都是經(jīng)過了模型特征提取后的D維特征,本文D=256。
Pixel to Pixel
這個(gè)是針對(duì)整個(gè)數(shù)據(jù)集的圖像來操作的,就是對(duì)所有類設(shè)置一個(gè)專屬的隊(duì)列(整個(gè)數(shù)據(jù)集有多少個(gè)類就有多少個(gè)隊(duì)列,比如COCO數(shù)據(jù)集有80類,那么就有80個(gè)隊(duì)列),訓(xùn)練的時(shí)候從每個(gè)mini-batch中的每個(gè)類選取V個(gè)D維像素加入到對(duì)應(yīng)類的隊(duì)列T里,T是遠(yuǎn)大于V的。一旦T被裝滿了,那么就去舊留新。通過這種方式Memory Bank能動(dòng)態(tài)存儲(chǔ)絕大部分圖像的內(nèi)容特征。
num_pixel = idxs.shape[0]
perm = torch.randperm(num_pixel) #隨機(jī)選擇一定像素
K = min(num_pixel, self.pixel_update_freq) #跟預(yù)設(shè)值比較,減少代碼出錯(cuò)的操作
feat = this_feat[:, perm[:K]]
feat = torch.transpose(feat, 0, 1)
ptr = int(pixel_queue_ptr[lb])
if ptr + K >= self.memory_size: #隊(duì)列滿了則去舊留新
pixel_queue[lb, -K:, :] = nn.functional.normalize(feat, p=2, dim=1)
pixel_queue_ptr[lb] = 0
else:
pixel_queue[lb, ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
pixel_queue_ptr[lb] = (pixel_queue_ptr[lb] + 1) % self.memory_size
上面的源碼可以大致看到,首先會(huì)隨機(jī)選擇一定數(shù)量的像素加入到隊(duì)列,如果隊(duì)列滿了則舊的數(shù)據(jù)會(huì)被新的數(shù)據(jù)代替。
Pixel to Region
就是將區(qū)域的一大塊特征用一個(gè)像素點(diǎn)的特征去表示,主要是用來彌補(bǔ)pixel to pixel 采樣不充分的問題,這個(gè)方法是針對(duì)一張圖像的操作,將多個(gè)一張圖像的特征拼接到一起訓(xùn)練就能獲取全局信息。怎么操作的呢?比如一張圖像上有3個(gè)地方是貓的區(qū)域,首先對(duì)這個(gè)3塊區(qū)域在XY坐標(biāo)上進(jìn)行求平均,最后變?yōu)?個(gè)D維的像素特征(D,1,1),然后再對(duì)這個(gè)3個(gè)像素點(diǎn)在對(duì)應(yīng)通道維度上求平均,最后當(dāng)前圖像的3只貓就被一個(gè)D維的像素點(diǎn)表示。
# segment enqueue and dequeue
feat = torch.mean(this_feat[:, idxs], dim=1).squeeze(1)
ptr = int(segment_queue_ptr[lb])
segment_queue[lb, ptr, :] = nn.functional.normalize(feat.view(-1), p=2, dim=0)
segment_queue_ptr[lb] = (segment_queue_ptr[lb] + 1) % self.memory_size
這個(gè)方法是訓(xùn)練前期開辟一大塊內(nèi)存,然后每一個(gè)mini-batch都會(huì)加入一定的特征進(jìn)去,越到訓(xùn)練后期特征越多,等一輪訓(xùn)練完成后就清空,再重新開始。
困難樣本采樣策略Segmentation-Aware Hard Anchor Sampling
這個(gè)采樣策略其實(shí)很簡(jiǎn)單,相比現(xiàn)有的難負(fù)挖掘采樣策略,它是隨機(jī)采一半困難樣本,剩下的一半就隨機(jī)采樣,這剩下的一本里面應(yīng)該既有困難樣本也有簡(jiǎn)單樣本,這樣做的目的是防止全部使用困難樣本訓(xùn)練導(dǎo)致過擬合。舉個(gè)例子,當(dāng)前是一個(gè)類別為貓的像素特征,首先會(huì)從memory bank中選擇512個(gè)D維的像素點(diǎn),這些像素點(diǎn)屬于狗、羊等其他跟貓?zhí)卣鹘咏膭?dòng)物,或者是貓的但經(jīng)常分類分錯(cuò)的像素點(diǎn),再隨機(jī)從memory bank中選擇一些像素點(diǎn)放一起,源碼中是總共選擇1024個(gè)點(diǎn)用于跟當(dāng)前的限度點(diǎn)進(jìn)行損失計(jì)算。怎么確定是困難樣本還是簡(jiǎn)單樣本呢?就是通過模型的mask圖的像素值跟label值對(duì)不對(duì)的上,mask值跟label值匹配就是困難樣本。