從圖片相似度學(xué)習(xí)圖片的表示

很多時(shí)候帶分類(lèi)標(biāo)注的圖片樣本是很難獲得的,但是圖片之間的相似度卻不難獲得。最簡(jiǎn)單的方式有幾個(gè):

最早用來(lái)從相似圖片數(shù)據(jù)集上學(xué)習(xí)圖片表示的網(wǎng)絡(luò)結(jié)構(gòu)是siamese網(wǎng)絡(luò)。

siamese.png

兩幅圖通過(guò)兩個(gè)共享權(quán)重的CNN得到各自的表示。而各自表示的距離決定了他們是相似還是不相似。

在siamese網(wǎng)絡(luò)之后,又提出了用triplet loss來(lái)學(xué)習(xí)圖片的表示,大概思路如下:

  • 拿到3張圖片A, B, C。其中A,B相似,A,C不相似。
  • 學(xué)到A, B, C 的表示,使得A,B之間的距離盡量小,而A,C之間的距離盡量大。

用mxnet實(shí)現(xiàn)triplet loss很簡(jiǎn)單,代碼如下:

def get_net(batch_size):
    same = mx.sym.Variable('same')
    diff = mx.sym.Variable('diff')
    anchor = mx.sym.Variable('anchor')
    one = mx.sym.Variable('one')
    one = mx.sym.Reshape(data = one, shape = (-1, 1))
    conv_weight = []
    conv_bias = []
    for i in range(3):
        conv_weight.append(mx.sym.Variable('conv' + str(i) + '_weight'))
        conv_bias.append(mx.sym.Variable('conv' + str(i) + '_bias'))
    fc_weight = mx.sym.Variable('fc_weight')
    fc_bias = mx.sym.Variable('fc_bias')
    fa = get_conv(anchor, conv_weight, conv_bias, fc_weight, fc_bias)
    fs = get_conv(same, conv_weight, conv_bias, fc_weight, fc_bias)
    fd = get_conv(diff, conv_weight, conv_bias, fc_weight, fc_bias)
    
    fs = fa - fs
    fd = fa - fd
    fs = fs * fs
    fd = fd * fd
    fs = mx.sym.sum(fs, axis = 1, keepdims = 1)
    fd = mx.sym.sum(fd, axis = 1, keepdims = 1)
    loss = fd - fs
    loss = one - loss
    loss = mx.sym.Activation(data = loss, act_type = 'relu')
    return mx.sym.MakeLoss(loss)

這里conv_weight[], fc_weight, conv_bias[], fc_bias是兩個(gè)CNN網(wǎng)絡(luò)共享的模型。理論上這里可以用任何的CNN網(wǎng)絡(luò)(AlexNet, GoogleNet, ResNet)。我們用了一個(gè)特別簡(jiǎn)單的CNN,如下:

def get_conv(data, conv_weight, conv_bias, fc_weight, fc_bias):
    cdata = data
    ks = [5, 3, 3]
    for i in range(3):
        cdata = mx.sym.Convolution(data=cdata, kernel=(ks[i],ks[i]), num_filter=32,
                                   weight = conv_weight[i], bias = conv_bias[i],
                                   name = 'conv' + str(i))
        cdata = mx.sym.Pooling(data=cdata, pool_type="avg", kernel=(2,2), stride=(1, 1))
        cdata = mx.sym.Activation(data=cdata, act_type="relu")

    cdata = mx.sym.Flatten(data = cdata)
    cdata = mx.sym.FullyConnected(data = cdata, num_hidden = 1024,
                                  weight = fc_weight, bias = fc_bias, name='fc')
    cdata = mx.sym.L2Normalization(data = cdata)
    return cdata

Triple loss用的Simultaneous Feature Learning and Hash Coding with Deep Neural Networks里的定義:

Triple Loss

下面是在cifar10數(shù)據(jù)集上測(cè)試的結(jié)果。為了形象的表示,采用了圖片檢索的方式(因?yàn)椴皇钦撐?,所以就不那么?yán)謹(jǐn)了)。在訓(xùn)練集上學(xué)習(xí)圖片的表示,然后對(duì)于測(cè)試集的一張隨機(jī)圖片,找到測(cè)試集上和他最相似的其他圖片:

cifar_triple.png

在其他的論文中還有一些其它評(píng)測(cè)方式,比如學(xué)習(xí)到表示后,用一個(gè)SVM去學(xué)習(xí)分類(lèi),看看分類(lèi)的準(zhǔn)確度相比End-End的CNN如何。基本的結(jié)論都是精度會(huì)稍微低一些,但是沒(méi)用明顯區(qū)別。這說(shuō)明學(xué)到的表示是靠譜的。

全部的代碼見(jiàn) github

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容