前言
入門深度學習,一般都是跑最經典的MNIST+LeNet-5, LeNet-5網絡結構簡單,MNIST數據集也不是很大,對于初學者來說很方便和友好。作為進階,熟悉Pytorch基本用法之后,躍躍欲試,想自己手寫一個CNN網絡,在一個數據集上進行訓練和測試。
FashionMNIST數據集作為進階的練習很不錯,本實驗將基于FashionMNIST數據集從頭到尾訓練測試一個CNN網絡。
FashionMNIST數據集
簡介
https://github.com/zalandoresearch/fashion-mnist

Fashion-MNISTis a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. We intendFashion-MNISTto serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.
FashionMNIST數據特點:
- 60,000個訓練樣本+10,000個測試樣本
- 樣本圖像為灰度,28x28
- 10個類別
Labels
Each training and test example is assigned to one of the following labels:
| Label | Description |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
image.png
image.png
Why we made Fashion-MNIST
Why we made Fashion-MNIST
The original MNIST dataset contains a lot of handwritten digits. Members of the AI/ML/Data Science community love this dataset and use it as a benchmark to validate their algorithms. In fact, MNIST is often the first dataset researchers try. "If it doesn't work on MNIST, it won't work at all", they said. "Well, if it does work on MNIST, it may still fail on others."
To Serious Machine Learning Researchers
Seriously, we are talking about replacing MNIST. Here are some good reasons:
- MNIST is too easy. Convolutional nets can achieve 99.7% on MNIST. Classic machine learning algorithms can also achieve 97% easily. Check out our side-by-side benchmark for Fashion-MNIST vs. MNIST, and read "Most pairs of MNIST digits can be distinguished pretty well by just one pixel."
- MNIST is overused. In this April 2017 Twitter thread, Google Brain research scientist and deep learning expert Ian Goodfellow calls for people to move away from MNIST.
- MNIST can not represent modern CV tasks, as noted in this April 2017 Twitter thread, deep learning expert/Keras author Fran?ois Chollet.
實驗
獲取數據集
可以自己在網站上下載數據,pytorch提供了更好的方式,直接使用torchvision.datasets中的API,自動下載數據。
由于采用CPU模式,batch size 設置為4, 使用GPU模式,顯存足夠大的話可以將batch size設置大一些,使用英偉達1080 Ti, 本人設置為batch size = 16
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
# https://blog.csdn.net/weixin_41278720/article/details/80778640
# ---------------------------數據集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# 隨機顯示一個batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()
下載完成之后的數據集:



定義一個CNN網絡
定義網絡的一般格式:
- 繼承
nn.Module - 在
__init()__中定義網絡的層 - 重寫(override)父類的抽象方法
forward()
區(qū)別與之前定義LeNet-5, 此次采用nn.Sequential, 傳入一個有序字典OrderedDict。加入了BatchNorm, Dropout層, 并且第一個卷積之后并沒有進行池化,這樣可以保留更多的信息進入下一層。
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class Net(nn.Module):
'''
自定義的CNN網絡,3個卷積層,包含batch norm。2個pool,
3個全連接層,包含Dropout
輸入:28x28x1s
'''
def __init__(self):
super(Net, self).__init__()
self.feature = nn.Sequential(
OrderedDict(
[
# 28x28x1
('conv1', nn.Conv2d(in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2)),
('relu1', nn.ReLU()),
('bn1', nn.BatchNorm2d(num_features=32)),
# 28x28x32
('conv2', nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu2', nn.ReLU()),
('bn2', nn.BatchNorm2d(num_features=64)),
('pool1', nn.MaxPool2d(kernel_size=2)),
# 14x14x64
('conv3', nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1)),
('relu3', nn.ReLU()),
('bn3', nn.BatchNorm2d(num_features=128)),
('pool2', nn.MaxPool2d(kernel_size=2)),
# 7x7x128
('conv4', nn.Conv2d(in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu4', nn.ReLU()),
('bn4', nn.BatchNorm2d(num_features=64)),
('pool3', nn.MaxPool2d(kernel_size=2)),
# out 3x3x64
]
)
)
self.classifier = nn.Sequential(
OrderedDict(
[
('fc1', nn.Linear(in_features=3 * 3 * 64,
out_features=128)),
('dropout1', nn.Dropout2d(p=0.5)),
('fc2', nn.Linear(in_features=128,
out_features=64)),
('dropout2', nn.Dropout2d(p=0.6)),
('fc3', nn.Linear(in_features=64, out_features=10))
]
)
)
def forward(self, x):
out = self.feature(x)
out = out.view(-1, 64 * 3 *3)
out = self.classifier(out)
return out
訓練CNN網絡
- epoch num設置為100, GPU跑的話其實很快就跑完了
- 每迭代100次,進行一次測試,統(tǒng)計Accuarcy, running loss打印一次,并且保存的log文本中,方便后序的分析
- 訓練時候,調用net.train() 將模型設置為train()模式, 測試時候調用net.eval()將模型設置為eval()模式。 否則結果不正確,因為網絡中使用了BatchNorm和Dropout,兩者在eval(), train()模式下有所差異,具體看pytorch文檔。
- 訓練完成之后,保存模型。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
# https://blog.csdn.net/weixin_41278720/article/details/80778640
# ---------------------------數據集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# 隨機顯示一個batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()
# -------------------------定義網絡,參數設置--------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.Net()
print(net)
net = net.to(device)
loss_fc = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# -----------------------------訓練-----------------------------------------
file_runing_loss = open('./log/running_loss.txt', 'w')
file_test_accuarcy = open('./log/test_accuracy.txt', 'w')
epoch_num = 100
for epoch in range(epoch_num):
running_loss = 0.0
accuracy = 0.0
scheduler.step()
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
inputs = inputs.to(device)
labels = labels.to(device)
net.train()
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_fc(outputs, labels)
loss.backward()
optimizer.step()
print(i, loss.item())
# 統(tǒng)計數據,loss,accuracy
running_loss += loss.item()
if i % 20 == 19:
correct = 0
total = 0
net.eval()
for inputs, labels in val_dataloader:
outputs = net(inputs)
_, prediction = torch.max(outputs, 1)
correct += ((prediction == labels).sum()).item()
total += labels.size(0)
accuracy = correct / total
print('[{},{}] running loss = {:.5f} acc = {:.5f}'.format(epoch + 1, i+1, running_loss / 20, accuracy))
file_runing_loss.write(str(running_loss / 20)+'\n')
file_test_accuarcy.write(str(accuracy)+'\n')
running_loss = 0.0
print('\n train finish')
torch.save(net.state_dict(), './model/model_100_epoch.pth')
訓練結果














訓練的結果還不錯,Accuracy最高達到93%左右。
測試網絡
輸入1個batch, batch=4,加載訓練好的模型。
注意: 之前模型的訓練是在GPU上訓練的, 模型保存的存儲布局是按照GPU模式的, 在CPU模式下調用GPU訓練的模型時候需要添加:
net.load_dict(torch.load('xxx.pth', map_loaction='cpu'))
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
test_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, shuffle=False)
plt.figure()
utils.imshow_batch(next(iter(test_dataloader)))
net = net.Net()
net.load_state_dict(torch.load(f='./model/model_100_epoch.pth', map_location='cpu'))
print(net)
images, labels = next(iter(test_dataloader))
outputs = net(images)
_, prediction = torch.max(outputs, 1)
print('label:', labels)
print('prdeiction:', prediction)
plt.show()


完整工程
- 網絡定義
net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class Net(nn.Module):
'''
自定義的CNN網絡,3個卷積層,包含batch norm。2個pool,
3個全連接層,包含Dropout
輸入:28x28x1s
'''
def __init__(self):
super(Net, self).__init__()
self.feature = nn.Sequential(
OrderedDict(
[
# 28x28x1
('conv1', nn.Conv2d(in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2)),
('relu1', nn.ReLU()),
('bn1', nn.BatchNorm2d(num_features=32)),
# 28x28x32
('conv2', nn.Conv2d(in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu2', nn.ReLU()),
('bn2', nn.BatchNorm2d(num_features=64)),
('pool1', nn.MaxPool2d(kernel_size=2)),
# 14x14x64
('conv3', nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1)),
('relu3', nn.ReLU()),
('bn3', nn.BatchNorm2d(num_features=128)),
('pool2', nn.MaxPool2d(kernel_size=2)),
# 7x7x128
('conv4', nn.Conv2d(in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)),
('relu4', nn.ReLU()),
('bn4', nn.BatchNorm2d(num_features=64)),
('pool3', nn.MaxPool2d(kernel_size=2)),
# out 3x3x64
]
)
)
self.classifier = nn.Sequential(
OrderedDict(
[
('fc1', nn.Linear(in_features=3 * 3 * 64,
out_features=128)),
('dropout1', nn.Dropout2d(p=0.5)),
('fc2', nn.Linear(in_features=128,
out_features=64)),
('dropout2', nn.Dropout2d(p=0.6)),
('fc3', nn.Linear(in_features=64, out_features=10))
]
)
)
def forward(self, x):
out = self.feature(x)
out = out.view(-1, 64 * 3 *3)
out = self.classifier(out)
return out
- 訓練
train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
# https://blog.csdn.net/weixin_41278720/article/details/80778640
# ---------------------------數據集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# 隨機顯示一個batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()
# -------------------------定義網絡,參數設置--------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.Net()
print(net)
net = net.to(device)
loss_fc = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# -----------------------------訓練-----------------------------------------
file_runing_loss = open('./log/running_loss.txt', 'w')
file_test_accuarcy = open('./log/test_accuracy.txt', 'w')
epoch_num = 100
for epoch in range(epoch_num):
running_loss = 0.0
accuracy = 0.0
scheduler.step()
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
inputs = inputs.to(device)
labels = labels.to(device)
net.train()
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_fc(outputs, labels)
loss.backward()
optimizer.step()
print(i, loss.item())
# 統(tǒng)計數據,loss,accuracy
running_loss += loss.item()
if i % 20 == 19:
correct = 0
total = 0
net.eval()
for inputs, labels in val_dataloader:
outputs = net(inputs)
_, prediction = torch.max(outputs, 1)
correct += ((prediction == labels).sum()).item()
total += labels.size(0)
accuracy = correct / total
print('[{},{}] running loss = {:.5f} acc = {:.5f}'.format(epoch + 1, i+1, running_loss / 20, accuracy))
file_runing_loss.write(str(running_loss / 20)+'\n')
file_test_accuarcy.write(str(accuracy)+'\n')
running_loss = 0.0
print('\n train finish')
torch.save(net.state_dict(), './model/model_100_epoch.pth')
- 可視化工具
utils.py
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
def imshow_batch(sample_batch):
images = sample_batch[0]
labels = sample_batch[1]
images = make_grid(images, nrow=4, pad_value=255)
# 1,2, 0
images_transformed = np.transpose(images.numpy(), (1, 2, 0))
plt.imshow(images_transformed)
plt.axis('off')
labels = labels.numpy()
plt.title(labels)
- 測試
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])
test_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, shuffle=False)
plt.figure()
utils.imshow_batch(next(iter(test_dataloader)))
net = net.Net()
net.load_state_dict(torch.load(f='./model/model_100_epoch.pth', map_location='cpu'))
print(net)
images, labels = next(iter(test_dataloader))
outputs = net(images)
_, prediction = torch.max(outputs, 1)
print('label:', labels)
print('prdeiction:', prediction)
plt.show()

