mnist簡易分類網絡(pytorch)

網上找了一個代碼,閱讀代碼,加上了相應的注釋

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

BATCH_SIZE=512 #大概需要2G的顯存
EPOCHS=20 # 總共訓練批次
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 讓torch判斷是否使用GPU
#獲取數(shù)據(jù)
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, #有數(shù)據(jù)集后改為download=False
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)

#定義模型
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1,28x28
        self.conv1=nn.Conv2d(1,10,5) # 10, 24x24
        self.conv2=nn.Conv2d(10,20,3) # 128, 10x10
        self.fc1 = nn.Linear(20*10*10,500)
        self.fc2 = nn.Linear(500,10)
    def forward(self,x):
        in_size = x.size(0)
        out = self.conv1(x) #24
        out = F.relu(out)
        out = F.max_pool2d(out, 2, 2)  #12
        out = self.conv2(out) #10
        out = F.relu(out)
        out = out.view(in_size,-1)#展開成一維,方便進行FC
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.log_softmax(out,dim=1)
        return out
model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())
#訓練過程
'''
1 獲取loss:輸入圖像和標簽,通過infer計算得到預測值,計算損失函數(shù);
2 optimizer.zero_grad() 清空過往梯度;
3 loss.backward() 反向傳播,計算當前梯度;
4 optimizer.step() 根據(jù)梯度更新網絡參數(shù)
鏈接:https://www.zhihu.com/question/303070254/answer/573037166
'''
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()#梯度置零,清空過往梯度,這種操作模式的好處可參考https://www.zhihu.com/question/303070254
        output = model(data)
        loss = F.nll_loss(output, target)#調用內置函數(shù)
        loss.backward()#反向傳播,計算當前梯度
        optimizer.step()#根據(jù)梯度更新網絡參數(shù)
        if(batch_idx+1)%30 == 0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))#Use torch.Tensor.item() to get a Python number from a tensor containing a single value:
#測試過程
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)#正向計算預測值
            test_loss += F.nll_loss(output, target, reduction='sum').item() # 將一批的損失相加
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下標
            correct += pred.eq(target.view_as(pred)).sum().item()#找到正確的預測值
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
#運行,這里也可以改成main的形式
for epoch in range(1, EPOCHS + 1):
    train(model, DEVICE, train_loader, optimizer, epoch)
    test(model, DEVICE, test_loader)

代碼部分比較簡單
1、數(shù)據(jù)讀取
2、構建網絡模型 ConvNet(nn.Module)
3、構建訓練函數(shù)train
4、構建測試函數(shù)test
5、關于代碼中的model.train(),model.eval()的說明:
參考 PyTorch進行訓練和測試時指定實例化的model模式為:train/eval
eval即evaluation模式,train即訓練模式。僅僅當模型中有Dropout和BatchNorm是才會有影響。因為訓練時dropout和BN都開啟,而一般而言測試時dropout被關閉,BN中的參數(shù)也是利用訓練時保留的參數(shù),所以測試時應進入評估模式。
(在訓練時,??和??2是在整個mini-batch 上計算出來的包含了像是64 或28 或其它一定數(shù)量的樣本,但在測試時,你可能需要逐一處理樣本,方法是根據(jù)你的訓練集估算??和??2,估算的方式有很多種,理論上你可以在最終的網絡中運行整個訓練集來得到??和??2,但在實際操作中,我們通常運用指數(shù)加權平均來追蹤在訓練過程中你看到的??和??2的值。還可以用指數(shù)加權平均,有時也叫做流動平均來粗略估算??和??2,然后在測試中使用??和??2的值來進行你所需要的隱藏單元??值的調整。在實踐中,不管你用什么方式估算??和??2,這套過程都是比較穩(wěn)健的,因此我不太會擔心你具體的操作方式,而且如果你使用的是某種深度學習框架,通常會有默認的估算??和??2的方式,應該一樣會起到比較好的效果)
6、損失函數(shù),這是torch的loss function,這里用的是負對數(shù)似然,推導可以參考負對數(shù)似然

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容