請注意,我們要談?wù)勆窠?jīng)網(wǎng)絡(luò)的注意機制和使用方法

神經(jīng)網(wǎng)絡(luò)中的注意機制(attention mechanism),也被稱為神經(jīng)注意(neural attention)或注意(attention),最近也得到了人們越來越多的關(guān)注。在本文中,作者將嘗試為不同機制和用例找到共同點,此外還將描述并實現(xiàn)兩個軟視覺注意(soft visual attention)機制范例。本文作者 Adam Kosiorek 為牛津大學(xué)在讀博士。

注意機制是什么?

我們可以粗略地把神經(jīng)注意機制類比成一個可以專注于輸入內(nèi)容的某一子集(或特征)的神經(jīng)網(wǎng)絡(luò):它可以選擇特定的輸入。設(shè) x∈R^d 為輸入,z∈R^k 為特征向量,a∈{0,1}^k 是注意向量,g∈R^k 為 attention glimpse,f?(x) 為注意網(wǎng)絡(luò)(attention network)。一般而言,注意實現(xiàn)為如下形式:

其中 ⊙ 是元素依次相乘。對于軟注意(soft attention),其將特征與一個(軟)掩模(mask)相乘,該掩模的值在 0 到 1 之間;對于硬注意(hard attention),這些值被限制為確定的 0 或 1,即 a∈{0,1}k。在后面的案例中,我們可以使用硬注意掩模來直接索引其特征向量

(用 Matlab 的表示方法),它會改變自己的維度,所以現(xiàn)在[圖片上傳中。。。(3)],其中 m≤k。

為了理解注意機制的重要性,我們必須考慮到神經(jīng)網(wǎng)絡(luò)實際上就是一個函數(shù)近似器。它近似不同類型的函數(shù)的能力取決于它的架構(gòu)。典型的神經(jīng)網(wǎng)絡(luò)的實現(xiàn)形式是矩陣乘法構(gòu)成的鏈?zhǔn)竭\算和元素上的非線性,其中輸入的元素或特征向量只會通過加法彼此交互。

注意機制會計算一個用于對特征進行乘法運算的掩模。這種看似無關(guān)痛癢的擴展會產(chǎn)生重大的影響:突然之間,可以使用神經(jīng)網(wǎng)絡(luò)近似的函數(shù)空間多了很多,讓全新的用例成為了可能。為什么會這樣?盡管我沒有證據(jù),但直觀的想法是:有一種理論認(rèn)為神經(jīng)網(wǎng)絡(luò)是一種通用的函數(shù)近似器,可以近似任意函數(shù)并達到任意精度,唯一的限制是隱藏單元的數(shù)量有限。在任何實際的設(shè)置中,情況卻不是:我們受限于可以使用的隱藏單元的數(shù)量??紤]以下案例:我們要近似神經(jīng)網(wǎng)絡(luò)輸入的乘積。前饋神經(jīng)網(wǎng)絡(luò)只能通過使用(許多)加法(以及非線性)來模擬乘法,因此它需要大量神經(jīng)網(wǎng)絡(luò)基礎(chǔ)。如果我們引入乘法交互,那它就會變得簡單且緊湊。

如果我們放松對注意掩模的值的限制,使 a∈R^k,那么上面將注意定義為乘法交互的做法能讓我們考慮更大范圍的模型。比如動態(tài)過濾器網(wǎng)絡(luò)(DFN:Dynamic Filter Networks)使用了一個過濾器生成網(wǎng)絡(luò),它可以基于輸入而計算過濾器(即任意幅度的權(quán)重),并將它們應(yīng)用于特征,這在效果上就是一種乘法交互。使用軟注意機制的唯一區(qū)別是注意權(quán)重并不局限于 0 到 1 之間。在這個方向上更進一步,了解哪些交互應(yīng)該是相加的、哪些應(yīng)該是相乘的是非常有意思的。

論文《A Differentiable Transition Between Additive and Multiplicative Neurons》對這一概念進行了探索,參閱:https://arxiv.org/abs/1604.03736。另外,《深度 | 深度學(xué)習(xí)能力的拓展,Google Brain 講解注意力模型和增強 RNN》這篇文章也對軟注意機制進行了很好的概述。

視覺注意

注意機制可應(yīng)用于任意種類的輸入,不管這些輸入的形態(tài)如何。在輸入為矩陣值的案例中(比如圖像),我們可以考慮使用視覺注意(visual attention)。設(shè)

為圖像,
為 attention glimpse,即將注意機制應(yīng)用于圖像 I 所得到的結(jié)果。

硬注意

對圖像的硬注意已經(jīng)存在了很長時間,即圖像裁剪。在概念上這非常簡單,因為僅需要索引。使用 Python(或 TensorFlow),硬注意可以實現(xiàn)為:

g = I[y:y+h, x:x+w]

上面代碼的唯一問題是不可微分;為了學(xué)習(xí)得到模型的參數(shù),比如借助分?jǐn)?shù)函數(shù)估計器(score-function estimator)等方法,我之前的文章也曾簡要提到過:https://goo.gl/nfPB6r

軟注意

軟注意最簡單的形式在圖像方面和向量值特征方面并無不同,還是和上面的(1)式一樣。論文《Show, Attend and Tell: Neural Image Caption Generation with Visual Attention》是最早使用這種類型的注意的研究之一:https://arxiv.org/abs/1502.03044


該模型可以學(xué)習(xí)注意圖像的特定部分,同時生成描述這部分的詞。

但是,這種類型的軟注意非常浪費計算資源。輸入中變暗的部分對結(jié)果沒有貢獻,但仍然還是需要處理。它也過度參數(shù)化了:實現(xiàn)這種注意的 sigmoid 激活函數(shù)是彼此獨立的。它可以同時選擇多個目標(biāo),但在實際中,我們往往希望進行選擇并且僅關(guān)注場景中的單個元素。下面這兩個機制解決了這個問題,它們分別是由 DRAW(https://arxiv.org/abs/1502.04623)和 Spatial Transformer Networks(https://arxiv.org/abs/1506.02025)這兩項研究引入的。它們也可以重新調(diào)整輸入的大小,從而進一步提升性能。

高斯注意(Gaussian Attention)

高斯注意是使用參數(shù)化的一維高斯過濾器來創(chuàng)造圖像大小的注意圖(attention map)。設(shè)

是注意向量,其分別通過 y 和 x 坐標(biāo)指定了應(yīng)該注意圖像中的哪一部分。其注意掩??梢詣?chuàng)建為


在上圖中,上面一行表示 ax,右邊一列表示 ay,中間的矩形表示得到的結(jié)果 a。這里為了可視化,向量中僅包含 0 和 1. 實際上,它們可以實現(xiàn)為一維的高斯向量。一般而言,高斯的數(shù)量就等于空間的維度,且每個向量都使用了 3 個參數(shù)進行參數(shù)化:第一個高斯的中心 μ、連續(xù)的高斯中心之間的距離 d 和這些高斯的標(biāo)準(zhǔn)差 σ。使用這種參數(shù)化,注意和 glimpse 在注意的參數(shù)方面都是可微分的,因此很容易學(xué)習(xí)。

上面形式的注意仍然很浪費,因為它只選擇了圖像的一部分,同時遮擋了圖像的其它部分。我們可以不直接使用這些向量,而是將它們分別投射進

中?,F(xiàn)在,每個矩陣的每一行都有一個高斯,參數(shù) d 指定了連續(xù)行中的高斯中心之間的距離(以列為單位)。現(xiàn)在可以將 glimpse 實現(xiàn)為:

我最近一篇關(guān)于使用帶有注意機制的 RNN 進行生物啟發(fā)式目標(biāo)跟蹤的論文 HART 中就使用了這種機制,參閱:https://arxiv.org/abs/1706.09262。這里給出一個例子,下面左圖是輸入圖像,右圖是 attention glimpse;這個 glimpse 給出了主圖中綠色標(biāo)記出的框。

下面的代碼可以讓你在 TensorFlow 中為某個 minibatch 樣本創(chuàng)建一個上述的帶有矩陣值的掩模。如果你想創(chuàng)造 Ay,你可以這樣調(diào)用:Ay = gaussian_mask(u, s, d, h, H),其中 u、s、d 即為 μ、σ、d,以這樣的順序并在像素中指定。

def gaussian_mask(u, s, d, R, C):

    """

    :param u: tf.Tensor, centre of the first Gaussian.

    :param s: tf.Tensor, standard deviation of Gaussians.

    :param d: tf.Tensor, shift between Gaussian centres.

    :param R: int, number of rows in the mask, there is one Gaussian per row.

    :param C: int, number of columns in the mask.

    """

    # indices to create centres

    R = tf.to_float(tf.reshape(tf.range(R), (1, 1, R)))

    C = tf.to_float(tf.reshape(tf.range(C), (1, C, 1)))

    centres = u[np.newaxis, :, np.newaxis] + R * d

    column_centres = C - centres

    mask = tf.exp(-.5 * tf.square(column_centres / s))

    # we add eps for numerical stability

    normalised_mask = mask / (tf.reduce_sum(mask, 1, keep_dims=True) + 1e-8)

    return normalised_mask

我們也可以寫一個函數(shù)來直接從圖像中提取 glimpse:

def gaussian_glimpse(img_tensor, transform_params, crop_size):

    """

    :param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)

    :param transform_params: tf.Tensor of size (batch_size, 6), where params are  (mean_y, std_y, d_y, mean_x, std_x, d_x) specified in pixels.

    :param crop_size): tuple of 2 ints, size of the resulting crop

    """

    # parse arguments

    h, w = crop_size

    H, W = img_tensor.shape.as_list()[1:3]

    split_ax = transform_params.shape.ndims -1

    uy, sy, dy, ux, sx, dx = tf.split(transform_params, 6, split_ax)

    # create Gaussian masks, one for each axis

    Ay = gaussian_mask(uy, sy, dy, h, H)

    Ax = gaussian_mask(ux, sx, dx, w, W)

    # extract glimpse

    glimpse = tf.matmul(tf.matmul(Ay, img_tensor, adjoint_a=True), Ax)

    return glimpse

空間變換器(Spatial Transformer)

空間變換器(STN)可以實現(xiàn)更加一般化的變換,而不僅僅是可微分的圖像裁剪,但圖像裁剪也是其可能的用例之一。它由兩個組件構(gòu)成:一個網(wǎng)格生成器和一個采樣器。這個網(wǎng)格生成器會指定一個點構(gòu)成的網(wǎng)格以用于采樣,而采樣器的工作當(dāng)然就是采樣。使用 DeepMind 最近發(fā)布的一個神經(jīng)網(wǎng)絡(luò)庫 Sonnet,可以很輕松地在 TensorFlow 中實現(xiàn)它。Sonnet 地址:https://github.com/deepmind/sonnet

def spatial_transformer(img_tensor, transform_params, crop_size):

    """

    :param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)

    :param transform_params: tf.Tensor of size (batch_size, 4), where params are  (scale_y, shift_y, scale_x, shift_x)

    :param crop_size): tuple of 2 ints, size of the resulting crop

    """

    constraints = snt.AffineWarpConstraints.no_shear_2d()

    img_size = img_tensor.shape.as_list()[1:]

    warper = snt.AffineGridWarper(img_size, crop_size, constraints)

    grid_coords = warper(transform_params)

    glimpse = snt.resampler(img_tensor[..., tf.newaxis], grid_coords)

    return glimpse

高斯注意 vs. 空間變換器

高斯注意和空間變換器可以實現(xiàn)非常相似的行為。我們該選擇使用哪一個呢?這兩者之間有一些細(xì)微的差別:

高斯注意是一種過度參數(shù)化的裁剪機制:需要 6 個參數(shù),但卻只有 4 個自由度(y、x、高度、寬度)。STN 只需要 4 個參數(shù)。

我還沒運行過任何測試,但 STN 應(yīng)該更快。它依賴于在采樣點上的線性插值法,而高斯注意則必須執(zhí)行兩個巨大的矩陣乘法運算。STN 應(yīng)該可以快上一個數(shù)量級(在輸入圖像中的像素方面)。

高斯注意應(yīng)該更容易訓(xùn)練(沒有測試運行)。這是因為結(jié)果得到的 glimpse 中的每個像素都可以是源圖像的相對大批量的像素的凸組合,這使得我們能更容易找到任何錯誤的原因。而 STN 依賴于線性插值法,這意味著每個采樣點的梯度僅相對其最近的兩個像素是非 0 的。

你可以在這里查看代碼示例:https://github.com/akosiorek/akosiorek.github.io/tree/master/notebooks/attention_glimpse.ipynb

一個簡單的范例

讓我們來創(chuàng)建一個簡單的高斯注意和 STN 范例。首先,我們需要載入一些庫,定義尺寸,創(chuàng)建并裁剪輸入圖片。

import tensorflow as tf

import sonnet as snt

import numpy as np

import matplotlib.pyplot as plt

img_size = 10, 10

glimpse_size = 5, 5

# Create a random image with a square

x = abs(np.random.randn(1, *img_size)) * .3

x[0, 3:6, 3:6] = 1

crop = x[0, 2:7, 2:7] # contains the square

隨后,我們需要 TensorFlow 變量的占位符。

tf.reset_default_graph()

# placeholders

tx = tf.placeholder(tf.float32, x.shape, 'image')

tu = tf.placeholder(tf.float32, [1], 'u')

ts = tf.placeholder(tf.float32, [1], 's')

td = tf.placeholder(tf.float32, [1], 'd')

stn_params = tf.placeholder(tf.float32, [1, 4], 'stn_params')

我們現(xiàn)在可以定義高斯注意和 STN 在 Tensorflow 上的簡單表達式。

# Gaussian Attention

gaussian_att_params = tf.concat([tu, ts, td, tu, ts, td], -1)

gaussian_glimpse_expr = gaussian_glimpse(tx, gaussian_att_params, glimpse_size)

# Spatial Transformer

stn_glimpse_expr = spatial_transformer(tx, stn_params, glimpse_size)

運行這些表達式并繪制它們:

sess = tf.Session()

# extract a Gaussian glimpse

u = 2

s = .5

d = 1

u, s, d = (np.asarray([i]) for i in (u, s, d))

gaussian_crop = sess.run(gaussian_glimpse_expr, feed_dict={tx: x, tu: u, ts: s, td: d})

# extract STN glimpse

transform = [.4, -.1, .4, -.1]

transform = np.asarray(transform).reshape((1, 4))

stn_crop = sess.run(stn_glimpse_expr, {tx: x, stn_params: transform})

# plots

fig, axes = plt.subplots(1, 4, figsize=(12, 3))

titles = ['Input Image', 'Crop', 'Gaussian Att', 'STN']

imgs = [x, crop, gaussian_crop, stn_crop]

for ax, title, img in zip(axes, titles, imgs):

    ax.imshow(img.squeeze(), cmap='gray', vmin=0., vmax=1.)

    ax.set_title(title)

    ax.xaxis.set_visible(False)

    ax.yaxis.set_visible(False)

以上代碼也在 Jupyter Notebook 上:https://github.com/akosiorek/akosiorek.github.io/blob/master/notebooks/attention_glimpse.ipynb

結(jié)語

注意機制能夠擴展神經(jīng)網(wǎng)絡(luò)的能力:它們允許近似更加復(fù)雜的函數(shù),用更直觀的話說就是能關(guān)注輸入的特定部分。它們已經(jīng)幫助提升了自然語言處理的基準(zhǔn)表現(xiàn),也帶來了圖像描述、記憶網(wǎng)絡(luò)尋址和神經(jīng)編程器等全新能力。

我相信注意機制最重要的用例還尚未被發(fā)現(xiàn)。比如,我們知道視頻中的目標(biāo)是連續(xù)連貫的,它們不會在幀切換時憑空消失。注意機制可以用于表達這種連貫性的先驗知識。具體怎么做?請拭目以待。 [圖片上傳中。。。(14)]

原文鏈接:http://akosiorek.github.io/ml/2017/10/14/visual-attention.html

入門 | 請注意,我們要談?wù)勆窠?jīng)網(wǎng)絡(luò)的注意機制和使用方法 https://mp.weixin.qq.com/s?__biz=MzA3MzI4MjgzMw==&mid=2650732434&idx=2&sn=c668f9e835a4dc48730048478ba24526&chksm=871b33ecb06cbafae7e8126b8726b273111231d841c1f980c5cc5594dc9a94aff4b676ef4fbd#rd

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容