Pytorch袖珍手冊之六

pytorch pocket reference

原書下載地址:
我用阿里云盤分享了「OReilly.PyTorch.Pocket.R...odels.149209000X.pdf」,你可以不限速下載??
復制這段內容打開「阿里云盤」App 即可獲取
鏈接:https://www.aliyundrive.com/s/NZvnGbTYr6C

第四章 基于已有網絡設計進行神經網絡應用開發(fā)

這一章主要通過三個例子來表現Pytorch在神經網絡開發(fā)應用的便捷性及高效性。

  • 基于遷移學習的圖片分類
  • 自然語言處理里的情感分析
  • GAN,文字生成圖片

遷移學習實現圖片分類

在前面的章節(jié)中我們發(fā)現對于圖片的處理已有很多模型,如AlexNet,VGG等,且這些模型大部分是基于ImageNet數據集(其有1000種分類),但在實際應用中,我們可能面對的分類問題并不在這一千類別中,這就需要我們對模型進行微調,即可通過遷移學習進行實際項目開發(fā)。

示例:對蜜蜂與螞蟻分類
基于ResNet18進行參數微調對蜜蜂與螞蟻圖片進行分類處理。

  • 數據預處理
    加載數據,定義轉換操作,設置數據加載規(guī)則

本示例中用到的數據下載程序如下:

from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile


zipurl = 'https://pytorch.tips/bee-zip'
with urlopen(zipurl) as zipresp:
    with ZipFile(BytesIO(zipresp.read())) as zfile:
        zfile.extractall('./data')
  • 模型設計

  • 訓練及驗證模型
    示例代碼

%matplotlib inline
import os
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
from zipfile import ZipFile

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR


device = 'cuda' if torch.cuda.is_available() else 'cpu'
# windows系統(tǒng)下不使用多進程
workers = 0 if os.name=='nt' else 4

# 網上下載所要用到的zip文件數據,并解壓文件
data_path_zip = 'hymenoptera_data.zip'

with ZipFile(data_path_zip) as zfile:
    zfile.extractall('./data')

# 定義訓練數據的transforms
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456,0.406],
        [0.229, 0.224, 0.225]
    ), 
])

# 定義驗證數據的transforms
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456,0.406],
        [0.229, 0.224, 0.225]
    ), 
])


# 實例化各dataset
train_dataset = datasets.ImageFolder(
        root='data/hymenoptera_data/train',
        transform=train_transforms
    )
val_dataset = datasets.ImageFolder(
        root='data/hymenoptera_data/val',
        transform=val_transforms
    )

# 定義各數據加載器
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=workers
    )
val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=workers
    )

# 模型設計與構建
model = models.resnet18(pretrained=True)
print(model.fc)

# 對最后全連接層進行改動,即變成二分類
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
print(model.fc)

# 訓練及驗證
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 學習率根據epoch增加做相應的調整
exp_lr_scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

num_epochs = 25

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() / inputs.size(0)
        running_corrects += torch.sum(preds==labels.data) / inputs.size(0)
        
    exp_lr_scheduler.step()    
    train_epoch_loss = running_loss / len(train_loader)
    train_epoch_acc = running_corrects / len(train_loader)
        
    # 驗證
    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0
    
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        
        val_running_loss += loss.item() / inputs.size(0)
        val_running_corrects += torch.sum(preds==labels.data) / inputs.size(0)
        
    epoch_loss = val_running_loss / len(val_loader)
    epoch_acc = val_running_corrects.double() / len(val_loader)
    
    
    print("Train: Loss: {:.4f} Acc: {:.4f} Val: Loss: {:.4f} Acc: {:.4f}".format(
            train_epoch_loss,train_epoch_acc,epoch_loss,epoch_acc)
         )
   
  • 測試及保存模型
# 畫圖函數
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    
    if title is not None:
        plt.title(title)
        
inputs, classes = next(iter(val_loader))
out = torchvision.utils.make_grid(inputs)
class_name = val_dataset.classes

outputs = model(inputs.to(device))
_, preds = torch.max(outputs, 1)

imshow(out, title=[class_name[x] for x in preds])

# 保存模型參數
torch.save(model.state_dict(), './restnet18_demo.pt')
image.png

示例也可參考Pytorch官方文檔:
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容