PyTorch模型量化(二)- FX Graph模式的量化方法介紹

Introduction

由于最近項(xiàng)目需要,研究和學(xué)習(xí)PyTorch PTQ和QAT 量化的使用。比較新的PyTorch版本目前推薦使用FX Graph Mode Quantization。


FX Graph 模式量化Demo演示使用

Post-Training-Quantization (PTQ) 靜態(tài)量化的主要流程:
PyTorch FX Graph模式進(jìn)行量化的主要流程 step1 ~ step4:

  • step1: 設(shè)置,選擇量化方式 : 比如逐通道/layer QScheme, 量化之后的值域表示范圍(Qmin, Qmax)
  • step2: prepare_fx:
    * a) 將輸入的模型(nn.Module)轉(zhuǎn)為GraphModule (IR轉(zhuǎn)換)
    * b) Graph子圖,op融合(比如conv+relu --> convReLu)
    * c) 在Conv, Linear等OP前后插入Observer, 用于收集激活值Feature map的特征(范圍)
  • step3: 喂數(shù)據(jù),進(jìn)行Activation標(biāo)定
  • step4: 計(jì)算Weight和Activation量化參數(shù) (比如 scale, zero_point), 模型FP32 --> INT8
  • step5: 驗(yàn)證INT8 量化之后模型的精度
from ctypes import util
from torchvision.models import resnet18, resnet50
import torch
from torch.ao.quantization import quantize_fx, get_default_qconfig
import os
import copy
import utils


def calibrate(model, data_loader, num_batch, device):
    utils.evaluate(model=model, data_loader=data_loader, neval_batches=num_batch, n_print=1, device=device)


if __name__ == '__main__':
    device = torch.device('cuda', 0)
    eval_batch_size = 32
    imagenet_data='/media/wei/Document/ImageNet/ILSVRC2012'

    model_fp = resnet50(pretrained=True, progress=True).to(device)
    model_fp.eval()

    _, test_dataloader = utils.prepare_dataloader(data_path=imagenet_data, eval_batch_size=eval_batch_size, num_workers=8)
    utils.evaluate(model=model_fp, criterion=None, data_loader=test_dataloader, device=device)
    # ResNet-18: Tested on imagenet-val: batch:3125 Acc@1  56.25 ( 69.76), Acc@5  75.00 ( 89.08)
    # ResNet-50: batch:1560 Acc@1  59.38 ( 76.18), Acc@5  90.62 ( 92.87)

    # torch quantization
    model_prepare = copy.deepcopy(model_fp)
    model_prepare.eval()

    # 設(shè)置量化方式
    qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}
    model_prepare = quantize_fx.prepare_fx(model=model_prepare, qconfig_dict=qconfig_dict)
    model_prepare.eval()

    # 標(biāo)定,確定Activation的量化范圍
    calibrate(model_prepare, test_dataloader, 10, device)

    # 根據(jù)之前設(shè)置的量化方式以及標(biāo)定計(jì)算的參數(shù), 進(jìn)行模型轉(zhuǎn)換, FP32--> INT8
    quantized_model = quantize_fx.convert_fx(graph_module=model_prepare)
    quantized_model.eval()

    # 測(cè)試量化之后模型的精度
    utils.evaluate(quantized_model, data_loader=test_dataloader)

得益于PyTorch FX Graph Quantization API的精簡(jiǎn)設(shè)計(jì), 我們只需要很少的代碼以及修改就可以實(shí)現(xiàn)量化, 激動(dòng)!?。?, 接下來(lái)我們一探FX Graph 量化背后的具體實(shí)現(xiàn)原理。

下面逐一分析FX Graph 量化的過(guò)程

PyTorch FX Graph量化——Step1. 量化方式的配置選擇

這里是pytorch默認(rèn)的PTQ量化配置, 'fbgemm' --- 這是一個(gè)矩陣計(jì)算的庫(kù),支持server 端x86 CPU 的 Int8 Conv, Linear等OP。

qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}

def get_default_qconfig(backend='fbgemm'):
    """
    Returns the default PTQ qconfig for the specified backend.

    Args:
      * `backend`: a string representing the target backend. Currently supports `fbgemm`
        and `qnnpack`.

    Return:
        qconfig
    """

    if backend == 'fbgemm':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
                          weight=default_per_channel_weight_observer)
    elif backend == 'qnnpack':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
                          weight=default_weight_observer)
    else:
        qconfig = default_qconfig
    return qconfig
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)

我們發(fā)現(xiàn)qconfig包含2部分: 分別對(duì)weight, 以及activation的量化方式的配置, 其中 activation采用 HistogramObserver 基于直方圖統(tǒng)計(jì)的逐tensor/layer非對(duì)稱量化方式, Weight采用PerChannelMinMaxObserver 逐channel對(duì)稱量化方式。

Why ? 為什么Activation和Weight的量化方式不同?

  1. Weight的量化方式:
  • weight中元素的分布和activation有所不同: 因?yàn)閣eight一般都是均值為0, 左右對(duì)稱的Gaussian分布, 因此采用對(duì)稱量化
  • 為了減少量化OP中的計(jì)算量, 因?yàn)閷?duì)稱量化的zero_point=0

參考高通AI的量化白皮書(shū)介紹:

image.png

Observer的作用

總的來(lái)說(shuō)Observer是用于觀測(cè)數(shù)據(jù)分布, 計(jì)算量化參數(shù) scale, zero_point. 接下來(lái)從代碼進(jìn)行解析.
分析 PerChannelMinxMaxObserver

class PerChannelMinMaxObserver(_ObserverBase):
    r"""Observer module for computing the quantization parameters based on the
    running per channel min and max values.

    This observer uses the tensor min/max statistics to compute the per channel
    quantization parameters. The module records the running minimum and maximum
    of incoming tensors, and uses this statistic to compute the quantization
    parameters.

    Args:
        ch_axis: Channel axis
        dtype: Quantized data type
        qscheme: Quantization scheme to be used
        reduce_range: Reduces the range of the quantized data type by 1 bit
        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
        memoryless: Boolean that controls whether observer removes old data when a new input is seen.
                    This is most useful for simulating dynamic quantization, especially during QAT.

    The quantization parameters are computed the same way as in
    :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
    that the running min/max values are stored per channel.
    Scales and zero points are thus computed per channel as well.

    .. note:: If the running minimum equals to the running maximum, the scales
              and zero_points are set to 1.0 and 0.
    """
    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(
        self,
        ch_axis=0,
        dtype=torch.quint8,
        qscheme=torch.per_channel_affine,
        reduce_range=False,
        quant_min=None,
        quant_max=None,
        factory_kwargs=None,
        memoryless=False,
    ) -> None:
        super(PerChannelMinMaxObserver, self).__init__(
            dtype=dtype,
            qscheme=qscheme,
            reduce_range=reduce_range,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs,
        )
        self.memoryless = memoryless
        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
        self.ch_axis = ch_axis
        self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
        self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
        if (
            self.qscheme == torch.per_channel_symmetric
            and self.reduce_range
            and self.dtype == torch.quint8
        ):
            raise NotImplementedError(
                "Cannot reduce range for symmetric quantization for quint8"
            )

    def forward(self, x_orig):
        return self._forward(x_orig)

    def _forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()  # avoid keeping autograd tape
        min_val = self.min_val
        max_val = self.max_val
        x_dim = x.size()

        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
        new_axis_list[self.ch_axis] = 0
        new_axis_list[0] = self.ch_axis
        y = x.permute(new_axis_list)
        # Need to match dtype of min/max because the updates to buffers
        # are done in place and types need to match for comparisons
        y = y.to(self.min_val.dtype)
        y = torch.flatten(y, start_dim=1)
        if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
            min_val, max_val = torch.aminmax(y, dim=1)
        else:
            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
            min_val = torch.min(min_val_cur, min_val)
            max_val = torch.max(max_val_cur, max_val)
        self.min_val.resize_(min_val.shape)
        self.max_val.resize_(max_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.copy_(max_val)
        return x_orig

    @torch.jit.export
    def calculate_qparams(self):
        return self._calculate_qparams(self.min_val, self.max_val)

為了計(jì)算量化所需的參數(shù), PyTorch定義了一系列的Observer, 比如MinMaxObserver, MovingAveragingMinMaxObserver等等, 所有這些XXXObserver都繼承自一個(gè)基類,在基類的Observer中主要定義了以下2個(gè)重要的函數(shù):
我們發(fā)現(xiàn)Observer中主要的2個(gè)函數(shù):

  • forward(self, x_orig): 觀測(cè)weight中元素的最大,最小值
  • calculate_qparams(self): 計(jì)算scale, zero_point

forward(self, x_orig) 函數(shù)的功能實(shí)現(xiàn):

  • 輸入: x_orig: 也就是weight tensor, 一般CNN的weight的shape為: Oc * Ic * Kh * Kw 4D Tensor
  • 輸出/結(jié)果: 觀測(cè)到的最大,最小值

在實(shí)例化Observer對(duì)象的時(shí)候, init() 函數(shù)中的一個(gè)參數(shù) ch_axis=0 用于指定channel維度, ch_axis=0說(shuō)明Observer觀測(cè)的是weight的 Oc (output_channels) 方向的最大和最小值。 觀測(cè)最大、最小值的核心代碼:
min_val, max_val = torch.aminmax(y, dim=1)
因?yàn)镺c的在axis=0的維度上, 因此aminmax(dim=1)對(duì)axis=1的維度上進(jìn)行了規(guī)約reduction, 得到了Oc個(gè) min, max_val, 即Weight的每個(gè)output_channel包含一個(gè)scale, zero_point

    def _forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()  # avoid keeping autograd tape
        min_val = self.min_val
        max_val = self.max_val
        x_dim = x.size()

        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
        new_axis_list[self.ch_axis] = 0
        new_axis_list[0] = self.ch_axis
        y = x.permute(new_axis_list)
        # Need to match dtype of min/max because the updates to buffers
        # are done in place and types need to match for comparisons
        y = y.to(self.min_val.dtype)
        y = torch.flatten(y, start_dim=1)
        if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
            min_val, max_val = torch.aminmax(y, dim=1)
        else:
            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
            min_val = torch.min(min_val_cur, min_val)
            max_val = torch.max(max_val_cur, max_val)
        self.min_val.resize_(min_val.shape)
        self.max_val.resize_(max_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.copy_(max_val)
        return x_orig

calculate_qparams 函數(shù)的功能實(shí)現(xiàn)

很容易理解這個(gè)函數(shù)是用于計(jì)算量化參數(shù): scale & zero_point (對(duì)于線性量化)的, 下面分析代碼實(shí)現(xiàn):

  • 輸入: 觀測(cè)得到的 max_val, min_val, 以及定義好的qmax, qmin
  • 輸出: 計(jì)算得到的scale, zero_point
    def _calculate_qparams(
        self, min_val: torch.Tensor, max_val: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Calculates the quantization parameters, given min and max
        value tensors. Works for both per tensor and per channel cases

        Args:
            min_val: Minimum values per channel
            max_val: Maximum values per channel

        Returns:
            scales: Scales tensor of shape (#channels,)
            zero_points: Zero points tensor of shape (#channels,)
        """
        if not check_min_max_valid(min_val, max_val):
            return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)

        quant_min, quant_max = self.quant_min, self.quant_max
        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

        device = min_val_neg.device
        scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
        zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

        if (
            self.qscheme == torch.per_tensor_symmetric
            or self.qscheme == torch.per_channel_symmetric
        ):
            max_val_pos = torch.max(-min_val_neg, max_val_pos)
            scale = max_val_pos / (float(quant_max - quant_min) / 2)
            scale = torch.max(scale, self.eps)
            if self.dtype == torch.quint8:
                if self.has_customized_qrange:
                    # When customized quantization range is used, down-rounded midpoint of the range is chosen.
                    zero_point = zero_point.new_full(
                        zero_point.size(), (quant_min + quant_max) // 2
                    )
                else:
                    zero_point = zero_point.new_full(zero_point.size(), 128)
        elif self.qscheme == torch.per_channel_affine_float_qparams:
            scale = (max_val - min_val) / float(quant_max - quant_min)
            scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
            # We use the quantize function
            # xq = Round(Xf * inv_scale + zero_point),
            # setting zero_point to (-1 * min *inv_scale) we get
            # Xq = Round((Xf - min) * inv_scale)
            zero_point = -1 * min_val / scale
        else:
            scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
            scale = torch.max(scale, self.eps)
            zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
            zero_point = torch.clamp(zero_point, quant_min, quant_max)

        # For scalar values, cast them to Tensors of size 1 to keep the shape
        # consistent with default values in FakeQuantize.
        if len(scale.shape) == 0:
            # TODO: switch to scale.item() after adding JIT support
            scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
        if len(zero_point.shape) == 0:
            # TODO: switch to zero_point.item() after adding JIT support
            zero_point = torch.tensor(
                [int(zero_point)], dtype=zero_point.dtype, device=device
            )
            if self.qscheme == torch.per_channel_affine_float_qparams:
                zero_point = torch.tensor(
                    [float(zero_point)], dtype=zero_point.dtype, device=device
                )

        return scale, zero_point

計(jì)算量化參數(shù)Scale , zero_point的核心代碼

  • 對(duì)稱量化 (symmetric Quantization)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  • 非對(duì)稱量化 (Affine Quantization)
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, self.eps)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)

以上分析了nn.Conv2d layer的weight的量化參數(shù)的計(jì)算過(guò)程以及PerChannelMinMaxObserver的實(shí)現(xiàn)過(guò)程。下面繼續(xù)分析Activation的量化參數(shù)計(jì)算過(guò)程。


Activation的量化參數(shù)計(jì)算以及HistgramObserver分析

在選擇量化設(shè)置的時(shí)候, 默認(rèn)的backend=fbgemm中Activation采用 HistogramObserver, 即基于直方圖分析的方式計(jì)算量化參數(shù)。

if backend == 'fbgemm':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
                          weight=default_per_channel_weight_observer)

HistogramObserver過(guò)程分析

  1. 初始化: init()
  • 默認(rèn)bins=2048, 因?yàn)檫M(jìn)行直方圖統(tǒng)計(jì)需要設(shè)置一個(gè)bins代表直方圖的統(tǒng)計(jì)區(qū)間,即把min_val到max_val區(qū)間劃分2048份。
  • qscheme=per_tensor_affine, 即量化粒度采用逐tensor/layer 仿射量化, 逐tensor代表只有一個(gè)量化參數(shù)scale + zero_point, 而不是一組

class HistogramObserver(_ObserverBase):
    r"""
    The module records the running histogram of tensor values along with
    min/max values. ``calculate_qparams`` will calculate scale and zero_point.

    Args:
        bins: Number of bins to use for the histogram
        upsample_rate: Factor by which the histograms are upsampled, this is
                       used to interpolate histograms with varying ranges across observations
        dtype: Quantized data type
        qscheme: Quantization scheme to be used
        reduce_range: Reduces the range of the quantized data type by 1 bit

    The scale and zero point are computed as follows:

    1. Create the histogram of the incoming inputs.
        The histogram is computed continuously, and the ranges per bin change
        with every new tensor observed.
    2. Search the distribution in the histogram for optimal min/max values.
        The search for the min/max values ensures the minimization of the
        quantization error with respect to the floating point model.
    3. Compute the scale and zero point the same way as in the
        :class:`~torch.ao.quantization.MinMaxObserver`
    """
    histogram: torch.Tensor
    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(
        self,
        bins: int = 2048,
        upsample_rate: int = 128,
        dtype: torch.dtype = torch.quint8,
        qscheme=torch.per_tensor_affine,
        reduce_range=False,
        quant_min=None,
        quant_max=None,
        factory_kwargs=None,
    ) -> None:
        # bins: The number of bins used for histogram calculation.
        super(HistogramObserver, self).__init__(
            dtype=dtype,
            qscheme=qscheme,
            reduce_range=reduce_range,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs,
        )
        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
        self.bins = bins
        self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
        self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
        self.upsample_rate = upsample_rate
  1. 對(duì)Activation的 Tensor進(jìn)行統(tǒng)計(jì)觀察
    def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()
        min_val = self.min_val
        max_val = self.max_val
        same_values = min_val.item() == max_val.item()
        is_uninitialized = min_val == float("inf") and max_val == float("-inf")
        if is_uninitialized or same_values:
            min_val, max_val = torch.aminmax(x)
            self.min_val.resize_(min_val.shape)
            self.min_val.copy_(min_val)
            self.max_val.resize_(max_val.shape)
            self.max_val.copy_(max_val)
            assert (
                min_val.numel() == 1 and max_val.numel() == 1
            ), "histogram min/max values must be scalar."
            torch.histc(
                x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
            )
        else:
            new_min, new_max = torch.aminmax(x)
            combined_min = torch.min(new_min, min_val)
            combined_max = torch.max(new_max, max_val)
            # combine the existing histogram and new histogram into 1 histogram
            # We do this by first upsampling the histogram to a dense grid
            # and then downsampling the histogram efficiently
            (
                combined_min,
                combined_max,
                downsample_rate,
                start_idx,
            ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
            assert (
                combined_min.numel() == 1 and combined_max.numel() == 1
            ), "histogram min/max values must be scalar."
            combined_histogram = torch.histc(
                x, self.bins, min=int(combined_min), max=int(combined_max)
            )
            if combined_min == min_val and combined_max == max_val:
                combined_histogram += self.histogram
            else:
                combined_histogram = self._combine_histograms(
                    combined_histogram,
                    self.histogram,
                    self.upsample_rate,
                    downsample_rate,
                    start_idx,
                    self.bins,
                )

            self.histogram.detach_().resize_(combined_histogram.shape)
            self.histogram.copy_(combined_histogram)
            self.min_val.detach_().resize_(combined_min.shape)
            self.min_val.copy_(combined_min)
            self.max_val.detach_().resize_(combined_max.shape)
            self.max_val.copy_(combined_max)
        return x_orig
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容