tf.nn.conv2d()

概述

tf.nn.conv2d是TensorFlow里面實(shí)現(xiàn)卷積的函數(shù),參考文檔對它的介紹并不是很詳細(xì),實(shí)際上這是搭建卷積神經(jīng)網(wǎng)絡(luò)比較核心的一個方法,非常重要。

說明

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

參數(shù)

  • input:指需要做卷積的輸入圖像,它要求是一個Tensor,具有[batch, in_height, in_width, in_channels]這樣的shape,具體含義是[訓(xùn)練時(shí)一個batch的圖片數(shù)量, 圖片高度, 圖片寬度, 圖像通道數(shù)],注意這是一個4維的Tensor,要求類型為float32和float64其中之一
  • filter:相當(dāng)于CNN中的卷積核,它要求是一個Tensor,具有[filter_height, filter_width, in_channels, out_channels]這樣的shape,具體含義是[卷積核的高度,卷積核的寬度,圖像通道數(shù),卷積核個數(shù)],要求類型與參數(shù)input相同,有一個地方需要注意,第三維in_channels,就是參數(shù)input的第四維
  • strides:卷積時(shí)在圖像每一維的步長,這是一個一維的向量,長度4
  • padding:string類型的量,只能是"SAME","VALID"其中之一,這個值決定了不同的卷積方法,當(dāng)其為‘SAME’時(shí),表示卷積核可以停留在圖像邊緣。
  • use_cudnn_on_gpu:bool類型,是否使用cudnn加速,默認(rèn)為true
  • name:指定該操作的name

返回

結(jié)果返回一個Tensor,這個輸出,就是我們常說的feature map

實(shí)例

1.考慮一種最簡單的情況,現(xiàn)在有一張3×3單通道的圖像(對應(yīng)的shape:[1,3,3,1]),用一個1×1的卷積核(對應(yīng)的shape:[1,1,1,1])去做卷積,最后會得到一張3×3的feature map。輸出:[1,3, 3, 1]

input_arg = tf.Variable(tf.ones([1, 3, 3, 1]))
filter_arg = tf.Variable(tf.ones([1, 1, 1, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 3, 3, 1], 
                   use_cudnn_on_gpu=False, padding='VALID')

--------------case1--------------
[[[[ 1.]
[ 1.]
[ 1.]]

[[ 1.]
[ 1.]
[ 1.]]

[[ 1.]
[ 1.]
[ 1.]]]]

2.增加圖片的通道數(shù),使用一張3×3五通道的圖像(對應(yīng)的shape:[1,3,3,5]),用一個1×1的卷積核(對應(yīng)的shape:[1,1,1,1])去做卷積,仍然是一張3×3的feature map,這就相當(dāng)于每一個像素點(diǎn),卷積核都與該像素點(diǎn)的每一個通道做點(diǎn)積。輸出:[1, 3, 3, 1]

input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([1, 1, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                   use_cudnn_on_gpu=False, padding='VALID')

--------------case2--------------
[[[[ 5.]
[ 5.]
[ 5.]]

[[ 5.]
[ 5.]
[ 5.]]

[[ 5.]
[ 5.]
[ 5.]]]]

3.把卷積核擴(kuò)大,現(xiàn)在用3×3的卷積核做卷積,最后的輸出是一個值,相當(dāng)于情況2的feature map所有像素點(diǎn)的值求和。輸出:[1, 1, 1, 1]

input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='VALID')

--------------case3--------------
[[[[ 45.]]]]

4.使用更大的圖片將情況2的圖片擴(kuò)大到5×5,仍然是3×3的卷積核,令步長為1,輸出3×3的feature map。
注意我們可以把這種情況看成情況2和情況3的中間狀態(tài),卷積核以步長1滑動遍歷全圖,以下x表示的位置,表示卷積核停留的位置,每停留一個,輸出feature map的一個像素。輸出:[1, 3, 3, 1]
.....
.xxx.
.xxx.
.xxx.
.....

input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='VALID')

--------------case4--------------
[[[[ 45.]
[ 45.]
[ 45.]]

[[ 45.]
[ 45.]
[ 45.]]

[[ 45.]
[ 45.]
[ 45.]]]]

5.上面我們一直令參數(shù)padding的值為‘VALID’,當(dāng)其為‘SAME’時(shí),表示卷積核可以停留在圖像邊緣,輸出:[1, 5, 5, 1]
xxxxx
xxxxx
xxxxx
xxxxx
xxxxx

input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')

--------------case5--------------
[[[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]

[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]

[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]

[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]

[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]]]

6.如果卷積核有多個,此時(shí)輸出7張5×5的feature map。輸出:[1, 5, 5, 7]

 input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case6'])

--------------case6--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]

7.步長不為1的情況,文檔里說了對于圖片,因?yàn)橹挥袃删S,通常strides取[1,stride,stride,1]。輸出:[1, 3, 3, 7]
x.x.x
.....
x.x.x
.....
x.x.x

input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case7'])

--------------case7--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]

8.如果batch值不為1,同時(shí)輸入4張圖,輸出的每張圖,都有7張3×3的feature map。輸出:[4, 3, 3, 7]

input_arg = tf.Variable(tf.ones([4, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')

--------------case8--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]

[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]

[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]

[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]

代碼

import tensorflow as tf

oplist = []

# input_arg = [batch, in_height, in_width, in_channels]
# filter_arg = [filter_height, filter_width, in_channels, out_channels]

# case 1
input_arg = tf.Variable(tf.ones([1, 3, 3, 1]))
filter_arg = tf.Variable(tf.ones([1, 1, 1, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                   use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case1'])

# case 2
input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([1, 1, 5, 1]))

op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                   use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case2'])

# case 3
input_arg = tf.Variable(tf.ones([1, 3, 3, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case3'])

# case 4
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='VALID')
oplist.append([op, 'case4'])

# case 5
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 1]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case5'])

# case 6
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 1, 1, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case6'])

# case 7
input_arg = tf.Variable(tf.ones([1, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case7'])

# case 8
input_arg = tf.Variable(tf.ones([4, 5, 5, 5]))
filter_arg = tf.Variable(tf.ones([3, 3, 5, 7]))
op = tf.nn.conv2d(input_arg, filter_arg, strides=[1, 2, 2, 1], 
                  use_cudnn_on_gpu=False, padding='SAME')
oplist.append([op, 'case8'])

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for aop in oplist:
        print('--------------{}--------------'.format(aop[1]))  
        print(sess.run(aop[0]))
        print('\n')

--------------case1--------------
[[[[ 1.]
[ 1.]
[ 1.]]

[[ 1.]
[ 1.]
[ 1.]]

[[ 1.]
[ 1.]
[ 1.]]]]

--------------case2--------------
[[[[ 5.]
[ 5.]
[ 5.]]

[[ 5.]
[ 5.]
[ 5.]]

[[ 5.]
[ 5.]
[ 5.]]]]

--------------case3--------------
[[[[ 45.]]]]

--------------case4--------------
[[[[ 45.]
[ 45.]
[ 45.]]

[[ 45.]
[ 45.]
[ 45.]]

[[ 45.]
[ 45.]
[ 45.]]]]

--------------case5--------------
[[[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]

[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]

[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]

[[ 30.]
[ 45.]
[ 45.]
[ 45.]
[ 30.]]

[[ 20.]
[ 30.]
[ 30.]
[ 30.]
[ 20.]]]]

--------------case6--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]

--------------case7--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]

--------------case8--------------
[[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]

[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]

[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]

[[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]

[[ 30. 30. 30. ..., 30. 30. 30.]
[ 45. 45. 45. ..., 45. 45. 45.]
[ 30. 30. 30. ..., 30. 30. 30.]]

[[ 20. 20. 20. ..., 20. 20. 20.]
[ 30. 30. 30. ..., 30. 30. 30.]
[ 20. 20. 20. ..., 20. 20. 20.]]]]

參考

http://www.cnblogs.com/welhzh/p/6607581.html

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容