比賽目標(biāo):
本次競賽的目標(biāo)是確定來自幾個(gè)不同器官的活檢切片中每個(gè)功能組織單元(FTU)的位置。基礎(chǔ)數(shù)據(jù)包括來自不同來源的圖像,這些圖像采用不同的協(xié)議以不同的分辨率制作,反映了處理醫(yī)療數(shù)據(jù)的典型挑戰(zhàn)。本次比賽使用了來自兩個(gè)不同聯(lián)盟的數(shù)據(jù),即人類蛋白質(zhì)圖譜(HPA)和人類生物分子圖譜計(jì)劃(HuBMAP)。訓(xùn)練數(shù)據(jù)集由公共HPA數(shù)據(jù)中的數(shù)據(jù)組成,公共測試集是私有HPA數(shù)據(jù)和HuBMAP數(shù)據(jù)的組合,私有測試集僅包含HuBMAP數(shù)據(jù)。當(dāng)使用不同協(xié)議下的數(shù)據(jù)時(shí),調(diào)整模型以使其正常工作將是這場競爭的核心挑戰(zhàn)之一。開發(fā)泛化的模型是這項(xiàng)工作的關(guān)鍵目標(biāo)。
難點(diǎn)
In this competition we have:
- Private dataset has a different set of image scales compared to the train (relatively easy to model)
- Private dataset has a different color domain (different stains which attach to different molecules/tissues) - (Harder to model)
- Different slice thickness… That's going to be tough to incorporate.
Color Domain
數(shù)字病理切片的制作首先需要組織染色。為了突出切片中特定的細(xì)胞核和腺體特征,限定并檢查組織,通常使用染色劑來增強(qiáng)組織成分間的對(duì)比度,,主要包括蘇木精-伊紅(hematoxylin-eosin, 簡稱H&E)和免疫組織化學(xué)(immuno histo chemical, 簡稱IHC),H&E是最常用的染色方法。與H&E常規(guī)染色相比,IHC染色利用抗原抗體的特異性結(jié)合反應(yīng)來檢測和定位組織和細(xì)胞中的某些化學(xué)物質(zhì),具有較高的敏感性,可將形態(tài)學(xué)改變與功能代謝變化相結(jié)合,從而能夠鑒別、診斷和治療惡性腫。H&E的問題在于,在一周的不同日期進(jìn)行染色時(shí),實(shí)驗(yàn)室中的染色變異很大(HPA和HuBMAP的染色標(biāo)準(zhǔn)不一),甚至在同一實(shí)驗(yàn)室也是如此。這是因?yàn)樽罱K結(jié)果很大程度上取決于染料的類型和密度以及組織實(shí)際暴露于染色劑的時(shí)間。

常用解決辦法:
- 染色歸一化:不同的實(shí)驗(yàn)室和掃描儀可以為特定污漬生成具有不同顏色配置文件的圖像。染色標(biāo)準(zhǔn)化的目標(biāo)是標(biāo)準(zhǔn)化這些染色的外觀。傳統(tǒng)上,使用顏色匹配 [ Reinhard2001 ] 和染色分離 [ Macenko2009 , Khan2014 , Vahadane2016 ] 等方法。但是,這些方法依賴于選擇單個(gè)參考幻燈片。
病理圖像常用的顏色標(biāo)準(zhǔn)化方法
STST: 這個(gè)方法是用c-gan做的,但我實(shí)際跑的時(shí)候效果很差,細(xì)胞核細(xì)胞質(zhì)顏色很相近,分不開,可能因?yàn)樗菍?duì)灰度圖像染色,所以模型分不太開細(xì)胞核跟細(xì)胞質(zhì),不建議用。
Vahadane: 這個(gè)是比較推薦的方法。這個(gè)方法是用非負(fù)矩陣分解得到兩個(gè)染料矩陣,然后將reference和source的進(jìn)行配準(zhǔn),然后再合成新的圖像。這個(gè)方法比較穩(wěn)定,得到的圖像顏色比較真實(shí)。實(shí)際的標(biāo)化速度也可以接受。
Khan: 效果里面看很差,不太建議用。
Macehko: 這個(gè)方法是將RGB轉(zhuǎn)成OD,然后再用SVD分解得到兩個(gè)垂直的顏色矩陣,再將這個(gè)顏色矩陣進(jìn)行歸一化,或者歸一化到reference上。在實(shí)際應(yīng)用的時(shí)候得到的圖像顏色不是很自然,會(huì)變得比較奇怪。一方面可能是SVD分解并沒有非負(fù)分解這么好用,另一個(gè)可能是reference圖像選擇的不好。但速度上很快,效果一般。
Reinhard: 最開始并不是用在病理圖像上的,它是用在自然圖像上的,就是在Lab空間把兩個(gè)圖像的顏色進(jìn)行統(tǒng)計(jì)學(xué)上的匹配。這個(gè)效果很差,但很快。差的原因很多,Lab顏色空間并不符合病理圖像的光學(xué)特性,OD和HED更加符合他們的特性。其次,這個(gè)方法受輸入圖像的影響非常大,有些組織少的區(qū)域,為了得到相同的統(tǒng)計(jì)學(xué)分布,會(huì)將組織染色非常深,總>之就很不智能。不推薦用。
- 色彩增強(qiáng):通過應(yīng)用隨機(jī)仿射變換或添加噪聲來增強(qiáng)圖像是對(duì)抗過度擬合的最常見正則化技術(shù)之一。類似地,可以利用染色的變化來增加訓(xùn)練期間呈現(xiàn)給模型的圖像外觀的多樣性。雖然顏色的劇烈變化對(duì)于組織學(xué)來說是不現(xiàn)實(shí)的,但通過對(duì)每個(gè)顏色通道的隨機(jī)加法和乘法變化產(chǎn)生的更微妙的變化已被證明可以提高模型性能。顏色增強(qiáng)的強(qiáng)度是一個(gè)額外的超參數(shù),應(yīng)該在訓(xùn)練期間進(jìn)行試驗(yàn),并在來自不同實(shí)驗(yàn)室或掃描儀的測試集上進(jìn)行驗(yàn)證。
-
無監(jiān)督域?qū)褂?xùn)練:域適應(yīng)的下一個(gè)技術(shù)是域?qū)褂?xùn)練。這種方法利用來自目標(biāo)域的未標(biāo)記圖像。域?qū)鼓K被添加到現(xiàn)有模型中。該分類器的目標(biāo)是預(yù)測圖像屬于源域還是目標(biāo)域。梯度反轉(zhuǎn)層將此模塊連接到現(xiàn)有網(wǎng)絡(luò),以便訓(xùn)練優(yōu)化原始任務(wù)并鼓勵(lì)網(wǎng)絡(luò)學(xué)習(xí)域不變特征。特征提取器提取的信息會(huì)傳入域分類器,之后域分類器會(huì)判斷傳入的信息到底是來自源域還是目標(biāo)域,并計(jì)算損失。域分類器的訓(xùn)練目標(biāo)是盡量將輸入的信息分到正確的域類別(源域還是目標(biāo)域),而特征提取器的訓(xùn)練目標(biāo)卻恰恰相反(由于梯度反轉(zhuǎn)層的存在),特征提取器所提取的特征(或者說映射的結(jié)果)目的是是域判別器不能正確的判斷出信息來自哪一個(gè)域,因此形成一種對(duì)抗關(guān)系。特征提取器提取的信息也會(huì)傳入Label predictor (類別預(yù)測器)了,因?yàn)樵从驑颖臼怯袠?biāo)記的,所以在提取特征時(shí)不僅僅要考慮后面的域判別器的情況,還要利用源域的帶標(biāo)記樣本進(jìn)行有監(jiān)督訓(xùn)練從而兼顧分類的準(zhǔn)確性。
解決方案
采用了染色增強(qiáng)+CutMix+CutOut+像素大小自適應(yīng)的數(shù)據(jù)增強(qiáng)策略。采用原始數(shù)據(jù)集訓(xùn)練CoaT+Daformer和Swin transformer+ UPerNet模型,然后用過采樣肺部策略訓(xùn)練Swin transformer+UPerNet作為針對(duì)肺部的預(yù)測最終形成集成預(yù)測。loss采用Focal loss+Dice loss進(jìn)行難例挖掘。
A Comprehensive Study of Vision Transformers on Dense Prediction Tasks
文章研究了 Vision Transformers (VTs),VTs 和 CNN 作為特征提取器的不同方面,用于對(duì)具有挑戰(zhàn)性的現(xiàn)實(shí)世界數(shù)據(jù)進(jìn)行目標(biāo)檢測和語義分割。實(shí)驗(yàn)得出的主要結(jié)果和主要見解如下:
- VTs 在分布式數(shù)據(jù)集中優(yōu)于 CNN,同時(shí)具有較低的推理速度,但計(jì)算復(fù)雜度較低。因此,如果 GPU 針對(duì) Transformer 架構(gòu)進(jìn)行了優(yōu)化,它們就有可能在計(jì)算機(jī)視覺領(lǐng)域占據(jù)主導(dǎo)地位。
- VTs 可以更好地泛化到 OOD 數(shù)據(jù)集。我們的損失情況分析表明,與 CNN 相比,VTs 收斂到更平坦的最小值,這可以解釋它們的普遍性。
- 與 CNN 相比,VTs 對(duì)自然損壞和對(duì)抗性攻擊更穩(wěn)健。我們認(rèn)為這可以歸因于全局感受域以及自我注意的動(dòng)態(tài)特性。
- VTs 的紋理偏差比 CNN 少,這可以歸因于它們的全局感受野,這使得它們能夠更好地關(guān)注基于全局形狀的線索,而不是基于局部紋理的線索。
數(shù)據(jù)增強(qiáng)策略
基本數(shù)據(jù)增強(qiáng)
包括翻轉(zhuǎn)、旋轉(zhuǎn)、對(duì)比度、HSV顏色空間、噪聲和一些尺度變換。
def do_random_flip(image, mask):
if np.random.rand()>0.5:
image = cv2.flip(image,0)
mask = cv2.flip(mask,0)
if np.random.rand()>0.5:
image = cv2.flip(image,1)
mask = cv2.flip(mask,1)
if np.random.rand()>0.5:
image = image.transpose(1,0,2)
mask = mask.transpose(1,0)
image = np.ascontiguousarray(image)
mask = np.ascontiguousarray(mask)
return image, mask
def do_random_rot90(image, mask):
r = np.random.choice([
0,
cv2.ROTATE_90_CLOCKWISE,
cv2.ROTATE_90_COUNTERCLOCKWISE,
cv2.ROTATE_180,
])
if r==0:
return image, mask
else:
image = cv2.rotate(image, r)
mask = cv2.rotate(mask, r)
return image, mask
def do_random_contast(image, mask, mag=0.3):
alpha = 1 + random.uniform(-1,1)*mag
image = image * alpha
image = np.clip(image,0,1)
return image, mask
def do_random_hsv(image, mask, mag=[0.15,0.25,0.25]):
image = (image*255).astype(np.uint8)
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
h = hsv[:, :, 0].astype(np.float32) # hue
s = hsv[:, :, 1].astype(np.float32) # saturation
v = hsv[:, :, 2].astype(np.float32) # value
h = (h*(1 + random.uniform(-1,1)*mag[0]))%180
s = s*(1 + random.uniform(-1,1)*mag[1])
v = v*(1 + random.uniform(-1,1)*mag[2])
hsv[:, :, 0] = np.clip(h,0,180).astype(np.uint8)
hsv[:, :, 1] = np.clip(s,0,255).astype(np.uint8)
hsv[:, :, 2] = np.clip(v,0,255).astype(np.uint8)
image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
image = image.astype(np.float32)/255
return image, mask
def do_random_noise(image, mask, mag=0.1):
height, width = image.shape[:2]
noise = np.random.uniform(-1,1, (height, width,1))*mag
image = image + noise
image = np.clip(image,0,1)
return image, mask
def do_random_rotate_scale(image, mask, angle=30, scale=[0.8,1.2] ):
angle = np.random.uniform(-angle, angle)
scale = np.random.uniform(*scale) if scale is not None else 1
height, width = image.shape[:2]
center = (height // 2, width // 2)
transform = cv2.getRotationMatrix2D(center, angle, scale)
image = cv2.warpAffine( image, transform, (width, height), flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
mask = cv2.warpAffine( mask, transform, (width, height), flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderValue=0)
return image, mask
針對(duì)性的數(shù)據(jù)增強(qiáng)
解決訓(xùn)練集HPA和測試集Hubmap中的圖片像素間距和層厚不同的自適應(yīng)策略:
imaging_measurements = {
'hpa': {
'pixel_size': {
'kidney': 0.4,
'prostate': 0.4,
'largeintestine': 0.4,
'spleen': 0.4,
'lung': 0.4
},
'tissue_thickness': {
'kidney': 4,
'prostate': 4,
'largeintestine': 4,
'spleen': 4,
'lung': 4
}
},
'hubmap': {
'pixel_size': {
'kidney': 0.5,
'prostate': 6.263,
'largeintestine': 0.229,
'spleen': 0.4945,
'lung': 0.7562
},
'tissue_thickness': {
'kidney': 10,
'prostate': 5,
'largeintestine': 8,
'spleen': 4,
'lung': 5
}
}
}
def pixelSize_tissueThickness_adaptation(image, mask, organ, alpha=0.15):
image = (image*255).astype(np.uint8)
domain_pixel_size=imaging_measurements['hpa']['pixel_size'][organ],
target_pixel_size=imaging_measurements['hubmap']['pixel_size'][organ],
domain_tissue_thickness=imaging_measurements['hpa']['tissue_thickness'][organ],
target_tissue_thickness=imaging_measurements['hubmap']['tissue_thickness'][organ],
# Augment tissue thickness
tissue_thickness_scale_factor = target_tissue_thickness[0] - domain_tissue_thickness[0]
image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
image_hsv[:, :, 1] *= (1 + (alpha * tissue_thickness_scale_factor))
image_hsv[:, :, 2] *= (1 - (alpha * tissue_thickness_scale_factor))
image_hsv = image_hsv.astype(np.uint8)
image_scaled = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2RGB)
# Standardize luminosity
image_scaled = staintools.LuminosityStandardizer.standardize(image_scaled)
# Augment pixel size
pixel_size_scale_factor = domain_pixel_size[0] / target_pixel_size[0]
image_resized = cv2.resize(
image_scaled,
dsize=None,
fx=pixel_size_scale_factor,
fy=pixel_size_scale_factor,
interpolation=cv2.INTER_CUBIC
)
image_resized = cv2.resize(
image_resized,
dsize=(
image.shape[1],
image.shape[0]
),
interpolation=cv2.INTER_CUBIC
)
# Standardize luminosity
image = staintools.LuminosityStandardizer.standardize(image)
image_augmented = staintools.LuminosityStandardizer.standardize(image_resized)
image = image_augmented.astype(np.float32)/255
return image, mask
解決訓(xùn)練集HPA和測試集Hubmap中的染色標(biāo)準(zhǔn)不一的顏色增強(qiáng)策略:
def color_transfer(image, mask):
image = (image*255).astype(np.uint8)
hed_lighter_aug = stainlib.augmentation.augmenter.HedLightColorAugmenter()
hed_lighter_aug.randomize()
transformed = hed_lighter_aug.transform(image)
image = image.astype(np.float32)/255
return image, mask
提高在測試集泛化能力的其他數(shù)據(jù)增強(qiáng)策略:CutMix+CutOut(數(shù)據(jù)增強(qiáng):Mixup,Cutout,CutMix Mosaic),提供高強(qiáng)度的、變化的擾動(dòng)。消融實(shí)驗(yàn)顯示CutMix+CutOut的漲點(diǎn)是可觀的。
# CutMix 的切塊功能
def rand_bbox(size, lam):
if len(size) == 4:
W = size[2]
H = size[3]
elif len(size) == 3:
W = size[0]
H = size[1]
else:
raise Exception
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
randoms_int = random.randint(0, 99)
if randoms_int < 20:
rand_index = random.randint(0, len(self.df)-1)
rand_path = self.df.loc[rand_index, 'image_path']
rand_img_height = self.df.loc[rand_index, 'img_height']
rand_img_width = self.df.loc[rand_index, 'img_width']
if stained:
rand_image = tifffile.imread(rand_path.replace('train_images', 'train_images_augment'))
else:
rand_image = tifffile.imread(rand_path)
rand_rle = self.df.loc[rand_index, 'rle']
rand_mask = rle2mask(rand_rle, (rand_img_height, rand_img_width))
rand_image = rand_image.astype(np.float32)/255
rand_image = cv2.resize(rand_image,dsize=(image_size,image_size),interpolation=cv2.INTER_LINEAR)
rand_mask = cv2.resize(rand_mask, dsize=(image_size,image_size),interpolation=cv2.INTER_LINEAR)
lam = np.random.beta(1,1)
bbx1, bby1, bbx2, bby2 = rand_bbox(rand_image.shape, lam)
image[bbx1:bbx2, bby1:bby2, :] = rand_image[bbx1:bbx2, bby1:bby2, :]
mask[bbx1:bbx2, bby1:bby2] = rand_mask[bbx1:bbx2, bby1:bby2]
if randoms_int >= 80:
alpha = random.uniform(0.2, 0.5)
y = np.random.randint(image_size)
x = np.random.randint(image_size)
#劃出正方形區(qū)域,邊界處截?cái)? y1 = np.clip(y - int(alpha*image_size) // 2, 0, image_size)
y2 = np.clip(y + int(alpha*image_size) // 2, 0, image_size)
x1 = np.clip(x - int(alpha*image_size) // 2, 0, image_size)
x2 = np.clip(x + int(alpha*image_size) // 2, 0, image_size)
#全0填充區(qū)域
image[y1: y2, x1: x2, :] = 0
mask[y1: y2, x1: x2] = 0
fname = self.fnames[index]
organ = self.organ_to_label[self.df.loc[index].organ]
CutOut和CutMix已被證實(shí)可以用于語義分割(Semi-supervised semantic segmentation needs strong, varied perturbations),其在語義分割中的作用符合人的直觀理解,但是其背后的數(shù)學(xué)解釋值得探究。本比賽中也發(fā)現(xiàn)一個(gè)有趣的現(xiàn)象,好像CutOut和CutMix對(duì)于transformer模型相比CNN更有價(jià)值。
模型主要利用transformer模型作為特征提取器(Encoder),然后采用不用的Decoder構(gòu)成語義分割網(wǎng)絡(luò)。

Swin transformer+UPerNet總結(jié):
Swin transformer
性能優(yōu)于DeiT、ViT和EfficientNet等主干網(wǎng)絡(luò),已經(jīng)替代經(jīng)典的CNN架構(gòu),成為了計(jì)算機(jī)視覺領(lǐng)域通用的backbone。它基于了ViT模型的思想,創(chuàng)新性的引入了滑動(dòng)窗口機(jī)制,讓模型能夠?qū)W習(xí)到跨窗口的信息,同時(shí)也。同時(shí)通過下采樣層,使得模型能夠處理超分辨率的圖片,節(jié)省計(jì)算量以及能夠關(guān)注全局和局部的信息。
目前將 Transformer 從自然語言處理領(lǐng)域應(yīng)用到計(jì)算機(jī)視覺領(lǐng)域主要有兩大挑戰(zhàn):
- 視覺實(shí)體的方差較大,例如同一個(gè)物體,拍攝角度不同,轉(zhuǎn)化為二進(jìn)制后的圖片就會(huì)具有很大的差異。同時(shí)在不同場景下視覺 Transformer 性能未必很好。
- 圖像分辨率高,像素點(diǎn)多,如果采用ViT模型,自注意力的計(jì)算量會(huì)與像素的平方成正比。
針對(duì)上述兩個(gè)問題,論文中提出了一種基于滑動(dòng)窗口機(jī)制,具有層級(jí)設(shè)計(jì)(下采樣層) 的 Swin Transformer。
其中滑窗操作包括不重疊的 local window,和重疊的 cross-window。將注意力計(jì)算限制在一個(gè)窗口(window size固定)中,一方面能引入 CNN 卷積操作的局部性,另一方面能大幅度節(jié)省計(jì)算量,它只和窗口數(shù)量成線性關(guān)系。

整個(gè)模型采取層次化的設(shè)計(jì),一共包含 4 個(gè) Stage,除第一個(gè) stage 外,每個(gè) stage 都會(huì)先通過 Patch Merging 層縮小輸入特征圖的分辨率,進(jìn)行下采樣操作,像 CNN 一樣逐層擴(kuò)大感受野,以便獲取到全局的信息:
- 在輸入開始的時(shí)候做了一個(gè)Patch Partition,即ViT中Patch Embedding操作,通過 Patch_size 為4的卷積層將圖片切成一個(gè)個(gè) Patch ,并嵌入到Embedding,將 embedding_size轉(zhuǎn)變?yōu)?8(可以將 CV 中圖片的通道數(shù)理解為NLP中token的詞嵌入長度)。
- 隨后在第一個(gè)Stage中,通過Linear Embedding調(diào)整通道數(shù)為C。
- 在每個(gè) Stage 里(除第一個(gè) Stage ),均由Patch Merging和多個(gè)Swin Transformer Block組成。
- Swin Transformer Block注意這里的Block其實(shí)有兩種結(jié)構(gòu)W-MSA和SW-MSA。兩個(gè)結(jié)構(gòu)是成對(duì)使用的,先使用一個(gè)W-MSA結(jié)構(gòu)再使用一個(gè)SW-MSA結(jié)構(gòu)。
- Patch Merging模塊主要在每個(gè) Stage 一開始降低圖片分辨率,進(jìn)行下采樣的操作。
- Swin Transformer Block具體結(jié)構(gòu)如右圖所示,主要是LayerNorm,Window Attention ,Shifted Window Attention和MLP組成 。
Patch Embedding
在輸入進(jìn) Block 前,我們需要將圖片切成一個(gè)個(gè) patch,然后嵌入向量。具體做法是對(duì)原始圖片裁成一個(gè)個(gè) window_size * window_size 的窗口大小,然后進(jìn)行嵌入。這里可以通過二維卷積層,將 stride,kernel_size 設(shè)置為 window_size 大小。設(shè)定輸出通道來確定嵌入向量的大小。最后將 H,W 維度展開,并移動(dòng)到第一維度。
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self,
patch_size=4,
in_chans=3,
embed_dim=96,
norm_layer=None
):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# padding
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
#通過程序?qū)崿F(xiàn)可以發(fā)現(xiàn)其并沒有使用nn.Linear轉(zhuǎn)換輸入通道數(shù),而是使用nn.Conv2d在進(jìn)行patches轉(zhuǎn)換時(shí)同時(shí)更換了通道數(shù)。
Patch Merging
Patch Merging的作用是分辨率減半,通道數(shù)加倍,類似于CNN的作用,在Transformer中實(shí)現(xiàn)Hierarchical。用在每個(gè) Stage 開始前做降采樣能節(jié)省一定運(yùn)算量。在 CNN 中,則是在每個(gè) Stage 開始前用stride=2的卷積/池化層來降低分辨率。
每次降采樣是兩倍,因此在行方向和列方向上,間隔 2 選取元素。然后拼接在一起作為一整個(gè)張量,最后展開。此時(shí)通道維度會(huì)變成原先的 4 倍(因?yàn)?H,W 各縮小 2 倍),此時(shí)再通過一個(gè)全連接層再調(diào)整通道維度為原來的兩倍。


# H,W as input
# padding
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B, H/2, W/2, 4*C
x = x.view(B, -1, 4 * C) # B, H/2*W/2, 4*C
x = self.norm(x)
x = self.reduction(x)
return x
Window Partition/Reverse
window partition函數(shù)是用于對(duì)張量劃分窗口,指定窗口大小。將原本的張量從 N H W C, 劃分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / (window_size*window_size),即窗口的個(gè)數(shù)。而window reverse函數(shù)則是對(duì)應(yīng)的逆過程。這兩個(gè)函數(shù)會(huì)在后面的Window Attention用到。

def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
Window Attention
這傳統(tǒng)的 Transformer 都是基于全局來計(jì)算注意力的,因此計(jì)算復(fù)雜度十分高。而 Swin Transformer 則將注意力的計(jì)算限制在每個(gè)窗口內(nèi),進(jìn)而減少了計(jì)算量。
先簡單看下公式:
主要區(qū)別是在原始計(jì)算Attention的公式中的Q,K時(shí)加入了相對(duì)位置編碼。后續(xù)實(shí)驗(yàn)有證明相對(duì)位置編碼的加入提升了模型性能。
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim # 輸入通道的數(shù)量
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) # coords_h = tensor([0,1,2,...,self.window_size[0]-1]) 維度=Wh
coords_w = torch.arange(self.window_size[1]) # coords_w = tensor([0,1,2,...,self.window_size[1]-1]) 維度=Ww
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
'''
后面我們需要將其展開成一維偏移量。而對(duì)于(2,1)和(1,2)這兩個(gè)坐標(biāo),在二維上是不同的,但是通過將x\y坐標(biāo)相加轉(zhuǎn)換為一維偏移的時(shí)候
他們的偏移量是相等的,所以需要對(duì)其做乘法操作,進(jìn)行區(qū)分
'''
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
# 計(jì)算得到相對(duì)位置索引
# relative_position_index.shape = (M2, M2) 意思是一共有這么多個(gè)位置
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
'''
relative_position_index注冊為一個(gè)不參與網(wǎng)絡(luò)學(xué)習(xí)的變量
'''
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
'''
使用從截?cái)嗾龖B(tài)分布中提取的值填充輸入張量
self.relative_position_bias_table 是全0張量,通過trunc_normal_ 進(jìn)行數(shù)值填充
'''
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
N: number of all patches in the window
C: 輸入通過線性層轉(zhuǎn)化得到的維度C
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
'''
x.shape = (num_windows*B, N, C)
self.qkv(x).shape = (num_windows*B, N, 3C)
self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)
self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)
'''
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
'''
q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C//num_heads)
N = M2 代表patches的數(shù)量
C//num_heads代表Q,K,V的維數(shù)
'''
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# q乘上一個(gè)放縮系數(shù),對(duì)應(yīng)公式中的sqrt(d)
q = q * self.scale
# attn.shape = (num_windows*B, num_heads, N, N) N = M2 代表patches的數(shù)量
attn = (q @ k.transpose(-2, -1))
'''
self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)
self.relative_position_index.shape = (Wh*Ww, Wh*Ww)
self.relative_position_index矩陣中的所有值都是從self.relative_position_bias_table中取的
self.relative_position_index是計(jì)算出來不可學(xué)習(xí)的量
'''
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
'''
attn.shape = (num_windows*B, num_heads, M2, M2) N = M2 代表patches的數(shù)量
.unsqueeze(0):擴(kuò)張維度,在0對(duì)應(yīng)的位置插入維度1
relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)
num_windows*B 通過廣播機(jī)制傳播,relative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) 的維度1會(huì)broadcast到數(shù)量num_windows*B
表示所有batch通用一個(gè)索引矩陣和相對(duì)位置矩陣
'''
attn = attn + relative_position_bias.unsqueeze(0)
# mask.shape = (num_windows, M2, M2)
# attn.shape = (num_windows*B, num_heads, M2, M2)
if mask is not None:
nW = mask.shape[0]
# attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一個(gè)M2代表有M2個(gè)token,第二個(gè)M2代表每個(gè)token要計(jì)算M2次QKT的值
# mask.unsqueeze(1).unsqueeze(0).shape = (1, num_windows, 1, M2, M2) 第一個(gè)M2代表有M2個(gè)token,第二個(gè)M2代表每個(gè)token要計(jì)算M2次QKT的值
# broadcast相加
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
# attn.shape = (B, num_windows, num_heads, M2, M2)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
'''
v.shape = (num_windows*B, num_heads, M2, C//num_heads) N=M2 代表patches的數(shù)量, C//num_heads代表輸入的維度
attn.shape = (num_windows*B, num_heads, M2, M2)
attn@v .shape = (num_windows*B, num_heads, M2, C//num_heads)
'''
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # B_:num_windows*B N:M2 C=num_heads*C//num_heads
# self.proj = nn.Linear(dim, dim) dim = C
# self.proj_drop = nn.Dropout(proj_drop)
x = self.proj(x)
x = self.proj_drop(x)
return x # x.shape = (num_windows*B, N, C) N:窗口中所有patches的數(shù)量
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops

- 首先輸入張量形狀為 [numWindows*B, window_size* window_size, C]
- 然后經(jīng)過self.qkv這個(gè)全連接層后,進(jìn)行 reshape,調(diào)整軸的順序,得到形狀為[3, numWindows*B, num_heads, window_size*window_size, c//num_heads],并分配給q,k,v。
- 根據(jù)公式,我們對(duì)q乘以一個(gè)scale縮放系數(shù),然后與k(為了滿足矩陣乘要求,需要將最后兩個(gè)維度調(diào)換)進(jìn)行相乘。得到形狀為[numWindows*B, num_heads, window_size*window_size, window_size*window_size]的attn張量。
- 之前我們針對(duì)位置編碼設(shè)置了個(gè)形狀為(2*window_size-1*2*window_size-1, numHeads)的可學(xué)習(xí)變量。我們用計(jì)算得到的相對(duì)編碼位置索引self.relative_position_index.vew(-1)選取,得到形狀為(window_size*window_size, window_size*window_size, numHeads)的編碼,再permute(2,0,1)后加到attn張量上。
- 暫不考慮 mask 的情況,剩下就是跟 transformer 一樣的 softmax,dropout,與V矩陣乘,再經(jīng)過一層全連接層和dropout。
相關(guān)位置編碼的代碼詳解
絕對(duì)位置編碼是在進(jìn)行self-attention計(jì)算之前為每一個(gè)token添加一個(gè)可學(xué)習(xí)的參數(shù),相對(duì)位置編碼如上式所示,是在進(jìn)行self-attention計(jì)算時(shí),在計(jì)算過程中添加一個(gè)可學(xué)習(xí)的相對(duì)位置參數(shù)。
假設(shè)window_size = 2*2即每個(gè)窗口有4個(gè)token(M=2) ,如圖所示,在計(jì)算self-attention時(shí),每個(gè)token都要與所有的token計(jì)算QK值,如圖6所示,當(dāng)位置1的token計(jì)算self-attention時(shí),要計(jì)算位置1與位置(1,2,3,4)的QK值,即以位置1的token為中心點(diǎn),中心點(diǎn)位置坐標(biāo)(0,0),其他位置計(jì)算與當(dāng)前位置坐標(biāo)的偏移量。

最后生成的是相對(duì)位置索引,relative_position_index.shape =

繼續(xù)以圖中 M=2 的窗口為例,當(dāng)計(jì)算位置1對(duì)應(yīng)的 個(gè)QK值時(shí),應(yīng)用的relative_position_index = [ 4, 5, 7, 8] 個(gè) ,對(duì)應(yīng)的數(shù)據(jù)就是相對(duì)位置編碼]圖中位置索引4,5,7,8位置對(duì)應(yīng)的
維數(shù)據(jù),即relative_position.shape =
相對(duì)位置編碼在源碼WindowAttention中應(yīng)用,了解原理之后就很容易能夠讀懂程序:

Shifted Window Attention
采用W-MSA模塊時(shí),只會(huì)在每個(gè)窗口內(nèi)進(jìn)行自注意力計(jì)算,所以窗口與窗口之間是無法進(jìn)行信息傳遞的。為了解決這個(gè)問題,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模塊,即進(jìn)行偏移的W-MSA。如下圖所示,左側(cè)使用的是剛剛講的W-MSA(假設(shè)是第L層),那么根據(jù)之前介紹的W-MSA和SW-MSA是成對(duì)使用的,那么第L+1層使用的就是SW-MSA(右側(cè)圖)。根據(jù)左右兩幅圖對(duì)比能夠發(fā)現(xiàn)窗口(Windows)發(fā)生了偏移(可以理解成窗口從左上角分別向右側(cè)和下方各偏移了? M/2 ? 個(gè)像素)??聪缕坪蟮拇翱冢ㄓ覀?cè)圖),比如對(duì)于第一行第2列的2x4的窗口,它能夠使第L層的第一排的兩個(gè)窗口信息進(jìn)行交流。再比如,第二行第二列的4x4的窗口,他能夠使第L層的四個(gè)窗口信息進(jìn)行交流,其他的同理。那么這就解決了不同窗口之間無法進(jìn)行信息交流的問題。
根據(jù)上圖,可以發(fā)現(xiàn)通過將窗口進(jìn)行偏移后,由原來的4個(gè)窗口變成9個(gè)窗口了。后面又要對(duì)每個(gè)窗口內(nèi)部進(jìn)行MSA,這樣做感覺又變麻煩了。為了解決這個(gè)麻煩,作者又提出而了Efficient batch computation for shifted configuration,一種更加高效的計(jì)算方法。下面是原論文給的示意圖。

在進(jìn)行cyclic shift之前,需要給子窗口進(jìn)行編碼,編碼之后通過torch.roll對(duì)窗口進(jìn)行滾動(dòng),達(dá)到cyclic shift的效果


想在原始四個(gè)窗口下得到正確的結(jié)果,我們就必須給Attention的結(jié)果加入一個(gè)mask。
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
# 生成全零張量
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
# 按區(qū)域劃分mask
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# tensor([[[0., 0., 0., 0., 1., 1., 2., 2.],
# [0., 0., 0., 0., 1., 1., 2., 2.],
# [0., 0., 0., 0., 1., 1., 2., 2.],
# [0., 0., 0., 0., 1., 1., 2., 2.],
# [3., 3., 3., 3., 4., 4., 5., 5.],
# [3., 3., 3., 3., 4., 4., 5., 5.],
# [6., 6., 6., 6., 7., 7., 8., 8.],
# [6., 6., 6., 6., 7., 7., 8., 8.]]])
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
SwinTransformerBlock調(diào)用了上面介紹到的WindowAttention,有mask的情況下,在WindowAttention中應(yīng)用mask對(duì)self-attention結(jié)果進(jìn)行調(diào)整。SwinTransformerBlock主要就是W-MSA/SW-MSA的實(shí)現(xiàn),其結(jié)構(gòu)為:LN+(W?MSA/SW?MSA)+LN+MLP。要注意的是shifted的特征圖最后會(huì)還原。這里的LN為nn.LayerNorm;MLP為作者自己的實(shí)現(xiàn)。
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion. 這里的分辨率是轉(zhuǎn)換成patches之后的分辨率,不是原圖像素的分辨率
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. dim*mlp_ration=隱藏層神經(jīng)元個(gè)數(shù)
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim # 輸入通道C
self.input_resolution = input_resolution # 輸入分辨率
self.num_heads = num_heads # self_attention head
self.window_size = window_size # 窗口大小
self.shift_size = shift_size # sw-window
self.mlp_ratio = mlp_ratio #
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim) # nn.LayerNorm
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 上述操作是為了給每個(gè)窗口給上索引
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # nW, window_size*window_size
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size*window_size, window_size*window_size
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution # H,W不是像素的分辨率,而是轉(zhuǎn)化成patches之后的分辨率
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x) # self.norm1 = norm_layer(dim) = nn.LayerNorm(dim)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
# torch.roll shifts為正則向下滾動(dòng),為負(fù)則向上滾動(dòng),可以是一個(gè)數(shù)組也可以是一個(gè)元組
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
'''
經(jīng)過torch.roll之后計(jì)算self.attn是SW-MSA
不經(jīng)過torch.roll計(jì)算的self.attn是W-MSA
'''
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
'''
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
'''
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
Block整體結(jié)構(gòu)如下:
- 先對(duì)特征圖進(jìn)行LayerNorm
- 通過self.shift_size決定是否需要對(duì)特征圖進(jìn)行shift
- 然后將特征圖切成一個(gè)個(gè)窗口
- 計(jì)算Attention,通過self.attn_mask來區(qū)分Window Attention還是Shift Window Attention
- 將各個(gè)窗口合并回來
- 如果之前有做shift操作,此時(shí)進(jìn)行reverse shift,把之前的shift操作恢復(fù)
- 做dropout和殘差連接
- 再通過一層LayerNorm+全連接層,以及dropout和殘差連接
SwinTransformerBlock,此程序中調(diào)用了上面介紹到的WindowAttention,有mask的情況下,在WindowAttention中應(yīng)用mask對(duì)self-attention結(jié)果進(jìn)行調(diào)整。
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self,
dim,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
#use_checkpoint=False,
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
"""
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA ----
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
#------
for blk in self.blocks:
x = blk(x, H, W, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
UperNet
主要使用的是UPerNet的decoder結(jié)構(gòu),Decoder = FPN+PPM:理論上講,深度卷積網(wǎng)絡(luò)的感受野足夠大,但實(shí)際可用的要小很多。為克服這一問題,本文把 PSPNet 中的金字塔池化模塊(PPM)用于骨干網(wǎng)絡(luò)的最后一層,在其被饋送至 FPN 自上而下的分支之前。結(jié)果實(shí)驗(yàn)證明,在帶來有效的全局先驗(yàn)表征方面,PPM 和 FPN 架構(gòu)是高度一致的。

class UPerDecoder(nn.Module):
def __init__(self,
in_dim=[256, 512, 1024, 2048],
ppm_pool_scale=[1, 2, 3, 6],
ppm_dim=512,
fpn_out_dim=256
):
super(UPerDecoder, self).__init__()
# PPM ----
dim = in_dim[-1]
ppm_pooling = []
ppm_conv = []
for scale in ppm_pool_scale:
ppm_pooling.append(
nn.AdaptiveAvgPool2d(scale)
)
ppm_conv.append(
nn.Sequential(
nn.Conv2d(dim, ppm_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(ppm_dim),
nn.ReLU(inplace=True)
)
)
self.ppm_pooling = nn.ModuleList(ppm_pooling)
self.ppm_conv = nn.ModuleList(ppm_conv)
self.ppm_out = conv3x3_bn_relu(dim + len(ppm_pool_scale)*ppm_dim, fpn_out_dim, 1)
# FPN ----
fpn_in = []
for i in range(0, len(in_dim)-1): # skip the top layer
fpn_in.append(
nn.Sequential(
nn.Conv2d(in_dim[i], fpn_out_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(fpn_out_dim),
nn.ReLU(inplace=True)
)
)
self.fpn_in = nn.ModuleList(fpn_in)
fpn_out = []
for i in range(len(in_dim) - 1): # skip the top layer
fpn_out.append(
conv3x3_bn_relu(fpn_out_dim, fpn_out_dim, 1),
)
self.fpn_out = nn.ModuleList(fpn_out)
self.fpn_fuse = nn.Sequential(
conv3x3_bn_relu(len(in_dim) * fpn_out_dim, fpn_out_dim, 1),
)
def forward(self, feature):
f = feature[-1]
pool_shape = f.shape[2:]
ppm_out = [f]
for pool, conv in zip(self.ppm_pooling, self.ppm_conv):
p = pool(f)
p = F.interpolate(p, size=pool_shape, mode='bilinear', align_corners=False)
p = conv(p)
ppm_out.append(p)
ppm_out = torch.cat(ppm_out, 1)
down = self.ppm_out(ppm_out)
#--------------------------------------
fpn_out = [down]
for i in reversed(range(len(feature) - 1)):
lateral = feature[i]
lateral = self.fpn_in[i](lateral) # lateral branch
down = F.interpolate(down, size=lateral.shape[2:], mode='bilinear', align_corners=False) # top-down branch
down = down + lateral
fpn_out.append(self.fpn_out[i](down))
fpn_out.reverse() # [P2 - P5]
fusion_shape = fpn_out[0].shape[2:]
fusion = [fpn_out[0]]
for i in range(1, len(fpn_out)):
fusion.append(
F.interpolate( fpn_out[i], fusion_shape, mode='bilinear', align_corners=False)
)
x = self.fpn_fuse( torch.cat(fusion, 1))
return x, fusion
Swin+UPerNet
參數(shù)設(shè)置:
cfg = dict(
#configs/_base_/models/upernet_swin.py
basic = dict(
swin=dict(
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
use_checkpoint=False
),
),
#configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py
swin_tiny_patch4_window7_224=dict(
checkpoint = pretrain_dir+'swin_tiny_patch4_window7_224_22k.pth',
swin = dict(
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
ape=False,
drop_path_rate=0.3,
patch_norm=True,
use_checkpoint=False,
),
upernet=dict(
in_channels=[96, 192, 384, 768],
),
),
#/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k.py
swin_small_patch4_window7_224_22k=dict(
checkpoint = pretrain_dir+'swin_small_patch4_window7_224_22k.pth',
swin = dict(
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
ape=False,
drop_path_rate=0.3,
patch_norm=True,
use_checkpoint=False
),
upernet=dict(
in_channels=[96, 192, 384, 768],
),
),
)
整體結(jié)構(gòu):
class Net(nn.Module):
def load_pretrain(self,):
checkpoint = cfg[self.arch]['checkpoint']
print('loading %s ...'%checkpoint)
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)['model']
if 0:
skip = ['relative_coords_table','relative_position_index']
filtered={}
for k,v in checkpoint.items():
if any([s in k for s in skip ]): continue
filtered[k]=v
checkpoint = filtered
print(self.encoder.load_state_dict(checkpoint,strict=False)) #True
def __init__( self,):
super(Net, self).__init__()
self.output_type = ['inference', 'loss']
self.rgb = RGB()
self.arch = 'swin_small_patch4_window7_224_22k'#'swin_tiny_patch4_window7_224'
self.encoder = SwinTransformerV1(
** {**cfg['basic']['swin'], **cfg[self.arch]['swin'],
**{'out_norm' : LayerNorm2d} }
)
encoder_dim =cfg[self.arch]['upernet']['in_channels']
#[96, 192, 384, 768]
self.decoder = UPerDecoder(
in_dim=encoder_dim,
ppm_pool_scale=[1, 2, 3, 6],
ppm_dim=512,
fpn_out_dim=256
)
self.logit = nn.Sequential(
nn.Conv2d(256, 1, kernel_size=1)
)
self.aux = nn.ModuleList([
nn.Conv2d(256, 1, kernel_size=1, padding=0) for i in range(4)
])
def forward(self, batch):
x = batch['image']
B,C,H,W = x.shape
x = self.rgb(x)
encoder = self.encoder(x)
last, decoder = self.decoder(encoder)
logit = self.logit(last)
logit = F.interpolate(logit, size=None, scale_factor=4, mode='bilinear', align_corners=False)
output = {}
if 'loss' in self.output_type:
output['bce_loss'] = F.binary_cross_entropy_with_logits(logit, batch['mask'])
output['dice_loss'] = DiceLoss()(logit, batch['mask'])
output['focal_loss'] = FocalLoss(logits=True, reduce=False)(logit, batch['mask'])
for i in range(4):
output['aux%d_loss'%i] = criterion_aux_loss(self.aux[i](decoder[i]),batch['mask'])
if 'inference' in self.output_type:
output['probability'] = torch.sigmoid(logit)
return output
CoaT+Daformer總結(jié):
CoaT
Co-scale conv-attentional image Transformers(CoaT),這是一種基于Transformer的圖像分類器,其主要包含Co-scale和conv-attentional機(jī)制設(shè)計(jì)。
- 首先,Co-scale機(jī)制在各個(gè)尺度上都保持了Transformers編碼器分支的完整性,同時(shí)允許在不同尺度下學(xué)習(xí)的表示形式能夠有效地進(jìn)行彼此間的通信。同時(shí),作者還設(shè)計(jì)了一系列的串行和并行塊用來實(shí)現(xiàn)Co-scale Attention機(jī)制。
- 其次,本文通過一種類似于卷積的實(shí)現(xiàn)方式設(shè)計(jì)了一種Factorized Attention機(jī)制,可以使得在因式注意力模塊中實(shí)現(xiàn)相對(duì)位置的嵌入。CoaT為 Vision Transformer提供了豐富的多尺度和上下文建模功能。
盡管CNN和Self-Attention操作都執(zhí)行一個(gè)加權(quán)和,但它們的權(quán)值計(jì)算方式不同:在CNN中權(quán)值在訓(xùn)練過程中學(xué)習(xí),但在測試過程中固定;而在Self-Attention中,根據(jù)每對(duì)Token之間的相似度或親和度動(dòng)態(tài)計(jì)算權(quán)重。因此,Self-Attention中的自相似操作提供了比卷積操作更具有潛在適應(yīng)性和通用性的建模手段。此外,位置編碼和位置嵌入的引入為Transformer建模提供了靈活性。
模型詳解:https://jishuin.proginn.com/p/763bfbd566f3
Daformer
DAFormer是使用Transformer進(jìn)行語義分割無監(jiān)督域自適應(yīng)的開篇之作。DAFormer的網(wǎng)絡(luò)結(jié)構(gòu)包括一個(gè)Transformer編碼器和一個(gè)多級(jí)上下文感知特征融合解碼器。它是由3個(gè)簡單但很關(guān)鍵的訓(xùn)練策略來穩(wěn)定訓(xùn)練和避免對(duì)源域的過擬合:
- 源域上的罕見類采樣通過減輕Self-training對(duì)普通類的確認(rèn)偏差提高了Pseudo-labels的質(zhì)量
- Thing-Class ImageNet Feature Distance
- Learning rate warmup促進(jìn)了預(yù)訓(xùn)練的特征遷移
這里我們只是借用decoder模塊,完整的Daformer:https://blog.csdn.net/amusi1994/article/details/124833996

class DaformerDecoder(nn.Module):
def __init__(
self,
encoder_dim = [32, 64, 160, 256],
decoder_dim = 256,
dilation = [1, 6, 12, 18],
use_bn_mlp = True,
fuse = 'conv3x3',
):
super().__init__()
self.mlp = nn.ModuleList([
nn.Sequential(
# Conv2dBnReLU(dim, decoder_dim, 1, padding=0), #follow mmseg to use conv-bn-relu
*(
( nn.Conv2d(dim, decoder_dim, 1, padding= 0, bias=False),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True),
)if use_bn_mlp else
( nn.Conv2d(dim, decoder_dim, 1, padding= 0, bias=True),)
),
MixUpSample(2**i) if i!=0 else nn.Identity(),
) for i, dim in enumerate(encoder_dim)])
if fuse=='conv1x1':
self.fuse = nn.Sequential(
nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 1, padding=0, bias=False),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True),
)
if fuse=='conv3x3':
self.fuse = nn.Sequential(
nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 3, padding=1, bias=False),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True),
)
if fuse=='aspp':
self.fuse = ASPP(
decoder_dim*len(encoder_dim),
decoder_dim,
dilation,
)
if fuse=='ds-aspp':
self.fuse = DSASPP(
decoder_dim*len(encoder_dim),
decoder_dim,
dilation,
)
def forward(self, feature):
out = []
for i,f in enumerate(feature):
f = self.mlp[i](f)
out.append(f)
#print(f.shape)
x = self.fuse(torch.cat(out, dim = 1))
return x, out
class daformer_conv3x3 (DaformerDecoder):
def __init__(self, **kwargs):
super(daformer_conv3x3, self).__init__(
fuse = 'conv3x3',
**kwargs
)
class daformer_conv1x1 (DaformerDecoder):
def __init__(self, **kwargs):
super(daformer_conv1x1, self).__init__(
fuse = 'conv1x1',
**kwargs
)
class daformer_aspp (DaformerDecoder):
def __init__(self, **kwargs):
super(daformer_aspp, self).__init__(
fuse = 'aspp',
**kwargs
)
CoaT+Daformer:
class Net(nn.Module):
def __init__(self,
encoder=coat_lite_medium,
decoder=daformer_conv3x3,
encoder_cfg={},
decoder_cfg={},
):
super(Net, self).__init__()
self.output_type = ['inference', 'loss']
decoder_dim = decoder_cfg.get('decoder_dim', 320)
self.encoder = encoder
self.rgb = RGB()
encoder_dim = self.encoder.embed_dims
# [64, 128, 320, 512]
self.decoder = decoder(
encoder_dim=encoder_dim,
decoder_dim=decoder_dim,
)
self.logit = nn.Sequential(
nn.Conv2d(decoder_dim, 1, kernel_size=1),
nn.Upsample(scale_factor = 4, mode='bilinear', align_corners=False),
)
self.aux = nn.ModuleList([
nn.Conv2d(decoder_dim, 1, kernel_size=1, padding=0) for i in range(4)
])
def forward(self, batch):
x = batch['image']
x = self.rgb(x)
B, C, H, W = x.shape
encoder = self.encoder(x)
last, decoder = self.decoder(encoder)
logit = self.logit(last)
output = {}
if 'loss' in self.output_type:
output['bce_loss'] = F.binary_cross_entropy_with_logits(logit, batch['mask'])
output['dice_loss'] = DiceLoss()(logit, batch['mask'])
output['focal_loss'] = FocalLoss(logits=True, reduce=False)(logit, batch['mask'])
for i in range(4):
output['aux%d_loss'%i] = criterion_aux_loss(self.aux[i](decoder[i]),batch['mask'])
if 'inference' in self.output_type:
output['probability'] = torch.sigmoid(logit)
return output
Badcase分析

在肺部區(qū)域表現(xiàn)較差,解決思路使用過采樣肺部數(shù)據(jù)的方式重新訓(xùn)練Swin transformer+UPerNet,將其作為肺部預(yù)測模型,與原數(shù)據(jù)模型構(gòu)成集成預(yù)測。
if self.train:
self.df = df.append(df.loc[df[df['organ']=="lung"].index.repeat(5)]).reset_index(drop=True)

