參考這個鏈接,加了一些自己的注釋
導(dǎo)入模塊
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
建立模型
LeNet 1998年提出
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
檢查模塊
檢查未修建的conv1層
module = model.conv1
print(list(module.named_parameters()))
輸出結(jié)果如下:
權(quán)重有6個矩陣,每個矩陣的size是3*3,偏差為6個value。
[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True))]
檢測是否有緩沖區(qū)
print(list(module.named_buffers()))
輸出為空矩陣
parameter和buffer的區(qū)別:
模型中需要保存下來的參數(shù)包括兩種:
一種是反向傳播需要被optimizer更新的,稱之為 parameter
一種是反向傳播不需要被optimizer更新,稱之為 buffer
修剪模塊 part1
目標:我們將在conv1層中名為weight的參數(shù)中隨機修剪 30%的連接。
- 從
torch.nn.utils.prune選擇修建技術(shù) - 指定模塊和該模塊中需要修剪的參數(shù)名稱
3.使用所選修剪技術(shù)所需的適當關(guān)鍵字參數(shù),指定修剪參數(shù)。
name=weight的理解:
之前print(list(module.named_parameters()))的輸出結(jié)果是以字典形式保存的,關(guān)鍵字有weight
amount=0.3的理解:剪掉百分之30的連接。
prune.random_unstructured(module, name="weight", amount=0.3)
修剪函數(shù)執(zhí)行時候的內(nèi)部原理:
修剪是通過從參數(shù)中刪除weight并將其替換為名為weight_orig的新參數(shù)(即,將"_orig"附加到初始參數(shù)name)來進行的。 weight_orig存儲未修剪的張量版本。 bias未修剪,因此它將保持完整。
print(list(module.named_parameters()))
輸出如下:
[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True))]
通過以上選擇的修剪技術(shù)生成的修剪掩碼將保存為名為weight_mask的模塊緩沖區(qū)
print(list(module.named_buffers()))
輸出:
[('weight_mask', tensor([[[[0., 1., 0.],
[1., 0., 0.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.]]],
[[[1., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 0.],
[1., 1., 0.]]]], device='cuda:0'))]
這里需要注意:
mask里標記為0的位置對應(yīng)的weight是被pruned掉的,在retrained的時候保持為0。
對比mask和weight可以發(fā)現(xiàn),mask里標記為0的量對應(yīng)每個3*3矩陣里weight magnitude較小的權(quán)重。
這時候打印weight,會得到掩碼和原始參數(shù)結(jié)合的版本(即pruned的權(quán)重變?yōu)?)。注意這里的weight不是一個參數(shù),只是一個屬性。
print(module.weight)
tensor([[[[ 0.0000, -0.2212, 0.0000],
[ 0.2488, 0.0000, 0.0000],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.0000, -0.2106],
[ 0.1776, -0.1845, -0.0000],
[-0.0708, 0.0000, 0.3095]]],
[[[-0.2070, 0.0000, 0.0000],
[ 0.0000, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.0000, -0.0000],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.0000, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.0000, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.0000],
[ 0.2159, -0.1725, 0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
剪枝需要在每次前向傳播之前被應(yīng)用。通過PyTorch 的forward_pre_hooks可以應(yīng)用剪枝。
當模型被剪枝時,它將為與該模型關(guān)聯(lián)的每個參數(shù)獲取forward_pre_hook進行修剪。(注意,在這里模型不是指整個網(wǎng)絡(luò)模型,而是指被剪枝的子模型,比如在這里是指conv1)
在這種情況下,由于到目前為止我們只修剪了名稱為weight的原始參數(shù),因此只會出現(xiàn)一個鉤子。
print(module._forward_pre_hooks)
輸出為:
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>)])
修建模塊 part2
為了完整起見,我們現(xiàn)在也可以修剪bias,以查看module的參數(shù)parameter,緩沖區(qū)buffer,掛鉤hook和屬性property如何變化。
在這里我們嘗試另一種修剪方法,按 L1 范數(shù)修剪掉最小的3個偏差bias
prune.l1_unstructured(module, name="bias", amount=3)
預(yù)計目標:
現(xiàn)在,我們希望命名的參數(shù)同時包含之前的weight_orig和bias_orig。 緩沖區(qū)buffer將包括weight_mask和bias_mask。 兩個張量(weight和bias)的修剪版本將作為模塊屬性存在,并且該模塊現(xiàn)在將具有兩個forward_pre_hooks。
實際輸出:
參數(shù):
[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True))]
緩沖區(qū):
[('weight_mask', tensor([[[[0., 1., 0.],
[1., 0., 0.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.]]],
[[[1., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 0.],
[1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
屬性:
tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000, 0.1425], device='cuda:0',
grad_fn=<MulBackward0>)
鉤子:可以看到有兩個鉤子
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1e6c425550>)])
迭代修剪
一個模塊中的同一參數(shù)可以被多次修剪。
暫時用不到這一塊,因此不深入下去了。
序列化修剪的模型
所有相關(guān)的張量,包括掩碼緩沖區(qū)和用于計算修剪的張量的原始參數(shù),都存儲在模型的state_dict中,因此可以根據(jù)需要輕松地序列化和保存。
print(model.state_dict().keys())
輸出
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
刪除剪枝重新參數(shù)化
torch.nn.utils.prune中的remove
修建模型中的多個參數(shù)
下面這段代碼對不同的層采用了不同的sparsity percentage
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
全局剪枝(Global Pruning)
之前我們的剪枝方法為“局部剪枝”(local pruning)研究了通常被稱為“局部”修剪的方法,即通過比較每個條目的統(tǒng)計信息(weight magnitude, activation, gradient, etc.)來逐一修剪模型中的張量的做法。 但是,一種常見且可能更強大的技術(shù)是通過刪除(例如)刪除整個模型中最低的20%的連接,而不是刪除每一層中最低的 20%的連接來一次修剪模型。 這很可能導(dǎo)致每個層的修剪百分比不同。 讓我們看看如何使用torch.nn.utils.prune中的global_unstructured進行操作。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
下面我們檢查在每個修剪參數(shù)中引起的稀疏性,該稀疏性將不等于每層中的 20%。 但是,全局稀疏度大約為 20%。
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100\. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100\. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100\. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
使用自定義修剪功能擴展torch.nn.utils.prune
這部分暫時用不到