Visual Transformer (ViT)模型結構以及原理解析

簡介

Visual Transformer (ViT) 出自于論文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在視覺領域的開篇之作。本文將盡可能簡潔地介紹一下ViT模型的整體架構以及基本原理。ViT模型是基于Transformer Encoder模型的,在這里假設讀者已經了解Transformer的基本知識,如果不了解可以參考鏈接

Vision Transformer如何工作

我們知道Transformer模型最開始是用于自然語言處理(NLP)領域的,NLP主要處理的是文本、句子、段落等,即序列數據。但是視覺領域處理的是圖像數據,因此將Transformer模型應用到圖像數據上面臨著諸多挑戰(zhàn),理由如下:

  1. 與單詞、句子、段落等文本數據不同,圖像中包含更多的信息,并且是以像素值的形式呈現。
  2. 如果按照處理文本的方式來處理圖像,即逐像素處理的話,即使是目前的硬件條件也很難。
  3. Transformer缺少CNNs的歸納偏差,比如平移不變性和局部受限感受野。
  4. CNNs是通過相似的卷積操作來提取特征,隨著模型層數的加深,感受野也會逐步增加。但是由于Transformer的本質,其在計算量上會比CNNs更大。
  5. Transformer無法直接用于處理基于網格的數據,比如圖像數據。

為了解決上述問題,Google的研究團隊提出了ViT模型,它的本質其實也很簡單,既然Transformer只能處理序列數據,那么我們就把圖像數據轉換成序列數據就可以了唄。下面來看下ViT是如何做的。

ViT模型架構

我們先結合下面的動圖來粗略地分析一下ViT的工作流程,如下:

  1. 將一張圖片分成patches
  2. 將patches鋪平
  3. 將鋪平后的patches的線性映射到更低維的空間
  4. 添加位置embedding編碼信息
  5. 將圖像序列數據送入標準Transformer encoder中去
  6. 在較大的數據集上預訓練
  7. 在下游數據集上微調用于圖像分類
ViT原理展示

ViT工作原理解析

我們將上圖展示的過程近一步分解為6步,接下來一步一步地來解析它的原理。如下圖:
ViT分解圖

步驟1、將圖片轉換成patches序列

這一步很關鍵,為了讓Transformer能夠處理圖像數據,第一步必須先將圖像數據轉換成序列數據,但是怎么做呢?假如我們有一張圖片x \in R^{H \times W \times C},patch大小為p,那么我們可以創(chuàng)建N個圖像patches,可以表示為x_p \in R^{N \times (p^2 C)},其中N = \frac {HW} {P^2},N就是序列的長度,類似一個句子中單詞的個數。在上面的圖中,可以看到圖片被分為了9個patches。

步驟2、將Patches鋪平

在原論文中,作者選用的patch大小為16,那么一個patch的shape為(3,16,16),維度為3,將它鋪平之后大小為3x16x16=768。即一個patch變?yōu)殚L度為768的向量。不過這看起來還是有點大,此時可以使用加一個Linear transformation,即添加一個線性映射層,將patch的維度映射到我們指定的embedding的維度,這樣就和NLP中的詞向量類似了。

步驟3、添加Position embedding

與CNNs不同,此時模型并不知道序列數據中的patches的位置信息。所以這些patches必須先追加一個位置信息,也就是圖中的帶數字的向量。實驗表明,不同的位置編碼embedding對最終的結果影響不大,在Transformer原論文中使用的是固定位置編碼,在ViT中使用的可學習的位置embedding 向量,將它們加到對應的輸出patch embeddings上。

步驟4、添加class token

在輸入到Transformer Encoder之前,還需要添加一個特殊的class token,這一點主要是借鑒了BERT模型。添加這個class token的目的是因為,ViT模型將這個class token在Transformer Encoder的輸出當做是模型對輸入圖片的編碼特征,用于后續(xù)輸入MLP模塊中與圖片label進行l(wèi)oss計算。

步驟5、輸入Transformer Encoder

將patch embedding和class token拼接起來輸入標準的Transformer Encoder中,

步驟6、分類

注意Transformer Encoder的輸出其實也是一個序列,但是在ViT模型中只使用了class token的輸出,將其送入MLP模塊中,去輸出最終的分類結果。

總結

ViT的整體思想還是比較簡單,主要是將圖片分類問題轉換成了序列問題。即將圖片patch轉換成token,以便使用Transformer來處理。聽起來很簡單,但是ViT需要在海量數據集上預訓練,然后在下游數據集上進行微調才能取得較好的效果,否則效果不如ResNet50等基于CNN的模型。

參考

?著作權歸作者所有,轉載或內容合作請聯系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容