Pytorch剪枝代碼示例和注釋

參考這個鏈接,加了一些自己的注釋

導(dǎo)入模塊

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

建立模型

LeNet 1998年提出

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

檢查模塊

檢查未修建的conv1

module = model.conv1
print(list(module.named_parameters()))

輸出結(jié)果如下:
權(quán)重有6個矩陣,每個矩陣的size是3*3,偏差為6個value。

[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]

檢測是否有緩沖區(qū)

print(list(module.named_buffers()))

輸出為空矩陣
parameter和buffer的區(qū)別:

模型中需要保存下來的參數(shù)包括兩種:
一種是反向傳播需要被optimizer更新的,稱之為 parameter
一種是反向傳播不需要被optimizer更新,稱之為 buffer

修剪模塊 part1

目標:我們將在conv1層中名為weight的參數(shù)中隨機修剪 30%的連接。

  1. torch.nn.utils.prune選擇修建技術(shù)
  2. 指定模塊和該模塊中需要修剪的參數(shù)名稱
    3.使用所選修剪技術(shù)所需的適當關(guān)鍵字參數(shù),指定修剪參數(shù)。
    name=weight的理解:
    之前print(list(module.named_parameters()))的輸出結(jié)果是以字典形式保存的,關(guān)鍵字有weight
    amount=0.3的理解:剪掉百分之30的連接。
prune.random_unstructured(module, name="weight", amount=0.3)

修剪函數(shù)執(zhí)行時候的內(nèi)部原理:

修剪是通過從參數(shù)中刪除weight并將其替換為名為weight_orig的新參數(shù)(即,將"_orig"附加到初始參數(shù)name)來進行的。 weight_orig存儲未修剪的張量版本。 bias未修剪,因此它將保持完整。

print(list(module.named_parameters()))

輸出如下:

[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True))]

通過以上選擇的修剪技術(shù)生成的修剪掩碼將保存為名為weight_mask的模塊緩沖區(qū)

print(list(module.named_buffers()))

輸出:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],

        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0'))]

這里需要注意:
mask里標記為0的位置對應(yīng)的weight是被pruned掉的,在retrained的時候保持為0。
對比mask和weight可以發(fā)現(xiàn),mask里標記為0的量對應(yīng)每個3*3矩陣里weight magnitude較小的權(quán)重。

這時候打印weight,會得到掩碼和原始參數(shù)結(jié)合的版本(即pruned的權(quán)重變?yōu)?)。注意這里的weight不是一個參數(shù),只是一個屬性。

print(module.weight)
tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.0000, -0.2106],
          [ 0.1776, -0.1845, -0.0000],
          [-0.0708,  0.0000,  0.3095]]],

        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.0000, -0.0000],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.0000],
          [ 0.2159, -0.1725,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

剪枝需要在每次前向傳播之前被應(yīng)用。通過PyTorch 的forward_pre_hooks可以應(yīng)用剪枝。
當模型被剪枝時,它將為與該模型關(guān)聯(lián)的每個參數(shù)獲取forward_pre_hook進行修剪。(注意,在這里模型不是指整個網(wǎng)絡(luò)模型,而是指被剪枝的子模型,比如在這里是指conv1

在這種情況下,由于到目前為止我們只修剪了名稱為weight的原始參數(shù),因此只會出現(xiàn)一個鉤子。

print(module._forward_pre_hooks)

輸出為:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>)])

修建模塊 part2

為了完整起見,我們現(xiàn)在也可以修剪bias,以查看module的參數(shù)parameter,緩沖區(qū)buffer,掛鉤hook和屬性property如何變化。
在這里我們嘗試另一種修剪方法,按 L1 范數(shù)修剪掉最小的3個偏差bias

prune.l1_unstructured(module, name="bias", amount=3)

預(yù)計目標:
現(xiàn)在,我們希望命名的參數(shù)同時包含之前的weight_orig和bias_orig。 緩沖區(qū)buffer將包括weight_mask和bias_mask。 兩個張量(weight和bias)的修剪版本將作為模塊屬性存在,并且該模塊現(xiàn)在將具有兩個forward_pre_hooks。
實際輸出:
參數(shù):

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]

緩沖區(qū):

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],

        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

屬性:

tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000,  0.1425], device='cuda:0',
       grad_fn=<MulBackward0>)

鉤子:可以看到有兩個鉤子

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1e6c425550>)])

迭代修剪

一個模塊中的同一參數(shù)可以被多次修剪。
暫時用不到這一塊,因此不深入下去了。

序列化修剪的模型

所有相關(guān)的張量,包括掩碼緩沖區(qū)和用于計算修剪的張量的原始參數(shù),都存儲在模型的state_dict中,因此可以根據(jù)需要輕松地序列化和保存。

print(model.state_dict().keys())

輸出

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

刪除剪枝重新參數(shù)化

torch.nn.utils.prune中的remove

修建模型中的多個參數(shù)

下面這段代碼對不同的層采用了不同的sparsity percentage

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

全局剪枝(Global Pruning)

之前我們的剪枝方法為“局部剪枝”(local pruning)研究了通常被稱為“局部”修剪的方法,即通過比較每個條目的統(tǒng)計信息(weight magnitude, activation, gradient, etc.)來逐一修剪模型中的張量的做法。 但是,一種常見且可能更強大的技術(shù)是通過刪除(例如)刪除整個模型中最低的20%的連接,而不是刪除每一層中最低的 20%的連接來一次修剪模型。 這很可能導(dǎo)致每個層的修剪百分比不同。 讓我們看看如何使用torch.nn.utils.prune中的global_unstructured進行操作。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

下面我們檢查在每個修剪參數(shù)中引起的稀疏性,該稀疏性將不等于每層中的 20%。 但是,全局稀疏度大約為 20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100\. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

使用自定義修剪功能擴展torch.nn.utils.prune

這部分暫時用不到

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容