前言
在比較大規(guī)模的iOS項目開發(fā)中,會遇到這樣的場景,一個新需求使用的icon可能之前有,但是想找到實在是太難了。最近在學PyTorch,于是想到是否能用PyTorch做一個本地的圖片搜索功能,經(jīng)過一番搜索,非常可行
項目代碼
ImageSearchApp: ImageSearchApp (gitee.com)使用QT做了一個簡單的UI,安裝好依賴,運行image_search_app.py即可。目前代碼未作整理,僅供參考

image.png
上圖搜索到的是VOC2007圖庫中的汽車,在實際iOS項目中我也有嘗試,比如搜索叉號icon,在項目中找到了11個叉號圖片。
Python依賴
- PyQt5
- pyperclip
- torch
- torchvision
- PIL
- PySide2
圖片搜索原理
圖片搜索主要分為以下幾步
- 構建被搜索圖片的特征庫
- 提取輸入圖片的特征
- 將輸入圖片的特征和被搜索圖片的特征進行比對,得出最相近的top K個結果
構建被搜索圖片的特征庫
抽取特征
直接基于PreTrain的模型進行特征抽取,我的代碼中使用了resnet-18模型avgpool模塊的特征輸出,輸出尺寸為512
model = models.resnet18(pretrained=True)
layer = model._modules.get('avgpool')
通過注冊hook獲取特征輸出
image = self.normalize(self.toTensor(img)).unsqueeze(0).to(self.device)
embedding = torch.zeros(1, self.number_features, 1, 1)
def copy_data(m ,i, o): embedding.copy_(o.data)
h = self.feature_layer.register_forward_hook(copy_data)
self.model(image)
h.remove()
特征抽取的完整代碼在feature_extrator.py中
特征持久化
為了避免每次都得重新計算特征,我使用了h5py保存特征值,使用圖片文件的路徑md5作為主key,分別保存path和feature值
h5_base_key = self.md5_of_path(img_full_path)
path_data = dbfile.create_dataset(h5_base_key + '/path', (1), dtype=h5py.special_dtype(vlen=str))
path_data[:] = img_full_path
dbfile[h5_base_key + "/feature"] = feature
這塊的完整代碼在batch_feature_processor.py中
提取輸入圖片的特征
輸入圖片的特征提取直接使用feature_extrator.py即可
特征比對
比對主要使用余弦相似度來評估圖片特征向量的相似度。在二維空間,余弦相似度可以理解為兩個向量的夾角,夾角為0時,相似度最高,此時余弦為1,余弦的計算公式如下
cos(angle) = dot(VecA, VecB) / (|VecA| * |VecB|)
這個公式在高維度同樣適用,比如我們輸出的特征向量,是512維,計算代碼如下
np.inner(feature_a.T, feature_b.T) / ((np.linalg.norm(feature_a, axis=0).reshape(-1, 1)) * ((np.linalg.norm(feature_b, axis=0).reshape(-1,1)).T))
np.inner表示內(nèi)積,在高維空間,使用內(nèi)積計算向量的點乘。np.linalg.norm則是計算第二范數(shù),對應到二維空間就是計算長度。轉置T是為了讓矩陣的Shape匹配。
通過比對余弦值的大小就可以得到最匹配的圖片啦~