簡介
Visual Transformer (ViT) 出自于論文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在視覺領域的開篇之作。本文將盡可能簡潔地介紹一下ViT模型的整體架構以及基本原理。ViT模型是基于Transformer Encoder模型的,在這里假設讀者已經了解Transformer的基本知識,如果不了解可以參考鏈接。
Vision Transformer如何工作
我們知道Transformer模型最開始是用于自然語言處理(NLP)領域的,NLP主要處理的是文本、句子、段落等,即序列數據。但是視覺領域處理的是圖像數據,因此將Transformer模型應用到圖像數據上面臨著諸多挑戰(zhàn),理由如下:
- 與單詞、句子、段落等文本數據不同,圖像中包含更多的信息,并且是以像素值的形式呈現。
- 如果按照處理文本的方式來處理圖像,即逐像素處理的話,即使是目前的硬件條件也很難。
- Transformer缺少CNNs的歸納偏差,比如平移不變性和局部受限感受野。
- CNNs是通過相似的卷積操作來提取特征,隨著模型層數的加深,感受野也會逐步增加。但是由于Transformer的本質,其在計算量上會比CNNs更大。
- Transformer無法直接用于處理基于網格的數據,比如圖像數據。
為了解決上述問題,Google的研究團隊提出了ViT模型,它的本質其實也很簡單,既然Transformer只能處理序列數據,那么我們就把圖像數據轉換成序列數據就可以了唄。下面來看下ViT是如何做的。
ViT模型架構
我們先結合下面的動圖來粗略地分析一下ViT的工作流程,如下:
- 將一張圖片分成patches
- 將patches鋪平
- 將鋪平后的patches的線性映射到更低維的空間
- 添加位置embedding編碼信息
- 將圖像序列數據送入標準Transformer encoder中去
- 在較大的數據集上預訓練
- 在下游數據集上微調用于圖像分類

ViT工作原理解析
我們將上圖展示的過程近一步分解為6步,接下來一步一步地來解析它的原理。如下圖:
步驟1、將圖片轉換成patches序列
這一步很關鍵,為了讓Transformer能夠處理圖像數據,第一步必須先將圖像數據轉換成序列數據,但是怎么做呢?假如我們有一張圖片,patch大小為
,那么我們可以創(chuàng)建
個圖像patches,可以表示為
,其中
,
就是序列的長度,類似一個句子中單詞的個數。在上面的圖中,可以看到圖片被分為了9個patches。
步驟2、將Patches鋪平
在原論文中,作者選用的patch大小為16,那么一個patch的shape為(3,16,16),維度為3,將它鋪平之后大小為3x16x16=768。即一個patch變?yōu)殚L度為768的向量。不過這看起來還是有點大,此時可以使用加一個Linear transformation,即添加一個線性映射層,將patch的維度映射到我們指定的embedding的維度,這樣就和NLP中的詞向量類似了。
步驟3、添加Position embedding
與CNNs不同,此時模型并不知道序列數據中的patches的位置信息。所以這些patches必須先追加一個位置信息,也就是圖中的帶數字的向量。實驗表明,不同的位置編碼embedding對最終的結果影響不大,在Transformer原論文中使用的是固定位置編碼,在ViT中使用的可學習的位置embedding 向量,將它們加到對應的輸出patch embeddings上。
步驟4、添加class token
在輸入到Transformer Encoder之前,還需要添加一個特殊的class token,這一點主要是借鑒了BERT模型。添加這個class token的目的是因為,ViT模型將這個class token在Transformer Encoder的輸出當做是模型對輸入圖片的編碼特征,用于后續(xù)輸入MLP模塊中與圖片label進行l(wèi)oss計算。
步驟5、輸入Transformer Encoder
將patch embedding和class token拼接起來輸入標準的Transformer Encoder中,
步驟6、分類
注意Transformer Encoder的輸出其實也是一個序列,但是在ViT模型中只使用了class token的輸出,將其送入MLP模塊中,去輸出最終的分類結果。
總結
ViT的整體思想還是比較簡單,主要是將圖片分類問題轉換成了序列問題。即將圖片patch轉換成token,以便使用Transformer來處理。聽起來很簡單,但是ViT需要在海量數據集上預訓練,然后在下游數據集上進行微調才能取得較好的效果,否則效果不如ResNet50等基于CNN的模型。