模型參數(shù)的裁剪

基礎使用

1 如何使用id( )

我們隨便定義一個模型:

m_seq = torch.nn.Sequential(
    torch.nn.Linear(2, 2),
    torch.nn.Linear(2, 2),
)

如果只是使用id(m_seq.parameters()) 只會返回整個m_seq.parameters()的一個id. 因此我們使用map().

list(map(id, m_seq.parameters())) # [140575686657984, 140575686318016, 140579143672432, 140579143765264]

可以看見返回了4個id. 它們分別是創(chuàng)建的兩個線性層的weight 和 bias 的參數(shù)的id(['0.weight', '0.bias', '1.weight', '1.bias']).

2 使用filter根據(jù)id濾除參數(shù)

當我們明確要濾除參數(shù)的模塊的時候, 可以使用下面這個方法:

def filter_params(model, to_filter_module):
    ignored_id = list(map(id, to_filter_module.parameters())) # list
    out_params = filter(lambda p: id(p) not in ignored_id, model.parameters())
    return out_params

在這個過程中filter 找到根據(jù)對比layer3里面參數(shù)的id和model里面所有參數(shù)的id(指的是list(map(id, model.parameters())))將layer3的參數(shù)濾除.
注意: 該method返回的out_params是一個filter.
使用的例子如下:
例如我們有這樣一個有3層的模型:

class MyModel(nn.Module):
    def __init__(self):
        super (MyModel, self).__init__()
        self.layer1 = torch.nn.Linear(2, 2)
        self.layer2 = torch.nn.Linear(2, 2)
        self.layer3 = torch.nn.Linear(2, 2)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

model = MyModel()

我們希望第一層和第二層用lr1更新, 第三層用lr2更新. 因此我們需要將第三層和第一二層分開, 實現(xiàn)方式如下:

param_exclude_layer3 = filter_params(model, model.layer3) # filter
param_layer3 = model.layer3.parameters() # generator

opt = torch.optim.Adam([
    {'params': param_exclude_layer3, 'lr': 1e-3},
    {'params': param_layer3, 'lr': 1e-4},
])

濾除多層

使用以下代碼濾除多層. 將需要濾除的多個模塊裝在list里面?zhèn)魅?例如: param_exclude_layer3_and_layer1 = filter_params(model, [model.layer3,model.layer1])

def filter_params(model, to_filter_module_list):
    ignored_id = []
    for module in to_filter_module_list:
        ignored_id += list(map(id, module.parameters()))
    print(ignored_id)
    out_params = filter(lambda p: id(p) not in ignored_id, model.parameters())
    return out_params

根據(jù)id來濾除的方法因為id的唯一性, 不太可能濾除錯誤, 但如果模型復雜, 就需要手動索引模塊的位置, 例如: model.enc.linears.layer1, 這樣比較麻煩, 可能需要測試定位它的位置.

?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容