tensorflow卷積神經(jīng)處理過程中的conv2d和max_pool

(ps:在寫conv2d和max_pool之前在網(wǎng)上有位朋友已經(jīng)把這個(gè)方法講解得很透徹,我把連接奉上,http://blog.csdn.net/mao_xiao_feng/article/details/53444333 ,http://blog.csdn.net/mao_xiao_feng/article/details/53453926 開頭部分我先作下簡單的表述)

首先看下官方文檔的關(guān)于conv2d

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

除去name和use_cudnn_on_gpu,與方法有關(guān)的一共四個(gè)參數(shù):

  • 第一個(gè)參數(shù)input:指需要做卷積的輸入圖像,它要求是一個(gè)Tensor,具有[batch, in_height, in_width, in_channels]這樣的shape,具體含義是[訓(xùn)練時(shí)一個(gè)batch的圖片數(shù)量, 圖片高度, 圖片寬度, 圖像通道數(shù)]

比方我們這么寫意味著:輸入值為數(shù)量為10的28x28x像素的有一個(gè)圖像通道的圖片

input_data = tf.Variable(tf.random_uniform([10, 28, 28, 1]))

(關(guān)于圖像的通道,在做測(cè)試的時(shí)候mnist就是只有一個(gè)顏色通道的灰色圖,而CIFAR-10則擁有RGB三種顏色的通道)

  • 第二個(gè)參數(shù)filter:是CNN中的卷積核,它要求是一個(gè)Tensor,具有[filter_height, filter_width, in_channels, out_channels]這樣的shape,具體含義是[卷積核的高度,卷積核的寬度,圖像通道數(shù),卷積核個(gè)數(shù)],有一個(gè)地方需要注意,第三維in_channels,就是參數(shù)input的第四維。而第四緯的out_channels則代表卷積核需要提取的特征。

以下的寫法卷積核截取5x5的大小的圖片,第三個(gè)緯度1與input_data對(duì)應(yīng),提取32種特征

W_con = tf.Variable(tf.random_uniform([5, 5, 1, 32]))
  • 第三個(gè)參數(shù)strides:卷積時(shí)在圖像每一維的步長,這是一個(gè)一維的向量,長度4

  • 第四個(gè)參數(shù)padding:只能寫"SAME","VALID"其中之一,寫SAME的話在做滑動(dòng)遇到元素不足時(shí)允許補(bǔ)全,VALID遇到此情況則多余會(huì)被拋棄

我們用代碼對(duì)mnist做實(shí)際的測(cè)試,首先引入tensorflow、mnist和numpy


import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

首先對(duì)于mnist手寫數(shù)字測(cè)試集,相當(dāng)多的資料講述已經(jīng)很多了,簡單的說它包含兩部分?jǐn)?shù)據(jù),一部分是像素28x28的灰色調(diào)圖片,另一部分是這些這些圖片的指代數(shù)字。

比如我們首取一個(gè)數(shù)據(jù),輸出batch_x的shape,batch_y我們暫不理會(huì)

batch_x,batch_y=mnist.train.next_batch(1)
print np.shape(batch_x)
輸出結(jié)果:

(1, 784)

可見batch_x是由28x28=784像素組成的圖片,并且我們還發(fā)現(xiàn),原始的數(shù)據(jù)的緯度不符合運(yùn)算規(guī)則,這時(shí)候我們需要進(jìn)行reshape,然后才能得到我們一開始所說的 [圖片數(shù)量、寬度、高度、顏色通道]

input_data=tf.reshape(batch_x,[1,28,28,1])

為了便于我們觀察卷積神經(jīng)的處理過程中維度shape的變化,我們先定義兩個(gè)通用的卷積和池化方法:

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
#第一層卷積,提取32種特征
W_conv1 = tf.Variable(tf.random_uniform([5, 5, 1, 32]))
b_conv1 = tf.constant(0.1, shape=[32])
h_conv1 = conv2d(input_data, W_conv1)
conv2d_relu1 = tf.nn.relu(h_conv1 + b_conv1)
pooling1 = max_pool_2x2(conv2d_relu1)

#第二層卷積,提取64種特征
W_conv2 = tf.Variable(tf.random_uniform([5, 5, 32, 64]))
b_conv2 = tf.constant(0.1, shape=[64])
h_conv2 = conv2d(pooling1, W_conv2)
conv2d_relu2 = tf.nn.relu(h_conv2 + b_conv2)
pooling2 = max_pool_2x2(conv2d_relu2)

打印一下兩層卷積處理過程的數(shù)據(jù)變化,可以為:

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

print 'input_data:  ', np.shape(input_data)
print 'conv2d_relu1:', np.shape(sess.run(conv2d_relu1))
print 'pooling1:    ', np.shape(sess.run(pooling1))
print 'conv2d_relu2:', np.shape(sess.run(conv2d_relu2))
print 'pooling2:    ', np.shape(sess.run(pooling2))
輸出結(jié)果:

input_data: (1, 28, 28, 1)
conv2d_relu1: (1, 28, 28, 32)
pooling1: (1, 14, 14, 32)
conv2d_relu2: (1, 14, 14, 64)
pooling2: (1, 7, 7, 64)

我們發(fā)現(xiàn)每一步運(yùn)行conv2d,數(shù)據(jù)末尾的緯度值都會(huì)變成我們預(yù)先設(shè)定的特征值,而結(jié)果經(jīng)過池化pooling,圖片的尺寸也是成半的減少

未完待續(xù)...

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • CNN on TensorFlow 本文大部分內(nèi)容均參考于: An Intuitive Explanation o...
    _Randolph_閱讀 8,024評(píng)論 2 31
  • 介紹 先前的教程展示了一個(gè)簡單的線性模型,對(duì)MNIST數(shù)據(jù)集中手寫數(shù)字的識(shí)別率達(dá)到了91%。 在這個(gè)教程中,我們會(huì)...
    Kimichen7764閱讀 1,716評(píng)論 0 7
  • 卷積神經(jīng)網(wǎng)絡(luò)是基于人工神經(jīng)網(wǎng)絡(luò)的深度機(jī)器學(xué)習(xí)方法,成功應(yīng)用于圖像識(shí)別領(lǐng)域。CNN采用了局部連接和權(quán)值共享,保持了網(wǎng)...
    dopami閱讀 1,126評(píng)論 0 0
  • 今天和大家一起來看下基于TensorFlow實(shí)現(xiàn)CNN的代碼示例,源碼參見Convolutional_netWor...
    Jerry_wl閱讀 1,490評(píng)論 0 2
  • 有個(gè)作者群,里面都是熱愛寫作的人,沒有大神。群里討論的主題一直都是技術(shù)類,現(xiàn)在的讀者喜歡什么,什么樣的題材能上首頁...
    Jenny喬閱讀 310評(píng)論 0 1

友情鏈接更多精彩內(nèi)容