torch.nn.utils.clip_grad_norm

梯度剪裁源碼地址以及函數(shù)的說明

函數(shù)源碼
函數(shù)官方說明

我的理解:

對于存在梯度爆炸的情況, 在優(yōu)化器函數(shù)之前執(zhí)行這個(gè)函數(shù),可以重新整合一遍梯度梯度縮小到指定范圍。

函數(shù)需要的參數(shù):

  1. parameters:計(jì)算了梯度之后的權(quán)重參數(shù)
  2. max_norm:認(rèn)為設(shè)定的閾值
  3. norm_type:指定的范數(shù)

函數(shù)執(zhí)行的操作
1. 對所有需要進(jìn)行梯度計(jì)算的參數(shù),收集所有參數(shù)的梯度的指定范數(shù)(通過參數(shù)norm_type進(jìn)行設(shè)置,1表示絕對值,2表示二階范數(shù)也就是平方和開根號)

2. 計(jì)算所有參數(shù)的梯度范數(shù)總和(一個(gè)標(biāo)量)和設(shè)定的max_norm的比值。如果max_norm/total_norm>1, 所有參數(shù)的梯度不變,可以直接反向傳播。如果比值小于1,說明參數(shù)梯度需要被縮減,縮減比率為rate= max_norm/total_norm,所有反向傳播的梯度變?yōu)樵镜膔ate倍。

這樣的意義就是避免權(quán)重梯度爆炸導(dǎo)致模型訓(xùn)練困難,對于大梯度的縮小,小梯度的不變。
但是存在的問題是,參數(shù)原本的分布很不均勻,有的梯度大有的梯度小;而梯度的總體范數(shù)值對于閾值,那么所有的梯度都會被同比例縮小。

import warnings
import torch
from torch._six import inf
from typing import Union, Iterable

_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]

def clip_grad_norm_(
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False) -> torch.Tensor:
    r"""Clips gradient norm of an iterable of parameters.

 The norm is computed over all gradients together, as if they were
 concatenated into a single vector. Gradients are modified in-place.

 Args:
 parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
 single Tensor that will have gradients normalized
 max_norm (float or int): max norm of the gradients
 norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
 infinity norm.
 error_if_nonfinite (bool): if True, an error is thrown if the total
 norm of the gradients from :attr:``parameters`` is ``nan``,
 ``inf``, or ``-inf``. Default: False (will switch to True in the future)

 Returns:
 Total norm of the parameters (viewed as a single vector).
 """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == inf:
        norms = [p.grad.detach().abs().max().to(device) for p in parameters]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f'The total norm of order {norm_type} for gradients from '
            '`parameters` is non-finite, so it cannot be clipped. To disable '
            'this error and scale the gradients by the non-finite norm anyway, '
            'set `error_if_nonfinite=False`')
    clip_coef = max_norm / (total_norm + 1e-6)
    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
    # when the gradients do not reside in CPU memory.
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    for p in parameters:
        p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
    return total_norm
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • TORCH.NN.FUNCTIONAL Convolution functions conv1d torch.nn...
    shcho閱讀 5,990評論 0 0
  • 函數(shù)調(diào)用形式 其為一個(gè)簡單的存儲固定大小的詞典的嵌入向量的查找表,意思就是說,給一個(gè)編號,嵌入層就能返回這個(gè)編號對...
    top_小醬油閱讀 183,936評論 10 100
  • 部分參考:z.defying https://zhuanlan.zhihu.com/p/76459295貢獻(xiàn)地址:...
    log1302閱讀 156評論 0 0
  • https://mp.weixin.qq.com/s/o-V07uM5NBn-0kQOQYrImw本文僅作為學(xué)術(shù)分...
    顧北向南閱讀 539評論 0 1
  • 一.函數(shù)調(diào)用形式 其為一個(gè)簡單的存儲固定大小的詞典的嵌入向量的查找表,意思就是說,給一個(gè)編號,嵌入層就能返回這個(gè)編...
    Vivivivi安閱讀 2,116評論 0 0

友情鏈接更多精彩內(nèi)容