1. MNIST數(shù)據(jù)集
1.1 MNIST數(shù)據(jù)集獲取
MNIST數(shù)據(jù)集是入門機(jī)器學(xué)習(xí)/模式識(shí)別的最經(jīng)典數(shù)據(jù)集之一。最早于1998年Yan Lecun在論文:
中提出。經(jīng)典的LeNet-5 CNN網(wǎng)絡(luò)也是在該論文中提出的。
數(shù)據(jù)集包含了0-9共10類手寫數(shù)字圖片,每張圖片都做了尺寸歸一化,都是28x28大小的灰度圖。每張圖片中像素值大小在0-255之間,其中0是黑色背景,255是白色前景。如下圖所示:

MNIST共包含70000張手寫數(shù)字圖片,其中有60000張用作訓(xùn)練集,10000張用作測試集。原始數(shù)據(jù)集可在MNIST官網(wǎng)下載。
下載之后得到4個(gè)壓縮文件:
train-images-idx3-ubyte.gz #60000張訓(xùn)練集圖片
train-labels-idx1-ubyte.gz #60000張訓(xùn)練集圖片對(duì)應(yīng)的標(biāo)簽
t10k-images-idx3-ubyte.gz #10000張測試集圖片
t10k-labels-idx1-ubyte.gz #10000張測試集圖片對(duì)應(yīng)的標(biāo)簽
將其解壓,得到
train-images-idx3-ubyte
train-labels-idx1-ubyte
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte
1.2 MNIST二進(jìn)制文件的存儲(chǔ)格式
解壓得到的四個(gè)文件都是二進(jìn)制格式,我們?nèi)绾潍@取其中的信息呢?這得首先了解MNIST二進(jìn)制文件的存儲(chǔ)格式(官網(wǎng)底部有介紹),以訓(xùn)練集圖像文件train-images-idx3-ubyte為例:

圖像文件的
- 第1-4個(gè)byte(字節(jié),1byte=8bit),即前32bit存的是文件的magic number,對(duì)應(yīng)的十進(jìn)制大小是2051;
- 第5-8個(gè)byte存的是number of images,即圖像數(shù)量60000;
- 第9-12個(gè)byte存的是每張圖片行數(shù)/高度,即28;
- 第13-16個(gè)byte存的是每張圖片的列數(shù)/寬度,即28。
- 從第17個(gè)byte開始,每個(gè)byte存儲(chǔ)一張圖片中的一個(gè)像素點(diǎn)的值。
因?yàn)?code>train-images-idx3-ubyte文件總共包含了60000張圖片數(shù)據(jù),按照以上的存儲(chǔ)方式,我們算一下該文件的大?。?/p>
- 一張圖片包含28x28=784個(gè)像素點(diǎn),需要784bytes的存儲(chǔ)空間;
- 60000張圖片則需要784x60000=47040000 bytes的存儲(chǔ)空間;
- 此外,文件開始處使用了16個(gè)bytes用于存儲(chǔ)magic number、圖像數(shù)量、圖像高度和圖像寬度,因此,訓(xùn)練集圖像文件的大小應(yīng)該是47040000+16=47040016 bytes。
我們查看解壓后的train-images-idx3-ubyte文件的屬性:

文件實(shí)際大小和我們計(jì)算的結(jié)果一致。
類似地,我們查看訓(xùn)練集標(biāo)簽文件train-labels-idx1-ubyte的存儲(chǔ)格式:

和圖像文件類似:
- 第1-4個(gè)byte存的是文件的magic number,對(duì)應(yīng)的十進(jìn)制大小是2049;
- 第5-8個(gè)byte存的是number of items,即label數(shù)量60000;
- 從第9個(gè)byte開始,每個(gè)byte存一個(gè)圖片的label信息,即數(shù)字0-9中的一個(gè)。
計(jì)算一下訓(xùn)練集標(biāo)簽文件train-labels-idx1-ubyte的文件大?。?/p>
- 1x60000+8=60008 bytes。
與該文件實(shí)際的大小一致:

另外兩個(gè)文件,即測試集圖像文件、測試集標(biāo)簽文件的存儲(chǔ)方式和訓(xùn)練圖像文件、訓(xùn)練標(biāo)簽文件相似,只是圖像數(shù)量由60000變?yōu)?0000。
1.3 使用python訪問MNIST數(shù)據(jù)集文件內(nèi)容
知道了MNIST二進(jìn)制文件的存儲(chǔ)方式,下面介紹如何使用python訪問文件內(nèi)容。同樣以訓(xùn)練集圖像文件train-images-idx3-ubyte為例:
首先,使用open()函數(shù)打開文件,并使用read()方法將所有的文件數(shù)據(jù)讀入到一個(gè)字符串中:
yan@yanubuntu:~/codes/Deep-Learning-21-Examples/chapter_1/MNIST_data$ python
Python 2.7.12 (default, Nov 12 2018, 14:36:49)
[GCC 5.4.0 20160609] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> with open('train-images.idx3-ubyte', 'rb') as f:
... file = f.read()
...
>>>
file是str類型,其中的每個(gè)元素就存儲(chǔ)的1個(gè)字節(jié)的內(nèi)容。我們現(xiàn)在查看前4個(gè)字節(jié),即magic number的內(nèi)容,看下是否是前面說的2051:
>>> magic_number=file[:4]
>>> magic_number
'\x00\x00\x08\x03'
>>> magic_number.encode('hex')
'00000803'
>>> int(magic_number.encode('hex'),16)
2051
可以看出前4個(gè)byte的值確實(shí)是2051,但是不能直接輸出magic number的內(nèi)容,需要將其編碼,然后才能轉(zhuǎn)成十進(jìn)制的int類型(有關(guān)字節(jié)編碼的知識(shí)暫時(shí)沒懂,先用著)。
同樣的方式,查看圖像數(shù)量、圖像高度和圖像寬度信息:
>>> num_images = int(file[4:8].encode('hex'),16)
>>> num_images
60000
>>> h_image = int(file[8:12].encode('hex'),16)
>>> h_image
28
>>> w_image = int(file[12:16].encode('hex'),16)
>>> w_image
28
現(xiàn)在獲取第1張圖片的像素信息,然后利用numpy和cv2模塊轉(zhuǎn)換其格式,并保存成.jpg格式的圖片:
>>> image1 = [int(item.encode('hex'), 16) for item in file[16:16+784]]
>>> len(image1)
784
>>> import numpy as np
>>> import cv2
>>> image1_np = np.array(image1, dtype=np.uint8).reshape(28,28,1)
>>> image1_np.shape
(28, 28, 1)
>>> cv2.imwrite('image1.jpg', image1_np)
True
>>>
保存下來的圖片image1.jpg如下圖所示:

該圖片的標(biāo)簽是5,我們可以驗(yàn)證一下訓(xùn)練集標(biāo)簽文件train-labels-idx1-ubyte文件的第一個(gè)標(biāo)簽是否和圖像內(nèi)容一一對(duì)應(yīng):
>>> with open('train-labels.idx1-ubyte', 'rb') as f:
... label_file = f.read()
...
>>> label1 = int(label_file[8].encode('hex'), 16)
>>> label1
5
>>>
訓(xùn)練標(biāo)簽文件的第一張圖片標(biāo)簽是第9個(gè)byte(索引從0開始,所以第9個(gè)byte是label_file[8]),結(jié)果沒問題。
1.4 將MNIST數(shù)據(jù)集保存成.jpg圖片格式
因?yàn)槭褂蒙厦娴玫降膄ile和label_file文件是str類型,因此可以使用迭代的方式,將所有訓(xùn)練和測試集的二進(jìn)制文件格式轉(zhuǎn)成.jpg圖片格式。轉(zhuǎn)換腳本mnist2jpg.py如下:
# coding=utf-8
'''將二進(jìn)制格式的MNIST數(shù)據(jù)集轉(zhuǎn)成.jpg圖片格式并保存,圖片標(biāo)簽包含在圖片名中'''
import numpy as np
import cv2
import os
def save_mnist_to_jpg(mnist_image_file, mnist_label_file, save_dir):
if 'train' in os.path.basename(mnist_image_file):
num_file = 60000
prefix = 'train'
else:
num_file = 10000
prefix = 'test'
with open(mnist_image_file, 'rb') as f1:
image_file = f1.read()
with open(mnist_label_file, 'rb') as f2:
label_file = f2.read()
image_file = image_file[16:]
label_file = label_file[8:]
for i in range(num_file):
label = int(label_file[i].encode('hex'), 16)
image_list = [int(item.encode('hex'), 16) for item in image_file[i*784:i*784+784]]
image_np = np.array(image_list, dtype=np.uint8).reshape(28,28,1)
save_name = os.path.join(save_dir, '{}_{}_{}.jpg'.format(prefix, i, label))
cv2.imwrite(save_name, image_np)
print '{} ==> {}_{}_{}.jpg'.format(i, prefix, i, label)
if __name__ == '__main__':
train_image_file = './train-images.idx3-ubyte'
train_label_file = './train-labels.idx1-ubyte'
test_image_file = 't10k-images.idx3-ubyte'
test_label_file = './t10k-labels.idx1-ubyte'
save_train_dir = './train_images/'
save_test_dir ='./test_images/'
if not os.path.exists(save_train_dir):
os.makedirs(save_train_dir)
if not os.path.exists(save_test_dir):
os.makedirs(save_test_dir)
save_mnist_to_jpg(train_image_file, train_label_file, save_train_dir)
save_mnist_to_jpg(test_image_file, test_label_file, save_test_dir)
2. Tensorflow處理MNIST數(shù)據(jù)集的方式
上面讀取MNIST的代碼可能效率不高,Tensorflow庫中專門有處理MNIST數(shù)據(jù)集的API接口,源代碼涉及到幾個(gè)python文件,我將其整理到一個(gè)read_mnist.py文件中:
# coding=utf-8
"""Tensorflow中用于讀取MNIST數(shù)據(jù)集的簡化代碼"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import collections
import gzip
import numpy
#帶名字的tuple,方便使用train/validation/test區(qū)分不同數(shù)據(jù)集
Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test'])
def _read32(bytestream):
#numpy.dtype.newbyteorder()函數(shù)返回一種指定字節(jié)序的dtype
#參數(shù)('>')表示big endian的字節(jié)序,MNIST官網(wǎng)底部有提到MNIST二進(jìn)制數(shù)據(jù)使用這種字節(jié)序,具體不太懂
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
#numpy.frombuffer()一次讀取bytestream中的4個(gè)byte,返回一個(gè)一維數(shù)組,所以需要使用索引[0]取其中的元素
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
#從網(wǎng)上下載的mnist數(shù)據(jù)集圖像文件(.gz)中讀取數(shù)據(jù)
#返回一個(gè)4-D、np.uint8類型的ndarray,shape=[num_iamges, h, w, channels]
# 比如針對(duì)訓(xùn)練集圖像文件,返回值shape=[60000, 28, 28, 1]
def extract_images(f):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
data: A 4D uint8 numpy array [index, y, x, depth].
Raises:
ValueError: If the bytestream does not start with 2051.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
#圖像文件的前4個(gè)byte記錄magic number,數(shù)值是2051
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' %
(magic, f.name))
#圖像文件的第4-8個(gè)byte記錄的是圖像數(shù)量,訓(xùn)練集文件是60000,測試集文件是10000
num_images = _read32(bytestream)
#圖像文件的第8-12個(gè)byte記錄的是圖像的h
rows = _read32(bytestream)
#圖像文件的第12-16個(gè)byte記錄的是圖像的w
cols = _read32(bytestream)
#文件剩下的內(nèi)容記錄所有圖片中的像素值,將其全部讀取到一維數(shù)組data中,dtype=np.uint8
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
#reshape成[num_images, h, w, channels]
data = data.reshape(num_images, rows, cols, 1)
return data
#將label轉(zhuǎn)成one-hot向量
def dense_to_one_hot(labels_dense, num_classes):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
#讀取從網(wǎng)上下載的mnist數(shù)據(jù)集標(biāo)簽文件的內(nèi)容
#如果one_hot=True: 返回一個(gè)2-D、shape=[num_images, 10]、dtype=np.uint8的ndarray
#如果one_hot=False: 返回一個(gè)1-D、shape=[num_images]、dtype=np.uint8的ndarray
def extract_labels(f, one_hot=False, num_classes=10):
"""Extract the labels into a 1D uint8 numpy array [index].
Args:
f: A file object that can be passed into a gzip reader.
one_hot: Does one hot encoding for the result.
num_classes: Number of classes for the one hot encoding.
Returns:
labels: a 1D uint8 numpy array.
Raises:
ValueError: If the bystream doesn't start with 2049.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST label file: %s' %
(magic, f.name))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return dense_to_one_hot(labels, num_classes)
return labels
class DataSet(object):
def __init__(self, images, labels, dtype=numpy.float32, reshape=True):
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
if reshape:
assert images.shape[3] == 1
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])
if dtype == numpy.float32:
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, shuffle=True):
"""Return the next `batch_size` examples from this data set."""
start = self._index_in_epoch
# Shuffle for the first epoch
if self._epochs_completed == 0 and start == 0 and shuffle:
perm0 = numpy.arange(self._num_examples)
numpy.random.shuffle(perm0)
self._images = self.images[perm0]
self._labels = self.labels[perm0]
# Go to the next epoch
if start + batch_size > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Get the rest examples in this epoch
rest_num_examples = self._num_examples - start
images_rest_part = self._images[start:self._num_examples]
labels_rest_part = self._labels[start:self._num_examples]
# Shuffle the data
if shuffle:
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._images = self.images[perm]
self._labels = self.labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size - rest_num_examples
end = self._index_in_epoch
images_new_part = self._images[start:end]
labels_new_part = self._labels[start:end]
return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
else:
self._index_in_epoch += batch_size
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
def read_data_sets(mnist_dir, one_hot=False, dtype=numpy.float32,
reshape=True, validation_size=5000):
'''讀取MNIST數(shù)據(jù)集
Args:
mnist_dir: 存放4個(gè)MNIST數(shù)據(jù)集壓縮文件的文件夾,數(shù)據(jù)集文件從網(wǎng)址http://yann.lecun.com/exdb/mnist/下載
one_hot: 如果one_hot=True, 返回的labels是one_hot編碼
reshape: 如果reshape=True,返回的images將展開成784維的向量
Return:
一個(gè)Datasets對(duì)象,是一個(gè)namedtuple:
Datasets.train包含訓(xùn)練集數(shù)據(jù)
Datasets.validation包含驗(yàn)證集數(shù)據(jù)
Datasets.test包含測試集數(shù)據(jù)
'''
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
#讀取訓(xùn)練集圖像,train_images.shape=[60000, 28, 28, 1], dtype=np.uint8
local_file = os.path.join(mnist_dir, TRAIN_IMAGES)
with open(local_file, 'rb') as f:
train_images = extract_images(f)
#讀取訓(xùn)練集標(biāo)簽,如果one_hot=False, train_labels.shape=[60000,]
#如果one_hot=True, train_labels.shape=[60000,10]
local_file = os.path.join(mnist_dir, TRAIN_LABELS)
with open(local_file, 'rb') as f:
train_labels = extract_labels(f, one_hot=one_hot)
local_file = os.path.join(mnist_dir, TEST_IMAGES)
with open(local_file, 'rb') as f:
test_images = extract_images(f)
local_file = os.path.join(mnist_dir, TEST_LABELS)
with open(local_file, 'rb') as f:
test_labels = extract_labels(f, one_hot=one_hot)
if not 0 <= validation_size <= len(train_images):
raise ValueError(
'Validation size should be between 0 and {}. Received: {}.'
.format(len(train_images), validation_size))
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
validation = DataSet(validation_images, validation_labels,dtype=dtype, reshape=reshape)
test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
return Datasets(train=train, validation=validation, test=test)
主要目的是理解其中的核心代碼,為了簡潔,刪除了一些不必要的代碼。
首先需要將MNIST官網(wǎng)的4個(gè)數(shù)據(jù)集文件手動(dòng)下載到本地文件夾,不用解壓。然后就可以調(diào)用read_mnist.py中的read_data_sets()函數(shù)進(jìn)行讀取了。經(jīng)過整理后的函數(shù)比較簡單,根據(jù)部分注釋應(yīng)該能看懂,純python代碼,沒有使用到Tensorflow函數(shù)。
3. 使用CNN網(wǎng)絡(luò)訓(xùn)練MNIST數(shù)據(jù)集
下面再貼一個(gè)使用Tensorflow構(gòu)造的2層CNN分類網(wǎng)絡(luò)訓(xùn)練MNIST數(shù)據(jù)集的代碼convolution.py,代碼直接copy自
- 《21個(gè)項(xiàng)目玩轉(zhuǎn)深度學(xué)習(xí) 基于TensorFlow的實(shí)踐詳解 》第一章
可以用來驗(yàn)證read_mnist.py中的函數(shù):
# coding: utf-8
import tensorflow as tf
import read_mnist as input_data
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
if __name__ == '__main__':
# 讀入數(shù)據(jù)
mnist = input_data.read_data_sets(mnist_dir="MNIST_data/", one_hot=True)
# x為訓(xùn)練圖像的占位符、y_為訓(xùn)練圖像標(biāo)簽的占位符
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# 將單張圖片從784維向量重新還原為28x28的矩陣圖片
x_image = tf.reshape(x, [-1, 28, 28, 1])
# 第一層卷積層
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
# 第二層卷積層
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
# 全連接層,輸出為1024維的向量
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# 使用Dropout,keep_prob是一個(gè)占位符,訓(xùn)練時(shí)為0.5,測試時(shí)為1
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
# 把1024維的向量轉(zhuǎn)換成10維,對(duì)應(yīng)10個(gè)類別
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
# 我們不采用先Softmax再計(jì)算交叉熵的方法,而是直接用tf.nn.softmax_cross_entropy_with_logits直接計(jì)算
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
# 同樣定義train_step
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
# 定義測試的準(zhǔn)確率
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 創(chuàng)建Session和變量初始化
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# 訓(xùn)練20000步
for i in range(20000):
batch = mnist.train.next_batch(50)
# 每100步報(bào)告一次在驗(yàn)證集上的準(zhǔn)確度
if i % 100 == 0:
batch_val = mnist.validation.next_batch(50)
train_accuracy = accuracy.eval(feed_dict={
x: batch_val[0], y_: batch_val[1], keep_prob: 1.0})
print("step %d, training accuracy %g" % (i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 訓(xùn)練結(jié)束后報(bào)告在測試集上的準(zhǔn)確度
print("test accuracy %g" % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
4. 總結(jié)
以上總結(jié)的MNIST數(shù)據(jù)集的詳細(xì)特點(diǎn)及使用方式,參考了Tensorflow官方提供的MNIST處理代碼,以及別人書中的一個(gè)簡單樣例,純搬磚。記錄下來自己以后查找方便。
參考資料
《21個(gè)項(xiàng)目玩轉(zhuǎn)深度學(xué)習(xí) 基于TensorFlow的實(shí)踐詳解 》