多器官分割比賽總結(jié)

比賽目標(biāo):

本次競賽的目標(biāo)是確定來自幾個(gè)不同器官的活檢切片中每個(gè)功能組織單元(FTU)的位置。基礎(chǔ)數(shù)據(jù)包括來自不同來源的圖像,這些圖像采用不同的協(xié)議以不同的分辨率制作,反映了處理醫(yī)療數(shù)據(jù)的典型挑戰(zhàn)。本次比賽使用了來自兩個(gè)不同聯(lián)盟的數(shù)據(jù),即人類蛋白質(zhì)圖譜(HPA)和人類生物分子圖譜計(jì)劃(HuBMAP)。訓(xùn)練數(shù)據(jù)集由公共HPA數(shù)據(jù)中的數(shù)據(jù)組成,公共測試集是私有HPA數(shù)據(jù)和HuBMAP數(shù)據(jù)的組合,私有測試集僅包含HuBMAP數(shù)據(jù)。當(dāng)使用不同協(xié)議下的數(shù)據(jù)時(shí),調(diào)整模型以使其正常工作將是這場競爭的核心挑戰(zhàn)之一。開發(fā)泛化的模型是這項(xiàng)工作的關(guān)鍵目標(biāo)。

難點(diǎn)

In this competition we have:

  • Private dataset has a different set of image scales compared to the train (relatively easy to model)
  • Private dataset has a different color domain (different stains which attach to different molecules/tissues) - (Harder to model)
  • Different slice thickness… That's going to be tough to incorporate.

Color Domain

數(shù)字病理切片的制作首先需要組織染色。為了突出切片中特定的細(xì)胞核和腺體特征,限定并檢查組織,通常使用染色劑來增強(qiáng)組織成分間的對(duì)比度,,主要包括蘇木精-伊紅(hematoxylin-eosin, 簡稱H&E)和免疫組織化學(xué)(immuno histo chemical, 簡稱IHC),H&E是最常用的染色方法。與H&E常規(guī)染色相比,IHC染色利用抗原抗體的特異性結(jié)合反應(yīng)來檢測和定位組織和細(xì)胞中的某些化學(xué)物質(zhì),具有較高的敏感性,可將形態(tài)學(xué)改變與功能代謝變化相結(jié)合,從而能夠鑒別、診斷和治療惡性腫。H&E的問題在于,在一周的不同日期進(jìn)行染色時(shí),實(shí)驗(yàn)室中的染色變異很大(HPA和HuBMAP的染色標(biāo)準(zhǔn)不一),甚至在同一實(shí)驗(yàn)室也是如此。這是因?yàn)樽罱K結(jié)果很大程度上取決于染料的類型和密度以及組織實(shí)際暴露于染色劑的時(shí)間。


常用解決辦法:

  1. 染色歸一化:不同的實(shí)驗(yàn)室和掃描儀可以為特定污漬生成具有不同顏色配置文件的圖像。染色標(biāo)準(zhǔn)化的目標(biāo)是標(biāo)準(zhǔn)化這些染色的外觀。傳統(tǒng)上,使用顏色匹配 [ Reinhard2001 ] 和染色分離 [ Macenko2009 , Khan2014 , Vahadane2016 ] 等方法。但是,這些方法依賴于選擇單個(gè)參考幻燈片。

病理圖像常用的顏色標(biāo)準(zhǔn)化方法
STST: 這個(gè)方法是用c-gan做的,但我實(shí)際跑的時(shí)候效果很差,細(xì)胞核細(xì)胞質(zhì)顏色很相近,分不開,可能因?yàn)樗菍?duì)灰度圖像染色,所以模型分不太開細(xì)胞核跟細(xì)胞質(zhì),不建議用。
Vahadane: 這個(gè)是比較推薦的方法。這個(gè)方法是用非負(fù)矩陣分解得到兩個(gè)染料矩陣,然后將reference和source的進(jìn)行配準(zhǔn),然后再合成新的圖像。這個(gè)方法比較穩(wěn)定,得到的圖像顏色比較真實(shí)。實(shí)際的標(biāo)化速度也可以接受。
Khan: 效果里面看很差,不太建議用。
Macehko: 這個(gè)方法是將RGB轉(zhuǎn)成OD,然后再用SVD分解得到兩個(gè)垂直的顏色矩陣,再將這個(gè)顏色矩陣進(jìn)行歸一化,或者歸一化到reference上。在實(shí)際應(yīng)用的時(shí)候得到的圖像顏色不是很自然,會(huì)變得比較奇怪。一方面可能是SVD分解并沒有非負(fù)分解這么好用,另一個(gè)可能是reference圖像選擇的不好。但速度上很快,效果一般。
Reinhard: 最開始并不是用在病理圖像上的,它是用在自然圖像上的,就是在Lab空間把兩個(gè)圖像的顏色進(jìn)行統(tǒng)計(jì)學(xué)上的匹配。這個(gè)效果很差,但很快。差的原因很多,Lab顏色空間并不符合病理圖像的光學(xué)特性,OD和HED更加符合他們的特性。其次,這個(gè)方法受輸入圖像的影響非常大,有些組織少的區(qū)域,為了得到相同的統(tǒng)計(jì)學(xué)分布,會(huì)將組織染色非常深,總>之就很不智能。不推薦用。

  1. 色彩增強(qiáng):通過應(yīng)用隨機(jī)仿射變換或添加噪聲來增強(qiáng)圖像是對(duì)抗過度擬合的最常見正則化技術(shù)之一。類似地,可以利用染色的變化來增加訓(xùn)練期間呈現(xiàn)給模型的圖像外觀的多樣性。雖然顏色的劇烈變化對(duì)于組織學(xué)來說是不現(xiàn)實(shí)的,但通過對(duì)每個(gè)顏色通道的隨機(jī)加法和乘法變化產(chǎn)生的更微妙的變化已被證明可以提高模型性能。顏色增強(qiáng)的強(qiáng)度是一個(gè)額外的超參數(shù),應(yīng)該在訓(xùn)練期間進(jìn)行試驗(yàn),并在來自不同實(shí)驗(yàn)室或掃描儀的測試集上進(jìn)行驗(yàn)證。
  2. 無監(jiān)督域?qū)褂?xùn)練:域適應(yīng)的下一個(gè)技術(shù)是域?qū)褂?xùn)練。這種方法利用來自目標(biāo)域的未標(biāo)記圖像。域?qū)鼓K被添加到現(xiàn)有模型中。該分類器的目標(biāo)是預(yù)測圖像屬于源域還是目標(biāo)域。梯度反轉(zhuǎn)層將此模塊連接到現(xiàn)有網(wǎng)絡(luò),以便訓(xùn)練優(yōu)化原始任務(wù)并鼓勵(lì)網(wǎng)絡(luò)學(xué)習(xí)域不變特征。
    特征提取器提取的信息會(huì)傳入域分類器,之后域分類器會(huì)判斷傳入的信息到底是來自源域還是目標(biāo)域,并計(jì)算損失。域分類器的訓(xùn)練目標(biāo)是盡量將輸入的信息分到正確的域類別(源域還是目標(biāo)域),而特征提取器的訓(xùn)練目標(biāo)卻恰恰相反(由于梯度反轉(zhuǎn)層的存在),特征提取器所提取的特征(或者說映射的結(jié)果)目的是是域判別器不能正確的判斷出信息來自哪一個(gè)域,因此形成一種對(duì)抗關(guān)系。特征提取器提取的信息也會(huì)傳入Label predictor (類別預(yù)測器)了,因?yàn)樵从驑颖臼怯袠?biāo)記的,所以在提取特征時(shí)不僅僅要考慮后面的域判別器的情況,還要利用源域的帶標(biāo)記樣本進(jìn)行有監(jiān)督訓(xùn)練從而兼顧分類的準(zhǔn)確性。

解決方案

采用了染色增強(qiáng)+CutMix+CutOut+像素大小自適應(yīng)的數(shù)據(jù)增強(qiáng)策略。采用原始數(shù)據(jù)集訓(xùn)練CoaT+Daformer和Swin transformer+ UPerNet模型,然后用過采樣肺部策略訓(xùn)練Swin transformer+UPerNet作為針對(duì)肺部的預(yù)測最終形成集成預(yù)測。loss采用Focal loss+Dice loss進(jìn)行難例挖掘。

A Comprehensive Study of Vision Transformers on Dense Prediction Tasks
文章研究了 Vision Transformers (VTs),VTs 和 CNN 作為特征提取器的不同方面,用于對(duì)具有挑戰(zhàn)性的現(xiàn)實(shí)世界數(shù)據(jù)進(jìn)行目標(biāo)檢測和語義分割。實(shí)驗(yàn)得出的主要結(jié)果和主要見解如下:

  • VTs 在分布式數(shù)據(jù)集中優(yōu)于 CNN,同時(shí)具有較低的推理速度,但計(jì)算復(fù)雜度較低。因此,如果 GPU 針對(duì) Transformer 架構(gòu)進(jìn)行了優(yōu)化,它們就有可能在計(jì)算機(jī)視覺領(lǐng)域占據(jù)主導(dǎo)地位。
  • VTs 可以更好地泛化到 OOD 數(shù)據(jù)集。我們的損失情況分析表明,與 CNN 相比,VTs 收斂到更平坦的最小值,這可以解釋它們的普遍性。
  • 與 CNN 相比,VTs 對(duì)自然損壞和對(duì)抗性攻擊更穩(wěn)健。我們認(rèn)為這可以歸因于全局感受域以及自我注意的動(dòng)態(tài)特性。
  • VTs 的紋理偏差比 CNN 少,這可以歸因于它們的全局感受野,這使得它們能夠更好地關(guān)注基于全局形狀的線索,而不是基于局部紋理的線索。

數(shù)據(jù)增強(qiáng)策略

基本數(shù)據(jù)增強(qiáng)
包括翻轉(zhuǎn)、旋轉(zhuǎn)、對(duì)比度、HSV顏色空間、噪聲和一些尺度變換。

def do_random_flip(image, mask):
    if np.random.rand()>0.5:
        image = cv2.flip(image,0)
        mask = cv2.flip(mask,0)
    if np.random.rand()>0.5:
        image = cv2.flip(image,1)
        mask = cv2.flip(mask,1)
    if np.random.rand()>0.5:
        image = image.transpose(1,0,2)
        mask = mask.transpose(1,0)
    
    image = np.ascontiguousarray(image)
    mask = np.ascontiguousarray(mask)
    return image, mask

def do_random_rot90(image, mask):
    r = np.random.choice([
        0,
        cv2.ROTATE_90_CLOCKWISE,
        cv2.ROTATE_90_COUNTERCLOCKWISE,
        cv2.ROTATE_180,
    ])
    if r==0:
        return image, mask
    else:
        image = cv2.rotate(image, r)
        mask = cv2.rotate(mask, r)
        return image, mask
    
def do_random_contast(image, mask, mag=0.3):
    alpha = 1 + random.uniform(-1,1)*mag
    image = image * alpha
    image = np.clip(image,0,1)
    return image, mask

def do_random_hsv(image, mask, mag=[0.15,0.25,0.25]):
    image = (image*255).astype(np.uint8)
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    h = hsv[:, :, 0].astype(np.float32)  # hue
    s = hsv[:, :, 1].astype(np.float32)  # saturation
    v = hsv[:, :, 2].astype(np.float32)  # value
    h = (h*(1 + random.uniform(-1,1)*mag[0]))%180
    s =  s*(1 + random.uniform(-1,1)*mag[1])
    v =  v*(1 + random.uniform(-1,1)*mag[2])

    hsv[:, :, 0] = np.clip(h,0,180).astype(np.uint8)
    hsv[:, :, 1] = np.clip(s,0,255).astype(np.uint8)
    hsv[:, :, 2] = np.clip(v,0,255).astype(np.uint8)
    image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    image = image.astype(np.float32)/255
    return image, mask


def do_random_noise(image, mask, mag=0.1):
    height, width = image.shape[:2]
    noise = np.random.uniform(-1,1, (height, width,1))*mag
    image = image + noise
    image = np.clip(image,0,1)
    return image, mask

def do_random_rotate_scale(image, mask, angle=30, scale=[0.8,1.2] ):
    angle = np.random.uniform(-angle, angle)
    scale = np.random.uniform(*scale) if scale is not None else 1
    
    height, width = image.shape[:2]
    center = (height // 2, width // 2)
    
    transform = cv2.getRotationMatrix2D(center, angle, scale)
    image = cv2.warpAffine( image, transform, (width, height), flags=cv2.INTER_LINEAR,
                            borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
    mask  = cv2.warpAffine( mask, transform, (width, height), flags=cv2.INTER_LINEAR,
                            borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    return image, mask

針對(duì)性的數(shù)據(jù)增強(qiáng)
解決訓(xùn)練集HPA和測試集Hubmap中的圖片像素間距層厚不同的自適應(yīng)策略:

imaging_measurements = {
    'hpa': {
        'pixel_size': {
            'kidney': 0.4,
            'prostate': 0.4,
            'largeintestine': 0.4,
            'spleen': 0.4,
            'lung': 0.4
        },
        'tissue_thickness': {
            'kidney': 4,
            'prostate': 4,
            'largeintestine': 4,
            'spleen': 4,
            'lung': 4
        }
    },
    'hubmap': {
        'pixel_size': {
            'kidney': 0.5,
            'prostate': 6.263,
            'largeintestine': 0.229,
            'spleen': 0.4945,
            'lung': 0.7562
        },
        'tissue_thickness': {
            'kidney': 10,
            'prostate': 5,
            'largeintestine': 8,
            'spleen': 4,
            'lung': 5
        }
    }
}

def pixelSize_tissueThickness_adaptation(image, mask, organ, alpha=0.15):
    image = (image*255).astype(np.uint8)
    domain_pixel_size=imaging_measurements['hpa']['pixel_size'][organ],
    target_pixel_size=imaging_measurements['hubmap']['pixel_size'][organ],
    domain_tissue_thickness=imaging_measurements['hpa']['tissue_thickness'][organ],
    target_tissue_thickness=imaging_measurements['hubmap']['tissue_thickness'][organ],
    
    # Augment tissue thickness
    tissue_thickness_scale_factor = target_tissue_thickness[0] - domain_tissue_thickness[0]
    image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
    image_hsv[:, :, 1] *= (1 + (alpha * tissue_thickness_scale_factor))
    image_hsv[:, :, 2] *= (1 - (alpha * tissue_thickness_scale_factor))
    image_hsv = image_hsv.astype(np.uint8)
    image_scaled = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2RGB)
    
    # Standardize luminosity
    image_scaled = staintools.LuminosityStandardizer.standardize(image_scaled)

    # Augment pixel size
    pixel_size_scale_factor = domain_pixel_size[0] / target_pixel_size[0]
    image_resized = cv2.resize(
        image_scaled,
        dsize=None,
        fx=pixel_size_scale_factor,
        fy=pixel_size_scale_factor,
        interpolation=cv2.INTER_CUBIC
    )
    image_resized = cv2.resize(
        image_resized,
        dsize=(
            image.shape[1],
            image.shape[0]
        ),
        interpolation=cv2.INTER_CUBIC
    )
    
    # Standardize luminosity
    image = staintools.LuminosityStandardizer.standardize(image)
    image_augmented = staintools.LuminosityStandardizer.standardize(image_resized)

    image = image_augmented.astype(np.float32)/255
    
    return image, mask

解決訓(xùn)練集HPA和測試集Hubmap中的染色標(biāo)準(zhǔn)不一的顏色增強(qiáng)策略:

def color_transfer(image, mask):
    image = (image*255).astype(np.uint8)
    hed_lighter_aug = stainlib.augmentation.augmenter.HedLightColorAugmenter()
    hed_lighter_aug.randomize()
    transformed = hed_lighter_aug.transform(image)
    image = image.astype(np.float32)/255
    return image, mask

提高在測試集泛化能力的其他數(shù)據(jù)增強(qiáng)策略:CutMix+CutOut(數(shù)據(jù)增強(qiáng):Mixup,Cutout,CutMix Mosaic),提供高強(qiáng)度的、變化的擾動(dòng)。消融實(shí)驗(yàn)顯示CutMix+CutOut的漲點(diǎn)是可觀的。

#  CutMix 的切塊功能
def rand_bbox(size, lam):
    if len(size) == 4:
        W = size[2]
        H = size[3]
    elif len(size) == 3:
        W = size[0]
        H = size[1]
    else:
        raise Exception

    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

randoms_int = random.randint(0, 99)
if randoms_int < 20:
    rand_index = random.randint(0, len(self.df)-1)
    rand_path = self.df.loc[rand_index, 'image_path']
    rand_img_height = self.df.loc[rand_index, 'img_height']
    rand_img_width = self.df.loc[rand_index, 'img_width']

    if stained:
        rand_image = tifffile.imread(rand_path.replace('train_images', 'train_images_augment'))
    else:
        rand_image = tifffile.imread(rand_path)
    
    rand_rle = self.df.loc[rand_index, 'rle']
    rand_mask = rle2mask(rand_rle, (rand_img_height, rand_img_width))

    rand_image = rand_image.astype(np.float32)/255

    rand_image = cv2.resize(rand_image,dsize=(image_size,image_size),interpolation=cv2.INTER_LINEAR)
    rand_mask  = cv2.resize(rand_mask, dsize=(image_size,image_size),interpolation=cv2.INTER_LINEAR)

    lam = np.random.beta(1,1)
    bbx1, bby1, bbx2, bby2 = rand_bbox(rand_image.shape, lam)

    image[bbx1:bbx2, bby1:bby2, :] = rand_image[bbx1:bbx2, bby1:bby2, :]
    mask[bbx1:bbx2, bby1:bby2] = rand_mask[bbx1:bbx2, bby1:bby2]

if randoms_int >= 80:
    alpha = random.uniform(0.2, 0.5)
    y = np.random.randint(image_size)
    x = np.random.randint(image_size)
    #劃出正方形區(qū)域,邊界處截?cái)?    y1 = np.clip(y - int(alpha*image_size) // 2, 0, image_size)
    y2 = np.clip(y + int(alpha*image_size) // 2, 0, image_size)
    x1 = np.clip(x - int(alpha*image_size) // 2, 0, image_size)
    x2 = np.clip(x + int(alpha*image_size) // 2, 0, image_size)
    #全0填充區(qū)域
    image[y1: y2, x1: x2, :] = 0
    mask[y1: y2, x1: x2] = 0

fname = self.fnames[index]
organ = self.organ_to_label[self.df.loc[index].organ]

CutOut和CutMix已被證實(shí)可以用于語義分割(Semi-supervised semantic segmentation needs strong, varied perturbations),其在語義分割中的作用符合人的直觀理解,但是其背后的數(shù)學(xué)解釋值得探究。本比賽中也發(fā)現(xiàn)一個(gè)有趣的現(xiàn)象,好像CutOut和CutMix對(duì)于transformer模型相比CNN更有價(jià)值。

模型主要利用transformer模型作為特征提取器(Encoder),然后采用不用的Decoder構(gòu)成語義分割網(wǎng)絡(luò)。



Swin transformer+UPerNet總結(jié):

Swin transformer

性能優(yōu)于DeiT、ViT和EfficientNet等主干網(wǎng)絡(luò),已經(jīng)替代經(jīng)典的CNN架構(gòu),成為了計(jì)算機(jī)視覺領(lǐng)域通用的backbone。它基于了ViT模型的思想,創(chuàng)新性的引入了滑動(dòng)窗口機(jī)制,讓模型能夠?qū)W習(xí)到跨窗口的信息,同時(shí)也。同時(shí)通過下采樣層,使得模型能夠處理超分辨率的圖片,節(jié)省計(jì)算量以及能夠關(guān)注全局和局部的信息。
目前將 Transformer 從自然語言處理領(lǐng)域應(yīng)用到計(jì)算機(jī)視覺領(lǐng)域主要有兩大挑戰(zhàn):

  • 視覺實(shí)體的方差較大,例如同一個(gè)物體,拍攝角度不同,轉(zhuǎn)化為二進(jìn)制后的圖片就會(huì)具有很大的差異。同時(shí)在不同場景下視覺 Transformer 性能未必很好。
  • 圖像分辨率高,像素點(diǎn)多,如果采用ViT模型,自注意力的計(jì)算量會(huì)與像素的平方成正比。
    針對(duì)上述兩個(gè)問題,論文中提出了一種基于滑動(dòng)窗口機(jī)制,具有層級(jí)設(shè)計(jì)(下采樣層) 的 Swin Transformer。

其中滑窗操作包括不重疊的 local window,和重疊的 cross-window。將注意力計(jì)算限制在一個(gè)窗口(window size固定)中,一方面能引入 CNN 卷積操作的局部性,另一方面能大幅度節(jié)省計(jì)算量,它只和窗口數(shù)量成線性關(guān)系。

整體結(jié)構(gòu)

整個(gè)模型采取層次化的設(shè)計(jì),一共包含 4 個(gè) Stage,除第一個(gè) stage 外,每個(gè) stage 都會(huì)先通過 Patch Merging 層縮小輸入特征圖的分辨率,進(jìn)行下采樣操作,像 CNN 一樣逐層擴(kuò)大感受野,以便獲取到全局的信息:

  • 在輸入開始的時(shí)候做了一個(gè)Patch Partition,即ViT中Patch Embedding操作,通過 Patch_size 為4的卷積層將圖片切成一個(gè)個(gè) Patch ,并嵌入到Embedding,將 embedding_size轉(zhuǎn)變?yōu)?8(可以將 CV 中圖片的通道數(shù)理解為NLP中token的詞嵌入長度)。
  • 隨后在第一個(gè)Stage中,通過Linear Embedding調(diào)整通道數(shù)為C。
  • 在每個(gè) Stage 里(除第一個(gè) Stage ),均由Patch Merging和多個(gè)Swin Transformer Block組成。
  • Swin Transformer Block注意這里的Block其實(shí)有兩種結(jié)構(gòu)W-MSASW-MSA。兩個(gè)結(jié)構(gòu)是成對(duì)使用的,先使用一個(gè)W-MSA結(jié)構(gòu)再使用一個(gè)SW-MSA結(jié)構(gòu)。
  • Patch Merging模塊主要在每個(gè) Stage 一開始降低圖片分辨率,進(jìn)行下采樣的操作。
  • Swin Transformer Block具體結(jié)構(gòu)如右圖所示,主要是LayerNorm,Window Attention ,Shifted Window Attention和MLP組成 。

Patch Embedding

在輸入進(jìn) Block 前,我們需要將圖片切成一個(gè)個(gè) patch,然后嵌入向量。具體做法是對(duì)原始圖片裁成一個(gè)個(gè) window_size * window_size 的窗口大小,然后進(jìn)行嵌入。這里可以通過二維卷積層,將 stride,kernel_size 設(shè)置為 window_size 大小。設(shè)定輸出通道來確定嵌入向量的大小。最后將 H,W 維度展開,并移動(dòng)到第一維度。
class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 norm_layer=None
                 ):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size
        
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        # padding
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
            
            
        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
        
        return x
#通過程序?qū)崿F(xiàn)可以發(fā)現(xiàn)其并沒有使用nn.Linear轉(zhuǎn)換輸入通道數(shù),而是使用nn.Conv2d在進(jìn)行patches轉(zhuǎn)換時(shí)同時(shí)更換了通道數(shù)。

Patch Merging

Patch Merging的作用是分辨率減半,通道數(shù)加倍,類似于CNN的作用,在Transformer中實(shí)現(xiàn)Hierarchical。用在每個(gè) Stage 開始前做降采樣能節(jié)省一定運(yùn)算量。在 CNN 中,則是在每個(gè) Stage 開始前用stride=2的卷積/池化層來降低分辨率。
每次降采樣是兩倍,因此在行方向和列方向上,間隔 2 選取元素。然后拼接在一起作為一整個(gè)張量,最后展開。此時(shí)通道維度會(huì)變成原先的 4 倍(因?yàn)?H,W 各縮小 2 倍),此時(shí)再通過一個(gè)全連接層再調(diào)整通道維度為原來的兩倍。

# H,W as input
# padding
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)
    
    def forward(self, x, H, W):
        """
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
     
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
 
        
        x = x.view(B, H, W, C)
        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
        
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B, H/2, W/2, 4*C
        x = x.view(B, -1, 4 * C)  # B, H/2*W/2, 4*C
        
        x = self.norm(x)
        x = self.reduction(x)
        
        return x

Window Partition/Reverse

window partition函數(shù)是用于對(duì)張量劃分窗口,指定窗口大小。將原本的張量從 N H W C, 劃分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / (window_size*window_size),即窗口的個(gè)數(shù)。而window reverse函數(shù)則是對(duì)應(yīng)的逆過程。這兩個(gè)函數(shù)會(huì)在后面的Window Attention用到。


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

Window Attention

這傳統(tǒng)的 Transformer 都是基于全局來計(jì)算注意力的,因此計(jì)算復(fù)雜度十分高。而 Swin Transformer 則將注意力的計(jì)算限制在每個(gè)窗口內(nèi),進(jìn)而減少了計(jì)算量。
先簡單看下公式:
主要區(qū)別是在原始計(jì)算Attention的公式中的Q,K時(shí)加入了相對(duì)位置編碼。后續(xù)實(shí)驗(yàn)有證明相對(duì)位置編碼的加入提升了模型性能。

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim # 輸入通道的數(shù)量
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0]) # coords_h = tensor([0,1,2,...,self.window_size[0]-1])  維度=Wh
        coords_w = torch.arange(self.window_size[1]) # coords_w = tensor([0,1,2,...,self.window_size[1]-1])  維度=Ww

        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww


        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1

        '''
        后面我們需要將其展開成一維偏移量。而對(duì)于(2,1)和(1,2)這兩個(gè)坐標(biāo),在二維上是不同的,但是通過將x\y坐標(biāo)相加轉(zhuǎn)換為一維偏移的時(shí)候
        他們的偏移量是相等的,所以需要對(duì)其做乘法操作,進(jìn)行區(qū)分
        '''

        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        # 計(jì)算得到相對(duì)位置索引
        # relative_position_index.shape = (M2, M2) 意思是一共有這么多個(gè)位置
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww 

        '''
        relative_position_index注冊為一個(gè)不參與網(wǎng)絡(luò)學(xué)習(xí)的變量
        '''
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        '''
        使用從截?cái)嗾龖B(tài)分布中提取的值填充輸入張量
        self.relative_position_bias_table 是全0張量,通過trunc_normal_ 進(jìn)行數(shù)值填充
        '''
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            N: number of all patches in the window
            C: 輸入通過線性層轉(zhuǎn)化得到的維度C
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        '''
        x.shape = (num_windows*B, N, C)
        self.qkv(x).shape = (num_windows*B, N, 3C)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)
        '''
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        '''
        q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C//num_heads)
        N = M2 代表patches的數(shù)量
        C//num_heads代表Q,K,V的維數(shù)
        '''
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # q乘上一個(gè)放縮系數(shù),對(duì)應(yīng)公式中的sqrt(d)
        q = q * self.scale

        # attn.shape = (num_windows*B, num_heads, N, N)  N = M2 代表patches的數(shù)量
        attn = (q @ k.transpose(-2, -1))

        '''
        self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)
        self.relative_position_index.shape = (Wh*Ww, Wh*Ww)
        self.relative_position_index矩陣中的所有值都是從self.relative_position_bias_table中取的
        self.relative_position_index是計(jì)算出來不可學(xué)習(xí)的量
        '''
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        '''
        attn.shape = (num_windows*B, num_heads, M2, M2)  N = M2 代表patches的數(shù)量
        .unsqueeze(0):擴(kuò)張維度,在0對(duì)應(yīng)的位置插入維度1
        relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)
        num_windows*B 通過廣播機(jī)制傳播,relative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) 的維度1會(huì)broadcast到數(shù)量num_windows*B
        表示所有batch通用一個(gè)索引矩陣和相對(duì)位置矩陣
        '''
        attn = attn + relative_position_bias.unsqueeze(0)

        # mask.shape = (num_windows, M2, M2)
        # attn.shape = (num_windows*B, num_heads, M2, M2)
        if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一個(gè)M2代表有M2個(gè)token,第二個(gè)M2代表每個(gè)token要計(jì)算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一個(gè)M2代表有M2個(gè)token,第二個(gè)M2代表每個(gè)token要計(jì)算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        '''
        v.shape = (num_windows*B, num_heads, M2, C//num_heads)  N=M2 代表patches的數(shù)量, C//num_heads代表輸入的維度
        attn.shape = (num_windows*B, num_heads, M2, M2)
        attn@v .shape = (num_windows*B, num_heads, M2, C//num_heads)
        '''
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)   # B_:num_windows*B  N:M2  C=num_heads*C//num_heads

        #   self.proj = nn.Linear(dim, dim)  dim = C
        #   self.proj_drop = nn.Dropout(proj_drop)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x  # x.shape = (num_windows*B, N, C)  N:窗口中所有patches的數(shù)量

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops
  1. 首先輸入張量形狀為 [numWindows*B, window_size* window_size, C]
  2. 然后經(jīng)過self.qkv這個(gè)全連接層后,進(jìn)行 reshape,調(diào)整軸的順序,得到形狀為[3, numWindows*B, num_heads, window_size*window_size, c//num_heads],并分配給q,k,v。
  3. 根據(jù)公式,我們對(duì)q乘以一個(gè)scale縮放系數(shù),然后與k(為了滿足矩陣乘要求,需要將最后兩個(gè)維度調(diào)換)進(jìn)行相乘。得到形狀為[numWindows*B, num_heads, window_size*window_size, window_size*window_size]的attn張量。
  4. 之前我們針對(duì)位置編碼設(shè)置了個(gè)形狀為(2*window_size-1*2*window_size-1, numHeads)的可學(xué)習(xí)變量。我們用計(jì)算得到的相對(duì)編碼位置索引self.relative_position_index.vew(-1)選取,得到形狀為(window_size*window_size, window_size*window_size, numHeads)的編碼,再permute(2,0,1)后加到attn張量上。
  5. 暫不考慮 mask 的情況,剩下就是跟 transformer 一樣的 softmax,dropout,與V矩陣乘,再經(jīng)過一層全連接層和dropout。

相關(guān)位置編碼的代碼詳解

絕對(duì)位置編碼是在進(jìn)行self-attention計(jì)算之前為每一個(gè)token添加一個(gè)可學(xué)習(xí)的參數(shù),相對(duì)位置編碼如上式所示,是在進(jìn)行self-attention計(jì)算時(shí),在計(jì)算過程中添加一個(gè)可學(xué)習(xí)的相對(duì)位置參數(shù)。
假設(shè)window_size = 2*2即每個(gè)窗口有4個(gè)token(M=2) ,如圖所示,在計(jì)算self-attention時(shí),每個(gè)token都要與所有的token計(jì)算QK值,如圖6所示,當(dāng)位置1的token計(jì)算self-attention時(shí),要計(jì)算位置1與位置(1,2,3,4)的QK值,即以位置1的token為中心點(diǎn),中心點(diǎn)位置坐標(biāo)(0,0),其他位置計(jì)算與當(dāng)前位置坐標(biāo)的偏移量。

相對(duì)位置索引求解流程圖

最后生成的是相對(duì)位置索引,relative_position_index.shape = (M^{2}*M^{2}) ,在網(wǎng)絡(luò)中注冊成為一個(gè)不可學(xué)習(xí)的變量,relative_position_index的作用就是根據(jù)最終的索引值找到對(duì)應(yīng)的可學(xué)習(xí)的相對(duì)位置編碼。relative_position_index的數(shù)值范圍(0~8),即 (2M-1,2M-1),所以相對(duì)位置編碼可以由一個(gè)3*3的矩陣表示,如圖s所示:
相對(duì)位置編碼
圖中的0-8為索引值,每個(gè)索引值都對(duì)應(yīng)了 維可學(xué)習(xí)數(shù)據(jù)(根據(jù)圖1,每個(gè)token都要計(jì)算 個(gè)QK值,每個(gè)QK值都要加上對(duì)應(yīng)的相對(duì)位置編碼)

繼續(xù)以圖中 M=2 的窗口為例,當(dāng)計(jì)算位置1對(duì)應(yīng)的 M^2個(gè)QK值時(shí),應(yīng)用的relative_position_index = [ 4, 5, 7, 8] 個(gè) ,對(duì)應(yīng)的數(shù)據(jù)就是相對(duì)位置編碼]圖中位置索引4,5,7,8位置對(duì)應(yīng)的 M^2 維數(shù)據(jù),即relative_position.shape = (M^{2}*M^{2})

相對(duì)位置編碼在源碼WindowAttention中應(yīng)用,了解原理之后就很容易能夠讀懂程序:


Shifted Window Attention

采用W-MSA模塊時(shí),只會(huì)在每個(gè)窗口內(nèi)進(jìn)行自注意力計(jì)算,所以窗口與窗口之間是無法進(jìn)行信息傳遞的。為了解決這個(gè)問題,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模塊,即進(jìn)行偏移的W-MSA。如下圖所示,左側(cè)使用的是剛剛講的W-MSA(假設(shè)是第L層),那么根據(jù)之前介紹的W-MSA和SW-MSA是成對(duì)使用的,那么第L+1層使用的就是SW-MSA(右側(cè)圖)。根據(jù)左右兩幅圖對(duì)比能夠發(fā)現(xiàn)窗口(Windows)發(fā)生了偏移(可以理解成窗口從左上角分別向右側(cè)和下方各偏移了? M/2 ? 個(gè)像素)??聪缕坪蟮拇翱冢ㄓ覀?cè)圖),比如對(duì)于第一行第2列的2x4的窗口,它能夠使第L層的第一排的兩個(gè)窗口信息進(jìn)行交流。再比如,第二行第二列的4x4的窗口,他能夠使第L層的四個(gè)窗口信息進(jìn)行交流,其他的同理。那么這就解決了不同窗口之間無法進(jìn)行信息交流的問題。

根據(jù)上圖,可以發(fā)現(xiàn)通過將窗口進(jìn)行偏移后,由原來的4個(gè)窗口變成9個(gè)窗口了。后面又要對(duì)每個(gè)窗口內(nèi)部進(jìn)行MSA,這樣做感覺又變麻煩了。為了解決這個(gè)麻煩,作者又提出而了Efficient batch computation for shifted configuration,一種更加高效的計(jì)算方法。下面是原論文給的示意圖。

在進(jìn)行cyclic shift之前,需要給子窗口進(jìn)行編碼,編碼之后通過torch.roll對(duì)窗口進(jìn)行滾動(dòng),達(dá)到cyclic shift的效果
計(jì)算 Attention 的時(shí)候,讓具有相同 index QK 進(jìn)行計(jì)算,而忽略不同 index QK 計(jì)算結(jié)果。

想在原始四個(gè)窗口下得到正確的結(jié)果,我們就必須給Attention的結(jié)果加入一個(gè)mask。

if self.shift_size > 0:
    # calculate attention mask for SW-MSA
    H, W = self.input_resolution
    # 生成全零張量
    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
    # 按區(qū)域劃分mask
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1
    # tensor([[[0., 0., 0., 0., 1., 1., 2., 2.],
                 # [0., 0., 0., 0., 1., 1., 2., 2.],
                 # [0., 0., 0., 0., 1., 1., 2., 2.],
                 # [0., 0., 0., 0., 1., 1., 2., 2.],
                 # [3., 3., 3., 3., 4., 4., 5., 5.],
                 # [3., 3., 3., 3., 4., 4., 5., 5.],
                 # [6., 6., 6., 6., 7., 7., 8., 8.],
                 # [6., 6., 6., 6., 7., 7., 8., 8.]]])
    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
    attn_mask = None

SwinTransformerBlock調(diào)用了上面介紹到的WindowAttention,有mask的情況下,在WindowAttention中應(yīng)用mask對(duì)self-attention結(jié)果進(jìn)行調(diào)整。SwinTransformerBlock主要就是W-MSA/SW-MSA的實(shí)現(xiàn),其結(jié)構(gòu)為:LN+(W?MSA/SW?MSA)+LN+MLP。要注意的是shifted的特征圖最后會(huì)還原。這里的LN為nn.LayerNorm;MLP為作者自己的實(shí)現(xiàn)。

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion. 這里的分辨率是轉(zhuǎn)換成patches之后的分辨率,不是原圖像素的分辨率
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. dim*mlp_ration=隱藏層神經(jīng)元個(gè)數(shù)
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim  # 輸入通道C
        self.input_resolution = input_resolution # 輸入分辨率
        self.num_heads = num_heads # self_attention head
        self.window_size = window_size # 窗口大小
        self.shift_size = shift_size # sw-window
        self.mlp_ratio = mlp_ratio #
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim) # nn.LayerNorm
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1

            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            # 上述操作是為了給每個(gè)窗口給上索引

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # nW, window_size*window_size
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size*window_size, window_size*window_size
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):

        H, W = self.input_resolution  # H,W不是像素的分辨率,而是轉(zhuǎn)化成patches之后的分辨率
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)  # self.norm1 = norm_layer(dim) = nn.LayerNorm(dim)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            # torch.roll shifts為正則向下滾動(dòng),為負(fù)則向上滾動(dòng),可以是一個(gè)數(shù)組也可以是一個(gè)元組
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        '''
        經(jīng)過torch.roll之后計(jì)算self.attn是SW-MSA
        不經(jīng)過torch.roll計(jì)算的self.attn是W-MSA
        '''
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        '''
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        '''
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

Block整體結(jié)構(gòu)如下:

  1. 先對(duì)特征圖進(jìn)行LayerNorm
  2. 通過self.shift_size決定是否需要對(duì)特征圖進(jìn)行shift
  3. 然后將特征圖切成一個(gè)個(gè)窗口
  4. 計(jì)算Attention,通過self.attn_mask來區(qū)分Window Attention還是Shift Window Attention
  5. 將各個(gè)窗口合并回來
  6. 如果之前有做shift操作,此時(shí)進(jìn)行reverse shift,把之前的shift操作恢復(fù)
  7. 做dropout和殘差連接
  8. 再通過一層LayerNorm+全連接層,以及dropout和殘差連接

SwinTransformerBlock,此程序中調(diào)用了上面介紹到的WindowAttention,有mask的情況下,在WindowAttention中應(yīng)用mask對(duì)self-attention結(jié)果進(jìn)行調(diào)整。

class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """
    def __init__(self,
        dim,
        depth,
        num_heads,
        window_size,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop=0.,
        attn_drop=0.,
        drop_path=0.,
        norm_layer=nn.LayerNorm,
        downsample=None,
        #use_checkpoint=False,
    ):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
            )
            for i in range(depth)
        ])
        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None
    
    def forward(self, x, H, W):
        """
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        
        # calculate attention mask for SW-MSA ----
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
        
        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        #------
     
    
        for blk in self.blocks:
            x = blk(x, H, W, attn_mask)
            
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W

UperNet

主要使用的是UPerNet的decoder結(jié)構(gòu),Decoder = FPN+PPM:理論上講,深度卷積網(wǎng)絡(luò)的感受野足夠大,但實(shí)際可用的要小很多。為克服這一問題,本文把 PSPNet 中的金字塔池化模塊(PPM)用于骨干網(wǎng)絡(luò)的最后一層,在其被饋送至 FPN 自上而下的分支之前。結(jié)果實(shí)驗(yàn)證明,在帶來有效的全局先驗(yàn)表征方面,PPM 和 FPN 架構(gòu)是高度一致的。

UPerNet 架構(gòu)圖

class UPerDecoder(nn.Module):
    def __init__(self,
        in_dim=[256, 512, 1024, 2048],
        ppm_pool_scale=[1, 2, 3, 6],
        ppm_dim=512,
        fpn_out_dim=256
    ):
        super(UPerDecoder, self).__init__()
        
        # PPM ----
        dim = in_dim[-1]
        ppm_pooling = []
        ppm_conv = []
        
        for scale in ppm_pool_scale:
            ppm_pooling.append(
                nn.AdaptiveAvgPool2d(scale)
            )
            ppm_conv.append(
                nn.Sequential(
                    nn.Conv2d(dim, ppm_dim, kernel_size=1, bias=False),
                    nn.BatchNorm2d(ppm_dim),
                    nn.ReLU(inplace=True)
                )
            )
        self.ppm_pooling   = nn.ModuleList(ppm_pooling)
        self.ppm_conv      = nn.ModuleList(ppm_conv)
        self.ppm_out = conv3x3_bn_relu(dim + len(ppm_pool_scale)*ppm_dim, fpn_out_dim, 1)
        
        # FPN ----
        fpn_in = []
        for i in range(0, len(in_dim)-1):  # skip the top layer
            fpn_in.append(
                nn.Sequential(
                    nn.Conv2d(in_dim[i], fpn_out_dim, kernel_size=1, bias=False),
                    nn.BatchNorm2d(fpn_out_dim),
                    nn.ReLU(inplace=True)
                )
            )
        self.fpn_in = nn.ModuleList(fpn_in)
        
        fpn_out = []
        for i in range(len(in_dim) - 1):  # skip the top layer
            fpn_out.append(
                conv3x3_bn_relu(fpn_out_dim, fpn_out_dim, 1),
            )
        self.fpn_out = nn.ModuleList(fpn_out)
        
        self.fpn_fuse = nn.Sequential(
            conv3x3_bn_relu(len(in_dim) * fpn_out_dim, fpn_out_dim, 1),
        )
    
    def forward(self, feature):
        f = feature[-1]
        pool_shape = f.shape[2:]
        
        ppm_out = [f]
        for pool, conv in zip(self.ppm_pooling, self.ppm_conv):
            p = pool(f)
            p = F.interpolate(p, size=pool_shape, mode='bilinear', align_corners=False)
            p = conv(p)
            ppm_out.append(p)
        ppm_out = torch.cat(ppm_out, 1)
        down = self.ppm_out(ppm_out)
        
        
        #--------------------------------------
        fpn_out = [down]
        for i in reversed(range(len(feature) - 1)):
            lateral = feature[i]
            lateral = self.fpn_in[i](lateral) # lateral branch
            down = F.interpolate(down, size=lateral.shape[2:], mode='bilinear', align_corners=False) # top-down branch
            down = down + lateral
            fpn_out.append(self.fpn_out[i](down))
        
        fpn_out.reverse() # [P2 - P5]
        fusion_shape = fpn_out[0].shape[2:]
        fusion = [fpn_out[0]]
        for i in range(1, len(fpn_out)):
            fusion.append(
                F.interpolate( fpn_out[i], fusion_shape, mode='bilinear', align_corners=False)
            )
        x = self.fpn_fuse( torch.cat(fusion, 1))
        
        return x, fusion

Swin+UPerNet

參數(shù)設(shè)置:

cfg = dict(

        #configs/_base_/models/upernet_swin.py
        basic = dict(
            swin=dict(
                embed_dim=96,
                depths=[2, 2, 6, 2],
                num_heads=[3, 6, 12, 24],
                window_size=7,
                mlp_ratio=4.,
                qkv_bias=True,
                qk_scale=None,
                drop_rate=0.,
                attn_drop_rate=0.,
                drop_path_rate=0.3,
                ape=False,
                patch_norm=True,
                out_indices=(0, 1, 2, 3),
                use_checkpoint=False
            ),

        ),

        #configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py
        swin_tiny_patch4_window7_224=dict(
            checkpoint = pretrain_dir+'swin_tiny_patch4_window7_224_22k.pth',

            swin = dict(
                embed_dim=96,
                depths=[2, 2, 6, 2],
                num_heads=[3, 6, 12, 24],
                window_size=7,
                ape=False,
                drop_path_rate=0.3,
                patch_norm=True,
                use_checkpoint=False,
            ),
            upernet=dict(
                in_channels=[96, 192, 384, 768],
            ),
        ),

        #/configs/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k.py
        swin_small_patch4_window7_224_22k=dict(
            checkpoint = pretrain_dir+'swin_small_patch4_window7_224_22k.pth',

            swin = dict(
                embed_dim=96,
                depths=[2, 2, 18, 2],
                num_heads=[3, 6, 12, 24],
                window_size=7,
                ape=False,
                drop_path_rate=0.3,
                patch_norm=True,
                use_checkpoint=False
            ),
            upernet=dict(
                in_channels=[96, 192, 384, 768],
            ),
        ),
    )

整體結(jié)構(gòu):

class Net(nn.Module):
    
    def load_pretrain(self,):

        checkpoint = cfg[self.arch]['checkpoint']
        print('loading %s ...'%checkpoint)
     
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)['model']
        if 0:
            skip = ['relative_coords_table','relative_position_index']
            filtered={}
            for k,v in checkpoint.items():
                if any([s in k for s in skip ]): continue
                filtered[k]=v
            checkpoint = filtered
        print(self.encoder.load_state_dict(checkpoint,strict=False))  #True


    def __init__( self,):
        super(Net, self).__init__()
        self.output_type = ['inference', 'loss']

        self.rgb = RGB()
        self.arch = 'swin_small_patch4_window7_224_22k'#'swin_tiny_patch4_window7_224'

        self.encoder = SwinTransformerV1(
            ** {**cfg['basic']['swin'], **cfg[self.arch]['swin'],
                **{'out_norm' : LayerNorm2d} }
        )
        encoder_dim =cfg[self.arch]['upernet']['in_channels']
        #[96, 192, 384, 768]

        self.decoder = UPerDecoder(
            in_dim=encoder_dim,
            ppm_pool_scale=[1, 2, 3, 6],
            ppm_dim=512,
            fpn_out_dim=256
        )

        self.logit = nn.Sequential(
            nn.Conv2d(256, 1, kernel_size=1)
        )
        self.aux = nn.ModuleList([
            nn.Conv2d(256, 1, kernel_size=1, padding=0) for i in range(4)
        ])



    def forward(self, batch):
        x = batch['image']
        B,C,H,W = x.shape
        x = self.rgb(x)
        encoder = self.encoder(x)
        last, decoder = self.decoder(encoder)
        logit = self.logit(last)
        logit = F.interpolate(logit, size=None, scale_factor=4, mode='bilinear', align_corners=False)

        output = {}
        if 'loss' in self.output_type:
            output['bce_loss'] = F.binary_cross_entropy_with_logits(logit, batch['mask'])
            output['dice_loss'] = DiceLoss()(logit, batch['mask'])
            output['focal_loss'] = FocalLoss(logits=True, reduce=False)(logit, batch['mask'])
            for i in range(4):
                output['aux%d_loss'%i] = criterion_aux_loss(self.aux[i](decoder[i]),batch['mask'])

        if 'inference' in self.output_type:
            output['probability'] = torch.sigmoid(logit)

        return output

CoaT+Daformer總結(jié):

CoaT

Co-scale conv-attentional image Transformers(CoaT),這是一種基于Transformer的圖像分類器,其主要包含Co-scale和conv-attentional機(jī)制設(shè)計(jì)。

  • 首先,Co-scale機(jī)制在各個(gè)尺度上都保持了Transformers編碼器分支的完整性,同時(shí)允許在不同尺度下學(xué)習(xí)的表示形式能夠有效地進(jìn)行彼此間的通信。同時(shí),作者還設(shè)計(jì)了一系列的串行和并行塊用來實(shí)現(xiàn)Co-scale Attention機(jī)制。
  • 其次,本文通過一種類似于卷積的實(shí)現(xiàn)方式設(shè)計(jì)了一種Factorized Attention機(jī)制,可以使得在因式注意力模塊中實(shí)現(xiàn)相對(duì)位置的嵌入。CoaT為 Vision Transformer提供了豐富的多尺度和上下文建模功能。
    盡管CNN和Self-Attention操作都執(zhí)行一個(gè)加權(quán)和,但它們的權(quán)值計(jì)算方式不同:在CNN中權(quán)值在訓(xùn)練過程中學(xué)習(xí),但在測試過程中固定;而在Self-Attention中,根據(jù)每對(duì)Token之間的相似度或親和度動(dòng)態(tài)計(jì)算權(quán)重。因此,Self-Attention中的自相似操作提供了比卷積操作更具有潛在適應(yīng)性和通用性的建模手段。此外,位置編碼和位置嵌入的引入為Transformer建模提供了靈活性。
    模型詳解:https://jishuin.proginn.com/p/763bfbd566f3

Daformer

DAFormer是使用Transformer進(jìn)行語義分割無監(jiān)督域自適應(yīng)的開篇之作。DAFormer的網(wǎng)絡(luò)結(jié)構(gòu)包括一個(gè)Transformer編碼器和一個(gè)多級(jí)上下文感知特征融合解碼器。它是由3個(gè)簡單但很關(guān)鍵的訓(xùn)練策略來穩(wěn)定訓(xùn)練和避免對(duì)源域的過擬合:

  • 源域上的罕見類采樣通過減輕Self-training對(duì)普通類的確認(rèn)偏差提高了Pseudo-labels的質(zhì)量
  • Thing-Class ImageNet Feature Distance
  • Learning rate warmup促進(jìn)了預(yù)訓(xùn)練的特征遷移

這里我們只是借用decoder模塊,完整的Daformer:https://blog.csdn.net/amusi1994/article/details/124833996

class DaformerDecoder(nn.Module):
    def __init__(
            self,
            encoder_dim = [32, 64, 160, 256],
            decoder_dim = 256,
            dilation = [1, 6, 12, 18],
            use_bn_mlp  = True,
            fuse = 'conv3x3',
    ):
        super().__init__()
        self.mlp = nn.ModuleList([
            nn.Sequential(
                # Conv2dBnReLU(dim, decoder_dim, 1, padding=0), #follow mmseg to use conv-bn-relu
                *(
                  ( nn.Conv2d(dim, decoder_dim, 1, padding= 0,  bias=False),
                    nn.BatchNorm2d(decoder_dim),
                    nn.ReLU(inplace=True),
                )if use_bn_mlp else
                  ( nn.Conv2d(dim, decoder_dim, 1, padding= 0,  bias=True),)
                ),
                
                MixUpSample(2**i) if i!=0 else nn.Identity(),
            ) for i, dim in enumerate(encoder_dim)])
      
        if fuse=='conv1x1':
            self.fuse = nn.Sequential(
                nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 1, padding=0, bias=False),
                nn.BatchNorm2d(decoder_dim),
                nn.ReLU(inplace=True),
            )
        
        if fuse=='conv3x3':
            self.fuse = nn.Sequential(
                nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 3, padding=1, bias=False),
                nn.BatchNorm2d(decoder_dim),
                nn.ReLU(inplace=True),
            )
        
        if fuse=='aspp':
            self.fuse = ASPP(
                decoder_dim*len(encoder_dim),
                decoder_dim,
                dilation,
            )
            
        if fuse=='ds-aspp':
            self.fuse = DSASPP(
                decoder_dim*len(encoder_dim),
                decoder_dim,
                dilation,
            )
        
    
    def forward(self, feature):
        
        out = []
        for i,f in enumerate(feature):
            f = self.mlp[i](f)
            out.append(f)
            #print(f.shape)
        x = self.fuse(torch.cat(out, dim = 1))
        return x, out


class daformer_conv3x3 (DaformerDecoder):
    def __init__(self, **kwargs):
        super(daformer_conv3x3, self).__init__(
            fuse = 'conv3x3',
            **kwargs
        )
class daformer_conv1x1 (DaformerDecoder):
    def __init__(self, **kwargs):
        super(daformer_conv1x1, self).__init__(
            fuse = 'conv1x1',
            **kwargs
        )

class daformer_aspp (DaformerDecoder):
    def __init__(self, **kwargs):
        super(daformer_aspp, self).__init__(
            fuse = 'aspp',
            **kwargs
        )

CoaT+Daformer

class Net(nn.Module):
    
    def __init__(self,
                 encoder=coat_lite_medium,
                 decoder=daformer_conv3x3,
                 encoder_cfg={},
                 decoder_cfg={},
                 ):
        
        super(Net, self).__init__()
        self.output_type = ['inference', 'loss']
        decoder_dim = decoder_cfg.get('decoder_dim', 320)
        self.encoder = encoder
        self.rgb = RGB()
        encoder_dim = self.encoder.embed_dims
        # [64, 128, 320, 512]

        self.decoder = decoder(
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim,
        )
        self.logit = nn.Sequential(
            nn.Conv2d(decoder_dim, 1, kernel_size=1),
            nn.Upsample(scale_factor = 4, mode='bilinear', align_corners=False),
        )

        self.aux = nn.ModuleList([
            nn.Conv2d(decoder_dim, 1, kernel_size=1, padding=0) for i in range(4)
        ])

    def forward(self, batch):
        x = batch['image']
        x = self.rgb(x)

        B, C, H, W = x.shape
        encoder = self.encoder(x)

        last, decoder = self.decoder(encoder)
        logit = self.logit(last)

        output = {}
        if 'loss' in self.output_type:
            output['bce_loss'] = F.binary_cross_entropy_with_logits(logit, batch['mask'])
            output['dice_loss'] = DiceLoss()(logit, batch['mask'])
            output['focal_loss'] = FocalLoss(logits=True, reduce=False)(logit, batch['mask'])
            for i in range(4):
                output['aux%d_loss'%i] = criterion_aux_loss(self.aux[i](decoder[i]),batch['mask'])

        if 'inference' in self.output_type:
            output['probability'] = torch.sigmoid(logit)

        return output

Badcase分析

Yellow: True Positive, Red: False Negative, Green: False Positive

在肺部區(qū)域表現(xiàn)較差,解決思路使用過采樣肺部數(shù)據(jù)的方式重新訓(xùn)練Swin transformer+UPerNet,將其作為肺部預(yù)測模型,與原數(shù)據(jù)模型構(gòu)成集成預(yù)測。

if self.train:
   self.df = df.append(df.loc[df[df['organ']=="lung"].index.repeat(5)]).reset_index(drop=True)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 1、motivation CNN的優(yōu)勢是平移不變、尺度不變、層次感受野; transformer應(yīng)用到CV的優(yōu)勢全...
    HORSEMAN_跬步閱讀 4,315評(píng)論 0 0
  • 0. 前言 近兩年學(xué)術(shù)界對(duì)Transformer在CV上的應(yīng)用可謂異常青睞,這里重點(diǎn)強(qiáng)調(diào)學(xué)術(shù)界的原因是目前工業(yè)界還...
    mrhalyang閱讀 2,587評(píng)論 0 0
  • 解讀SwinTrack: A Simple and Strong Baseline for Transformer...
    雨新閱讀 3,902評(píng)論 0 3
  • 復(fù)現(xiàn)可能是坑的地方:https://hub.fastgit.org/whai362/PVT/issues/21 作...
    Valar_Morghulis閱讀 2,906評(píng)論 1 1
  • Pay Attention to MLPs https://arxiv.org/abs/2105.08050htt...
    Valar_Morghulis閱讀 3,289評(píng)論 0 0

友情鏈接更多精彩內(nèi)容