Pytorch框架學(xué)習(xí)(11)——優(yōu)化器

@[toc]

1. 什么是優(yōu)化器

pytorch的優(yōu)化器:管理并更新模型中可學(xué)習(xí)參數(shù)的值,使得模型輸出更接近真實(shí)標(biāo)簽

  • 基本屬性
    • defaults:優(yōu)化器超參數(shù)
    • state:參數(shù)的緩存,如momentum的緩存
    • param_groups:管理的參數(shù)組(list)
    • _step_count:記錄更新次數(shù),學(xué)習(xí)率調(diào)整中使用


      在這里插入圖片描述
  • 基本方法
    • zero_grad():清空所管理參數(shù)的梯度,因?yàn)?strong>pytorch特性:張量梯度不自動(dòng)清零
    • step():執(zhí)行一步更新
    • add_param_group():添加參數(shù)組
    • state_dict():獲取有乎其當(dāng)前狀態(tài)信息字典
    • load_state_dict():加載狀態(tài)信息字典

2. 學(xué)習(xí)率與動(dòng)量

  • 梯度下降:w_{i+1} = w_i - g(w_i)

  • 學(xué)習(xí)率:控制更新的步伐。
    增加學(xué)習(xí)率之后的梯度下降公式為:w_{i+1} = w_i - LR*g(w_i)

  • 動(dòng)量(Momentum):結(jié)合當(dāng)前梯度與上一次更新信息,用于當(dāng)前更新
    增加動(dòng)量之后的梯度下降公式為:
    v_i = m * v_{i-1} + g(w_i)
    w_{i+1} = w_i - lr * v_i
    其中g(w_i)表示w_i的梯度, m為momentum系數(shù), v_i表示更新量,lr表示學(xué)習(xí)率,w_{i+1}表示第i+1次更新的參數(shù)

3. torch.optim.SGD

  • optim.SGD
    • 主要參數(shù):
      • params:管理的參數(shù)組(list)
      • lr:初試學(xué)習(xí)率
      • momentum:動(dòng)量系數(shù),貝塔
      • weight_decay:L2正則化系數(shù)
      • nesterov:是否采用NAG

4. 優(yōu)化器

  1. optim.SGD:隨機(jī)梯度下降法
  2. optim.Adagrad:自適應(yīng)學(xué)習(xí)率梯度下降法
  3. optim.RMSprop:Adagrad的改進(jìn)
  4. optim.Adadelta:Adagrad的改進(jìn)
  5. optim.Adam:RMSprop結(jié)合Momentum
  6. optim.Adamax:Adam增加學(xué)習(xí)率上限
  7. optim.SparseAdam:稀疏版Adam
  8. optim.ASGD:隨機(jī)平均梯度下降
  9. optim.Rprop:彈性反向傳播
    10.optim.LBFGS:BFGS的改進(jìn)

5. 作業(yè)

優(yōu)化器的作用是管理并更新參數(shù)組,請(qǐng)構(gòu)建一個(gè)SGD優(yōu)化器,通過(guò)add_param_group方法添加三組參數(shù),三組參數(shù)的學(xué)習(xí)率分別為 0.01, 0.02, 0.03, momentum分別為0.9, 0.8, 0.7,構(gòu)建好之后,并打印優(yōu)化器中的param_groups屬性中的每一個(gè)元素的key和value(提示:param_groups是list,其每一個(gè)元素是一個(gè)字典)

w1 = torch.randn((2, 2), requires_grad=True)
w2 = torch.randn((2, 2), requires_grad=True)
w3 = torch.randn((2, 2), requires_grad=True)
w1.grad = torch.ones((2, 2))

optimizer = optim.SGD([w1], lr=0.01, momentum=0.9)
optimizer.add_param_group({"params": w2, 'lr': 0.02, 'momentum': 0.8})
optimizer.add_param_group({"params": w3, 'lr': 0.03, 'momentum': 0.7})

print("optimizer.param_groups is\n{}".format(optimizer.param_groups))

執(zhí)行結(jié)果:

[{'params': [tensor([[0.6614, 0.2669],
        [0.0617, 0.6213]], requires_grad=True)], 'lr': 0.01, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-0.4519, -0.1661],
        [-1.5228,  0.3817]], requires_grad=True)], 'lr': 0.02, 'momentum': 0.8, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-1.0276, -0.5631],
        [-0.8923, -0.0583]], requires_grad=True)], 'lr': 0.03, 'momentum': 0.7, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容