多模態(tài)模型補充-CLIP

CLIP [OpenAI 21.01]

Learning Transferable Visual Models From Natural Language Supervision

https://arxiv.org/pdf/2103.00020.pdf
https://github.com/openai/CLIP

本來打算放在多模態(tài)模型匯總-按需更新里面的,發(fā)現(xiàn)CLIP內(nèi)容比較多,實驗很豐富,解讀很詳細,還是單獨介紹一下。

以下都是直覺翻譯和認知加工,如有問題,歡迎指正。
花了好幾天,仍然沒讀完...

私以為本文的幾個亮點:

  1. 雙流模型,單獨提取圖片和文本feature;
  2. 對比學習;
  3. 將分類模型轉(zhuǎn)換成圖文匹配任務,用文本來弱監(jiān)督圖片分類。

背景

一般的目標檢測,圖片分類等CV任務,都會預設有哪些類別,要識別哪些種類。實際圖片信息是很豐富的,除了這些預設的類別,其他的視覺信息沒有被充分利用,如果還要識別圖上其他類別,就需要再加標簽。

如果圖片有文本描述,通過文本監(jiān)督學習更多的圖片信息不失為一個好方法。

  1. Mori et al. (1999) 訓練圖文對,預測文本中的nouns和adjectives,探索基于圖片檢索的方式改進內(nèi)容;
  2. Quattoni et al. (2007) 訓練圖文對,預測文本描述,來學習分類權(quán)重空間,學習更多視覺表達;
  3. Srivastava & Salakhutdinov (2012) 訓練多模態(tài)Deep Boltzmann Machines,數(shù)據(jù)為top of low-level image和對應文本tag特征,學習深層表達;
  4. Joulin et al. (2016) 用CNN訓練圖文對,預測文本表述,來學習圖片表達;
  5. 還有些將標題,描述和hashtag等元數(shù)據(jù),和圖片一起訓練,預測這些元數(shù)據(jù)的模型;Li et al.(2017)擴展了這個方法,除了預測individual words,還會預測n-grams,并顯示這種系統(tǒng)在zero-shot transfer到其他分類數(shù)據(jù)集上的能力,即基于這些數(shù)據(jù)集的dictionary學習視覺n-grams,預測目標類別能拿到最高分;
  6. VirTex, ICMLML等模型展現(xiàn)了transformer模型在語言模型,masked語言模型,對比學習中的潛力。

雖然1~6展示了多模態(tài)(圖文)學習在學習模態(tài)表征上的能力,目前通過自然語言監(jiān)督學習視覺表達的研究還是比較少。

  1. 這種方式得到的模型結(jié)果不是最佳,比如 Li et al. (2017) 通過zero-shot方式,在ImageNet數(shù)據(jù)集上,只得到11.5%的準確率,但SOTA是88.4%;
  2. 文本數(shù)據(jù)是無限的,但是視覺數(shù)據(jù)的標簽是有限的,打標簽很費人力。

本文提出CLIP,Contrastive Language–Image Pre-training,用4億對來自網(wǎng)絡的圖文數(shù)據(jù)集,將文本作為圖像標簽,進行訓練。進行下游任務時,只需要提供和圖上的concepts對應的文本描述,就可以進行zero-shot transfer。
模型在30個CV數(shù)據(jù)集上做了實驗,實驗任務包括OCR, action recognition in videos, geo-localization, and many types of fine-grained object classification。模型在大部分的任務上都達到最佳。而且,一般不用再做specific training,就可以和其他baseline 模型媲美。

Model

數(shù)據(jù):4億個網(wǎng)絡公開的圖文對。為覆蓋到更多的視覺concepts, 用了50w個query在搜索引擎搜索圖片,一個query差不多有2w張圖片。
輸入:一個batch有N個圖像文本對;

CLIP 模型框架

模型:對比學習,預測N\times N對圖文數(shù)據(jù),將圖片分類任務轉(zhuǎn)換成圖文匹配任務:

  1. 雙流,2個encoder分別處理文本和圖片數(shù)據(jù),text encoder使用Transformer,image encoder用了2種模型,ResNetVision Transformer(ViT);
    a. 5種ResNet:ResNet-50, ResNet-101, EfficientNet-style的ResNet,包括RN50x4, RN50x16, RN50x64;
    b. 3種ViT:ViT-B/32, ViT-B/16, ViT-L/14;
  2. encoder representation直接線性投影到multi-modal embedding space;
  3. 計算2模態(tài)之間的cosine similarity,讓N個匹配的圖文對相似度最大,不匹配的圖文對相似度最?。?/li>
  4. 對稱的cross-entropy loss;
  5. 數(shù)據(jù)增強:對resized圖片進行random square crop。

偽代碼如下:

CLIP 偽代碼

實驗

1. Zero-shot Transfer

圖片分類的zero-shot指的是對未知類別進行推理。
本文的zero-shot指的是對未知任務進行推理,通過zero-shot transfer衡量任務學習的能力。
Visual N-Grams (Li et al., 2017) 是第一個將zero-shot transfer應用到圖片分類任務上的模型。模型用于學習長度為1~5grams的共142806個visual n-grams,對輸入的圖片,最大化對應的n-grams的probability。

同樣的,CLIP在進行zero-shot transfer時,將數(shù)據(jù)集中的類別標簽轉(zhuǎn)換為文字描述,主要步驟如下:

  1. 輸入:一張圖片 + 所有類別轉(zhuǎn)換的文本(100個類別就是100個文本描述);
  2. 轉(zhuǎn)換向量:經(jīng)過2個encoder,分別輸出image和text的feature embedding;
  3. 計算cosine similarity;
  4. 預測類別:multinomial logistic regression classifier。

模型結(jié)果:以下3個數(shù)據(jù)集,CLIP的表現(xiàn)都要高于Visual N-Grams。

CLIP vs Visual N-Grams

PROMPT ENGINEERING AND ENSEMBLING
有些CV分類數(shù)據(jù)集,類別用ID表示,有的有映射的文本描述。像Flower102和GTSRB數(shù)據(jù)集,缺乏映射關(guān)系,導致有些數(shù)據(jù)不能進行zero-shot transfer;

同義詞也會困擾模型:如果文本輸入只是一個類別,因為缺乏上下文,text encoder就不能很好的區(qū)分詞義。比如Oxford-IIIT Pet 數(shù)據(jù)集中的boxer指代一種狗,但因為缺乏上下文信息,CLIP的text encoder學習的不夠充分,以至于將boxer識別成一種運動。

CLIP的訓練數(shù)據(jù)集中,很少出現(xiàn)一張圖片對應一個單詞的現(xiàn)象。一般來說,都是用一句話來形容圖片。為了減少訓練集和其他任務數(shù)據(jù)集的gap,使用prompt template

Prompt template: “A photo of a {label}.”
比如,類別為狗,則輸入文本為“A photo of a dog.”。
通過這種方式,模型在ImageNet上,提高了1.3%的準確率。

如果提供更細粒度的類別信息(可以理解為文本描述中加屬性/類別標簽),比如上文的“A photo of a {label}.”,細化到“A photo of a {label}, a type of pet.”或者‘A photo of a big {label}”。按照這種方式,在ImageNet數(shù)據(jù)集上,構(gòu)建了80種不同形式的prompt template,準確率又提高了3.5%。

ANALYSIS OF ZERO-SHOT CLIP PERFORMANCE

  1. zero-shot classifier表現(xiàn)怎么樣?
    參照模型Linear Probe on ResNet50:ResNet-50 + logistic regression。
    下圖顯示了,在27個數(shù)據(jù)集中,CLIP在16個數(shù)據(jù)上表現(xiàn)更好。


    zero-shot CLIP vs. Linear Probe on ResNet50
  2. zero-shot CLIP怎么做prediction?
    zero-shot prediction
    基于輸入的圖片,在類別描述中檢索,找到最合適的類別。

"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
#cifar每個類別,輸入圖片,檢索匹配的類別

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

"""
Top predictions:
           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%
"""

Linear-probe evaluation
通過CLIP的image_encoder得到視覺向量,結(jié)合標簽做Logistic Regression。

"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)

def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) # c自定義
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

2. Data Overlap Analysis

預訓練的數(shù)據(jù)集是否和下游數(shù)據(jù)集有重疊?如果有重疊,就會導致數(shù)據(jù)泄露,下游實驗的結(jié)果就不可信。本文作者分析了訓練集和下游任務的重疊程度,以及這種重疊對模型效果的影響。

  1. 對每個evaluation的數(shù)據(jù)集,做一個duplicate detector。對和訓練集相似度高的數(shù)據(jù)進行人工審核,并確定一個threshold,保證高精度的同時最大化召回。通過這個threshold,將數(shù)據(jù)集ALL(evaluation的數(shù)據(jù)集)分解成2個子集:
    a. Overlap:和訓練集的相似度高于threshold的數(shù)據(jù);
    b. Clean:相似度低于threshold的數(shù)據(jù)。
  2. 計算CLIP zero-shot在ALL,Overlap和Clean這三種數(shù)據(jù)集上的準確率,ALL-Clean的準確率差距作為衡量標準;
  3. 因為overlap的程度比較輕,因此進行binomial significance test,使用Clean的準確率作為null hypothesis,計算Overlap的p-value。

作者在35個數(shù)據(jù)集上做了分析,其中,有9個數(shù)據(jù)集和訓練集沒有數(shù)據(jù)重疊。像e MNIST, CLEVR和GTSRB這類偏synthetic或者specialized的數(shù)據(jù),往往不會被視為正常的圖片,和訓練集基本沒有重疊。
重疊率最高的是Country211,重疊率達21.5%,它是從YFCC100M數(shù)據(jù)集剝離出來的,雖然重疊程度很高,但這種重疊率只帶來0.2%的準確率的提升。也許是因為Country211主要是為了衡量geo-localization的ability,但在訓練集中,文本描述并沒有提及到image的location。

Overlap分析

雖然但是,作者也提到是不是duplicate detector做的不夠好。

  1. detector選擇threshold時,考慮保精度同時要求有高召回,但也沒有辦法去檢查完召回到的400million的樣本;
  2. Overlap和Clean數(shù)據(jù)集分布可能發(fā)生很大變化,比如Kinetics-700數(shù)據(jù)集,重疊的數(shù)據(jù)集都是black transition frames,所以上圖左,Detected Data Overlap,Kinetics-700的準確率,相比Clean,降低了20%。

Limitation

  1. 不是和SOTA的比較:以上的數(shù)據(jù)分析,都是和a linear classifier on top of ResNet-50 features進行比較,大部分的數(shù)據(jù)集,都有對應的SOTA模型。為了達到SOTA,zero-shot CLIP估計要提高1000x的算力,當前情況不支持;
  2. 在部分fine-grained分類上表現(xiàn)不佳:
    a. 前面實驗分析發(fā)現(xiàn),模型不能很好的區(qū)分cars,species of flowers, 以及variants of aircraft;
    b. abstract和systematic任務表現(xiàn)不好,比如統(tǒng)計圖上object的數(shù)量;
    c. 在訓練集中基本不會出現(xiàn)的比較novel的任務,表現(xiàn)欠佳,比如classifying
    the distance to the nearest car in a photo;
  3. 訓練集中沒有出現(xiàn)的圖片類型(out-of-distribution),表現(xiàn)不好,比如OCR識別數(shù)字效果可以,但是MNIST的準確率只有88%;
    ...
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容