五、pytorch進階訓(xùn)練技巧——pytorch學(xué)習(xí)

1. 自定義損失函數(shù)

pytorch 在nn.Module中提供了很多的常用的損失函數(shù),但是有時需要提出全新的函數(shù)來提升模型的表現(xiàn),這時需要自己來定義損失函數(shù)

1.1 以函數(shù)方式定義

就是自己定義一個函數(shù),沒啥好說的

1.2 以類的方式定義

以類方式定義更加常用,在以類方式定義損失函數(shù)時,我們?nèi)绻疵恳粋€損失函數(shù)的繼承關(guān)系我們就可以發(fā)現(xiàn)Loss函數(shù)部分繼承自_loss, 部分繼承自_WeightedLoss, 而_WeightedLoss繼承自_loss, _loss繼承自 nn.Module。我們可以將其當作神經(jīng)網(wǎng)絡(luò)的一層來對待,同樣地,我們的損失函數(shù)類就需要繼承自nn.Module類。

如下舉例IoUloss函數(shù)定義:

在自定義損失函數(shù)時,涉及到數(shù)學(xué)運算時,我們最好全程使用PyTorch提供的張量計算接口,這樣就不需要我們實現(xiàn)自動求導(dǎo)功能并且我們可以直接調(diào)用cuda,使用numpy或者scipy的數(shù)學(xué)運算時,操作會有些麻煩,

from turtle import forward
import torch.nn as nn
import torch.nn.functional as F

class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()

        union = total - intersection
        IoU = (intersection + smooth) / (union + smooth)
        return  1 - IoU

2. 動態(tài)調(diào)整學(xué)習(xí)率

2.1 使用官方scheduler

PyTorch已經(jīng)在torch.optim.lr_scheduler為我們封裝好了一些動態(tài)調(diào)整學(xué)習(xí)率的方法供我們使用,如下面列出的這些scheduler。

from torch.optim import lr_scheduler

lr_scheduler.LambdaLR                   #將每個參數(shù)組的學(xué)習(xí)率設(shè)置為初始lr乘以給定函數(shù)
lr_scheduler.StepLR                     # 在每個epoch,衰減學(xué)習(xí)率
lr_scheduler.MultiStepLR                # 一旦epoch達到一定數(shù)量,按γ衰減學(xué)習(xí)率
lr_scheduler.ExponentialLR              # 按指數(shù)篩選學(xué)習(xí)率           
lr_scheduler.CosineAnnealingLR          # 按cosine函數(shù)衰減學(xué)習(xí)率  
lr_scheduler.ReduceLROnPlateau          # 當指標停止改進時衰減學(xué)習(xí)率
lr_scheduler.CyclicLR                   # 周期性衰減學(xué)習(xí)率
lr_scheduler.CosineAnnealingWarmRestarts    # 
# 使用官方的Scheduler 
# 選擇一種優(yōu)化器
optimizer = torch.optim.Adam(...) 
# 選擇上面提到的一種或多種動態(tài)調(diào)整學(xué)習(xí)率的方法
scheduler1 = torch.optim.lr_scheduler.... 
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 進行訓(xùn)練
for epoch in range(100):
    train(...)
    validate(...)
    optimizer.step()
    # 需要在優(yōu)化器參數(shù)更新之后再動態(tài)調(diào)整學(xué)習(xí)率
    scheduler1.step() 
    ...
    schedulern.step()

我們在使用官方給出的torch.optim.lr_scheduler時,需要將scheduler.step()放在optimizer.step()后面進行使用。

2.2 自定義scheduler

自定義函數(shù)adjust_learning_rate來改變param_group中l(wèi)r的值,在下面的敘述中會給出一個簡單的實現(xiàn)。

需要學(xué)習(xí)率每30輪下降為原來的1/10,假設(shè)已有的官方API中沒有符合我們需求的,那就需要自定義函數(shù)來實現(xiàn)學(xué)習(xí)率的改變。

def adjust_learning_rate(optimizer, epoch):
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
    train(...)
    validate(...)
    adjust_learning_rate(optimizer,epoch)

3. 模型微調(diào)-torchvision

遷移學(xué)習(xí)的一大應(yīng)用場景是模型微調(diào)(finetune)。簡單來說,就是我們先找到一個同類的別人訓(xùn)練好的模型,把別人現(xiàn)成的訓(xùn)練好了的模型拿過來,換成自己的數(shù)據(jù),通過訓(xùn)練調(diào)整一下參數(shù)。 在PyTorch中提供了許多預(yù)訓(xùn)練好的網(wǎng)絡(luò)模型(VGG,ResNet系列,mobilenet系列......),這些模型都是PyTorch官方在相應(yīng)的大型數(shù)據(jù)集訓(xùn)練好的。學(xué)習(xí)如何進行模型微調(diào),可以方便我們快速使用預(yù)訓(xùn)練模型完成自己的任務(wù)。

3.1 模型微調(diào)流程

  1. 在源數(shù)據(jù)集(如ImageNet數(shù)據(jù)集)上預(yù)訓(xùn)練一個神經(jīng)網(wǎng)絡(luò)模型,即源模型。
  2. 創(chuàng)建一個新的神經(jīng)網(wǎng)絡(luò)模型,即目標模型。它復(fù)制了源模型上除了輸出層外的所有模型設(shè)計及其參數(shù)。我們假設(shè)這些模型參數(shù)包含了源數(shù)據(jù)集上學(xué)習(xí)到的知識,且這些知識同樣適用于目標數(shù)據(jù)集。我們還假設(shè)源模型的輸出層跟源數(shù)據(jù)集的標簽緊密相關(guān),因此在目標模型中不予采用。
  3. 為目標模型添加一個輸出?小為?標數(shù)據(jù)集類別個數(shù)的輸出層,并隨機初始化該層的模型參數(shù)。
  4. 在目標數(shù)據(jù)集上訓(xùn)練目標模型。我們將從頭訓(xùn)練輸出層,而其余層的參數(shù)都是基于源模型的參數(shù)微調(diào)得到的。


    模型微調(diào).png

3.2 使用已有模型結(jié)構(gòu)

這里我們以torchvision中的常見模型為例,列出了如何在圖像分類任務(wù)中使用PyTorch提供的常見模型結(jié)構(gòu)和參數(shù)。對于其他任務(wù)和網(wǎng)絡(luò)結(jié)構(gòu),使用方式是類似的:實例化網(wǎng)絡(luò),再傳遞pretrained參數(shù)

  • 實例化網(wǎng)絡(luò)
from cgitb import reset
import torchvision.models as models 

reset18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
  • 傳遞pretrained參數(shù)

通過True或者False來決定是否使用預(yù)訓(xùn)練好的權(quán)重,在默認狀態(tài)下pretrained = False,意味著我們不使用預(yù)訓(xùn)練得到的權(quán)重,當pretrained = True,意味著我們將使用在一些數(shù)據(jù)集上預(yù)訓(xùn)練得到的權(quán)重。

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)

注意事項:

  1. 通常PyTorch模型的擴展為.pt或.pth,程序運行時會首先檢查默認路徑中是否有已經(jīng)下載的模型權(quán)重,一旦權(quán)重被下載,下次加載就不需要下載了。

  2. 一般情況下預(yù)訓(xùn)練模型的下載會比較慢,我們可以直接通過迅雷或者其他方式去 這里 查看自己的模型里面model_urls,然后手動下載,預(yù)訓(xùn)練模型的權(quán)重在Linux和Mac的默認下載路徑是用戶根目錄下的.cache文件夾。在Windows下就是C:\Users<username>.cache\torch\hub\checkpoint。我們可以通過使用 torch.utils.model_zoo.load_url()設(shè)置權(quán)重的下載地址。

  3. 如果覺得麻煩,還可以將自己的權(quán)重下載下來放到同文件夾下,然后再將參數(shù)加載網(wǎng)絡(luò)。

self.model = models.resnet50(pretrained=False)

self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

  1. 如果中途強行停止下載的話,一定要去對應(yīng)路徑下將權(quán)重文件刪除干凈,要不然可能會報錯。
    mnasnet = models.mnasnet1_0(pretrained=True)

3.3 訓(xùn)練特定層

  • 在默認情況下,參數(shù)的屬性.requires_grad = True,如果我們從頭開始訓(xùn)練或微調(diào)不需要注意這里。但如果我們正在提取特征并且只想為新初始化的層計算梯度,其他參數(shù)不進行改變。那我們就需要通過設(shè)置requires_grad = False來凍結(jié)部分層。在PyTorch官方中提供了如下set_parameter_requires_grad的樣例。

  • 通過該樣例,我們使用resnet18為例的將1000類改為4類,但是僅改變最后一層的模型參數(shù),不改變特征提取的模型參數(shù);注意我們先凍結(jié)模型參數(shù)的梯度,再對模型輸出部分的全連接層進行修改,這樣修改后的全連接層的參數(shù)就是可計算梯度的。

  • 之后在訓(xùn)練過程中,model仍會進行梯度回傳,但是參數(shù)更新則只會發(fā)生在fc層。通過設(shè)定參數(shù)的requires_grad屬性,我們完成了指定訓(xùn)練模型的特定層的目標,這對實現(xiàn)模型微調(diào)非常重要。

import torchvision.models as models

# 凍結(jié)參數(shù)的梯度
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


# 凍結(jié)參數(shù)的梯度
feature_extract = True
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 修改模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)

4. 模型微調(diào)之timm

除了使用torchvision.models進行預(yù)訓(xùn)練以外,還有一個常見的預(yù)訓(xùn)練模型庫,叫做timm,這個庫是由來自加拿大溫哥華Ross Wightman創(chuàng)建的。里面提供了許多計算機視覺的SOTA模型,可以當作是torchvision的擴充版本,并且里面的模型在準確度上也較高。

  • timm安裝 pip install timm

4.1 查看/修改預(yù)訓(xùn)練模型種類

import timm
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)
# 查詢指定系列的模型時,可以在list_models()輸入模型名稱 
timm.list_models('resnet*', pretrained=True)
# 查看模型的具體參數(shù),可以通過default_cfg實現(xiàn),
# 創(chuàng)建模型時,使用num_classes可以將模型的輸出進行修改
model = timm.create_model('resnet18', num_classes=10, pretrained=True)
model.default_cfg
# 輸出
{'url': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1', 'classifier': 'fc', 'architecture': 'resnet18'}
# 查看模型的輸出
import torch
x = torch.randn(1, 3, 244, 244)
output = model(x)
output.shape
# 輸出
torch.Size([1, 10])

# 模型第一層
print(dict(model.named_children())['conv1'])
# 查看模型第一層的參數(shù),以第一層卷積為例)
print(list(dict(model.named_children())['conv1'].parameters()))
# print(model)          # 查看模型的網(wǎng)絡(luò)結(jié)構(gòu)
# 改變輸入通道數(shù)(比如我們傳入的圖片是單通道的,但是模型需要的是三通道圖片) 我們可以通過添加in_chans=1來改變
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
x = torch.randn(1,1,224,224)
output = model(x)

4.2 模型參數(shù)的保存

timm庫所創(chuàng)建的模型是torch.model的子類,我們可以直接使用torch庫中內(nèi)置的模型參數(shù)保存和加載的方法,具體操作如下方代碼所示

torch.save(model.state_dict(),'./checkpoint/timm_model.pth')

model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))

# 保存模型權(quán)重
torch.save(model.state_dict(), './timm_model.pth')
# 加載模型權(quán)重
model.load_state_dict(torch.load('./timm_model.pth'))

5. 半精度訓(xùn)練

PyTorch默認的浮點數(shù)存儲方式用的是torch.float32,小數(shù)點后位數(shù)更多固然能保證數(shù)據(jù)的精確性,但絕大多數(shù)場景其實并不需要這么精確,只保留一半的信息也不會影響結(jié)果,也就是使用torch.float16格式。由于數(shù)位減了一半,因此被稱為“半精度”。半精度能夠減少顯存占用,使得顯卡可以同時加載更多數(shù)據(jù)進行計算

半精度.png

半精度訓(xùn)練的設(shè)置
使用autocast配置半精度訓(xùn)練

半精度訓(xùn)練主要適用于數(shù)據(jù)本身的size比較大(比如說3D圖像、視頻等)。當數(shù)據(jù)本身的size并不大時(比如手寫數(shù)字MNIST數(shù)據(jù)集的圖片尺寸只有28*28),使用半精度訓(xùn)練則可能不會帶來顯著的提升。

# autocast的導(dǎo)入
from torch.cuda.amp import autocast

# 模型設(shè)置
# 使用python的裝飾器方法,用autocast裝飾模型中的forward函數(shù)。關(guān)于裝飾器的使用。
@autocast()
def forward(self, x):
    ... 
    return x

# 訓(xùn)練過程
# 在訓(xùn)練過程中,只需在將數(shù)據(jù)輸入模型及其之后的部分放入“with autocast():
for  x in train_loader:
    x = x.cuda()
    with autocast():
        ouput = model(x)
        ...
        

6. 數(shù)據(jù)增強

①數(shù)據(jù)增強有什么用?

深度學(xué)習(xí)最重要的是數(shù)據(jù)。我們需要大量數(shù)據(jù)才能避免模型的過度擬合。但是我們在許多場景無法獲得大量數(shù)據(jù),例如醫(yī)學(xué)圖像分析。數(shù)據(jù)增強技術(shù)的存在是為了解決這個問題,這是針對有限數(shù)據(jù)問題的解決方案。數(shù)據(jù)增強一套技術(shù),可提高訓(xùn)練數(shù)據(jù)集的大小和質(zhì)量,以便我們可以使用它們來構(gòu)建更好的深度學(xué)習(xí)模型。 在計算視覺領(lǐng)域,生成增強圖像相對容易。即使引入噪聲或裁剪圖像的一部分,模型仍可以對圖像進行分類,數(shù)據(jù)增強有一系列簡單有效的方法可供選擇,有一些機器學(xué)習(xí)庫來進行計算視覺領(lǐng)域的數(shù)據(jù)增強,比如:imgaug 官網(wǎng)它封裝了很多數(shù)據(jù)增強算法,給開發(fā)者提供了方便

②數(shù)據(jù)增強的怎么做?

數(shù)據(jù)擴增是對讀取進行數(shù)據(jù)增強的操作,所以需要在數(shù)據(jù)讀取的時候完成。

③數(shù)據(jù)增強的方法有哪些?

數(shù)據(jù)擴增方法有很多:從顏色空間、尺度空間到樣本空間,同時根據(jù)不同任務(wù)數(shù)據(jù)擴增都有相應(yīng)的區(qū)別。
對于圖像分類,數(shù)據(jù)擴增一般不會改變標簽;對于物體檢測,數(shù)據(jù)擴增會改變物體坐標位置;對于圖像分割,數(shù)據(jù)擴增會像素標簽;

④數(shù)據(jù)增強庫

pytorch官方提供的數(shù)據(jù)擴增庫,提供了基本的數(shù)據(jù)數(shù)據(jù)擴增方法,可以無縫與torch進行集成;但數(shù)據(jù)擴增方法種類較少,且速度中等;

  • imgaug

https://github.com/aleju/imgaug

imgaug是常用的第三方數(shù)據(jù)擴增庫,提供了多樣的數(shù)據(jù)擴增方法,且組合起來非常方便,速度較快;

  • albumentations

https://albumentations.readthedocs.io

是常用的第三方數(shù)據(jù)擴增庫,提供了多樣的數(shù)據(jù)擴增方法,對圖像分類、語義分割、物體檢測和關(guān)鍵點檢測都支持,速度較快。

6.1 torchvision中的常見數(shù)據(jù)增強方法

基礎(chǔ)數(shù)據(jù)擴增方法指常見的數(shù)據(jù)擴增方法,且都是標簽一致的數(shù)據(jù)擴增方法,大都出現(xiàn)在torchvision中:

  • transforms.CenterCrop
    對圖片中心進行裁剪;
  • transforms.ColorJitter
    對圖像顏色的對比度、飽和度和零度進行變換;
  • transforms.FiveCrop
    對圖像四個角和中心進行裁剪得到五分圖像;
  • transforms.Grayscale
    對圖像進行灰度變換;
  • transforms.Pad
    使用固定值進行像素填充;
  • transforms.RandomAffine
    隨機仿射變換;
  • transforms.RandomCrop
    隨機區(qū)域裁剪;
  • transforms.RandomHorizontalFlip
    隨機水平翻轉(zhuǎn);
  • transforms.RandomRotation
    隨機旋轉(zhuǎn);
  • transforms.RandomVerticalFlip
    隨機垂直翻轉(zhuǎn);
import torchvision.transforms 
torchvision.transforms.CenterCrop()

6.2 imgaug的安裝和使用

imgaug的安裝方法和其他的Python包類似,我們可以通過以下兩種方式進行安裝

  • conda
    (我用這個安裝失敗了~~~)
conda config --add channels conda-forge
conda install imgaug
  • pip
    用下面第二行安裝成功了
#  install imgaug either via pypi

pip install imgaug

#  install the latest version directly from github

pip install git+https://github.com/aleju/imgaug.git
imgaug的使用

imgaug僅僅提供了圖像增強的一些方法,但是并未提供圖像的IO操作,因此我們需要使用一些庫來對圖像進行導(dǎo)入,建議使用imageio進行讀入,如果使用的是opencv進行文件讀取的時候,需要進行手動改變通道,將讀取的BGR圖像轉(zhuǎn)換為RGB圖像。除此以外,當我們用PIL.Image進行讀取時,因為讀取的圖片沒有shape的屬性,所以我們需要將讀取到的img轉(zhuǎn)換為np.array()的形式再進行處理。因此官方的例程中也是使用imageio進行圖片讀取。

單張圖片的處理

import imageio
import imgaug as ia 
%matplotlib inline 

# image讀取
import PIL
from PIL import Image
import numpy as np

# # Image讀取照片
# img2 = Image.open('car.jpg')
# image2 = np.array(img2)
# ia.imshow(image2)

#圖片的讀取,使用imageiod讀取,imgaug展示
img = imageio.imread('car.jpg')
print(img.shape)
# 可視化
ia.imshow(img)
  • 對單張圖片進行增強處理
    以下為旋轉(zhuǎn),和多種方式的組合介紹
from imgaug import augmenters as iaa 

# 設(shè)置隨機數(shù)種子
ia.seed(4)
# 實例化方法
rotate= iaa.Affine(rotate=(45))
img_aug = rotate(image = img)
ia.imshow(img_aug)
rotate.png
  • 對圖片進行多種組合的數(shù)據(jù)增強

使用imgaug.augmenters.Sequential()來構(gòu)造數(shù)據(jù)增強的pipline,與torchvison.transforms.Compose()類似

總的來說,對單張圖片處理的方式基本相同,我們可以根據(jù)實際需求,選擇合適的數(shù)據(jù)增強方法來對數(shù)據(jù)進行處理。

# iaa.Sequential的參數(shù)如下
# iaa.Sequential(children=None,   # Augmenter集合
#                 random_order=False, # 是否對每個batch使用不同順序的Augmenter list
#                 name = None,
#                 deterministic=False,
#                 random_state=None
# )

# 構(gòu)建處理序列
from email.mime import image


aug_seq = iaa.Sequential([
    iaa.Affine(rotate=(-25, 25)),           # 讓圖片在(-25, 25)間隨機旋轉(zhuǎn)
    iaa.AdditiveGaussianNoise(scale=(10, 60)),      # 給圖片添加高斯噪聲,高斯噪聲的標準差為(10, 60)
    iaa.Crop(percent=(0, 0.2))          # 對圖像進行隨機裁剪,裁剪范圍(0, 0.2)
])

# 對圖片進行處理
image_aug = aug_seq(image=img)
ia.imshow(image_aug)

對批次圖片進行處理

可以將圖形數(shù)據(jù)按照NHWC(N:batch H:height W:Width C:channel)的形式或者由列表組成的HWC的形式對批量的圖像進行處理。主要分為以下兩部分:

  • 對批次的圖片以同一種方式處理
  • 對批次的圖片進行分部分處理。

import os
import imageio.v2 as imageio
import imgaug as ia 
%matplotlib inline 

root_path = '/Users/anker/Desktop/python_code/datasets/test'
image_list = []
# os.listdir(root_path)
for i in os.listdir(root_path):
    image_path = os.path.join(root_path, i)
    img = imageio.imread(image_path)
    image_list.append(img)
    print(img.shape)

# 對一批次的圖片進行處理時,只需要將待處理的圖片放在一個list中,并將image改為image即可進行數(shù)據(jù)增強操作,具體實際操作如下:
images = [image_list[0], image_list[0], image_list[0]]
# 傳參時需要指明是images參數(shù)
images_aug = rotate(images = images)
# ia.imshow圖片時,輸入的圖片必須是相同的大小
ia.imshow(np.hstack(images_aug))

輸出


批次1.png
# 對批次圖片使用多種增強方法,傳參時注意傳的是images參數(shù)
images_aug_seq = aug_seq.augment_images(images = images)
# images_aug_seq = aug_seq(images = images)         # 可以用上面的寫法,也可以用本行的方法傳參
ia.imshow(np.hstack(images_aug_seq))

輸出

批次輸出2.png

對批次的圖片分部分處理

imgaug.augmenters.Sometimes()對batch中的一部分圖片應(yīng)用一部分Augmenters,剩下的圖片應(yīng)用另外的Augmenters。

aug_sometimes = iaa.Sometimes(0.5, iaa.GaussianBlur(0.7), iaa.Fliplr(1.0))
images_aug_sometimes = aug_sometimes(images = images)
ia.imshow(np.hstack(images_aug_sometimes))

輸出

批次輸出3.png

對不同大小的圖片進行處理

除了可視化與其他不同外,其他都相同

image_list
# 構(gòu)建pipline
seq = iaa.Sequential([
    iaa.CropAndPad(percent=(-0.2, 0.2), pad_mode='edge'),       # 對圖片進行剪切和填充
    iaa.AddToHueAndSaturation((-60, 60)),               # 對圖片的飽和度和色調(diào)進行調(diào)整
    iaa.ElasticTransformation(alpha=0.9, sigma=9),      # 對圖片進行像素調(diào)整,產(chǎn)生水波紋的效果
    iaa.Cutout()          # 填充圖像
])

# 對圖像進行增強
images_seq = seq(images= image_list)

for i in range(len(image_list)):
    print("Image %d (input shape: %s, output shape: %s)" % (i, image_list[i].shape, images_seq[i].shape))
    ia.imshow(np.hstack([image_list[i], images_seq[i]]))

輸出


批次輸出4.png

6.3 imgaug在PyTorch的應(yīng)用

import numpy as np
import imgaug
from imgaug import augmenters as iaa
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# 構(gòu)建pipline
tfs = transforms.Compose([
    iaa.Sequential([
        iaa.flip.Fliplr(p=0.5),
        iaa.flip.Flipud(p=0.5),
        iaa.GaussianBlur(sigma=(0.0, 0.1)),
        iaa.MultiplyBrightness(mul=(0.65, 1.35)),
    ]).augment_image,
    # 不要忘記了使用ToTensor()
    transforms.ToTensor()
])

# 自定義數(shù)據(jù)集
class CustomDataset(Dataset):
    def __init__(self, n_images, n_classes, transform=None):
        # 圖片的讀取,建議使用imageio
        self.images = np.random.randint(0, 255,
                                        (n_images, 224, 224, 3),
                                        dtype=np.uint8)
        self.targets = np.random.randn(n_images, n_classes)
        self.transform = transform

    def __getitem__(self, item):
        image = self.images[item]
        target = self.targets[item]

        if self.transform:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.images)


def worker_init_fn(worker_id):
    imgaug.seed(np.random.get_state()[1][0] + worker_id)


custom_ds = CustomDataset(n_images=50, n_classes=10, transform=tfs)
custom_dl = DataLoader(custom_ds, batch_size=64,
                       num_workers=4, pin_memory=True, 
                       worker_init_fn=worker_init_fn)

關(guān)于num_workers在Windows系統(tǒng)上只能設(shè)置成0,但是當我們使用Linux遠程服務(wù)器時,可能使用不同的num_workers的數(shù)量,這是我們就需要注意worker_init_fn()函數(shù)的作用了。它保證了我們使用的數(shù)據(jù)增強在num_workers>0時是對數(shù)據(jù)的增強是隨機的。

除去imgaug以外,還可以學(xué)習(xí)下Albumentations,因為Albumentations跟imgaug都有著豐富的教程資源,這個以后再看,先學(xué)完教程再說。

7. 使用argparse進行調(diào)參

argparse的作用就是將命令行傳入的其他參數(shù)進行解析、保存和使用。在使用argparse后,我們在命令行輸入的參數(shù)就可以以這種形式python file.py --lr 1e-4 --batch_size 32來完成對常見超參數(shù)的設(shè)置。

argparse的使用

  • 創(chuàng)建ArgumentParser()對象
  • 調(diào)用add_argument()方法添加參數(shù)
  • 使用parse_args()解析參數(shù)
# 簡單demo
import argparse

# 創(chuàng)建ArgumentParse()對象
parser = argparse.ArgumentParser()

# 添加參數(shù)
parser.add_argument('-o', '--output', action='store_true', help='shows output')
# action = `store_true` 會將output參數(shù)記錄為True
# type 規(guī)定了參數(shù)的格式
# default 規(guī)定了默認值
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')

# 使用parse_args()解析函數(shù)
args = parser.parse_args()

if args.output:  
    print(f"learning rate:{args.lr} ")

更加高效的使用

一種方式是將超參數(shù)的設(shè)置寫在單獨的config.py文件中,然后在調(diào)用使用

另外一種是封裝為函數(shù),調(diào)用的時候進行使用

# 將超參數(shù)設(shè)置寫在單獨的config.py文件中
import argparse  
  
def get_options(parser=argparse.ArgumentParser()):  
  
    parser.add_argument('--workers', type=int, default=0,  
                        help='number of data loading workers, you had better put it '  
                              '4 times of your gpu')  
  
    parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')  
  
    parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')  
  
    parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')  
  
    parser.add_argument('--seed', type=int, default=118, help="random seed")  
  
    parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')  
    parser.add_argument('--checkpoint_path',type=str,default='',  
                        help='Path to load a previous trained model if not empty (default empty)')  
    parser.add_argument('--output',action='store_true',default=True,help="shows output")  
  
    opt = parser.parse_args()  
  
    if opt.output:  
        print(f'num_workers: {opt.workers}')  
        print(f'batch_size: {opt.batch_size}')  
        print(f'epochs (niters) : {opt.niter}')  
        print(f'learning rate : {opt.lr}')  
        print(f'manual_seed: {opt.seed}')  
        print(f'cuda enable: {opt.cuda}')  
        print(f'checkpoint_path: {opt.checkpoint_path}')  
  
    return opt  
  
if __name__ == '__main__':  
    opt = get_options()
# 在隨后的train.py等文件中,單獨使用
# 導(dǎo)入必要庫
...
import config

opt = config.get_options()

manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path

# 隨機數(shù)的設(shè)置,保證復(fù)現(xiàn)結(jié)果
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

...


if __name__ == '__main__':
    set_seed(manual_seed)
    for epoch in range(niters):
        train(model,lr,batch_size,num_workers,checkpoint_path)
        val(model,lr,batch_size,num_workers,checkpoint_path)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容