MXNet實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練量化

? ? ? ? 深度學(xué)習(xí)在移動(dòng)端的應(yīng)用是越來(lái)越廣泛,由于移動(dòng)端的運(yùn)算力與服務(wù)器相比還是有差距,所以在移動(dòng)端部署深度學(xué)習(xí)模型的難點(diǎn)就在于如何保證模型效果的同時(shí),運(yùn)行效率也有保證。在實(shí)驗(yàn)階段對(duì)于模型結(jié)構(gòu)可以選擇大模型,因?yàn)樵撾A段主要是為了驗(yàn)證方法的有效性。在驗(yàn)證完了之后,開(kāi)始著手部署到移動(dòng)端,這時(shí)候就要精簡(jiǎn)模型的結(jié)構(gòu)了,一般是對(duì)訓(xùn)好的大模型進(jìn)行剪枝,或者參考現(xiàn)有的比如MobileNetV2ShuffleNetV2 等輕量級(jí)的網(wǎng)絡(luò)重新設(shè)計(jì)自己的網(wǎng)絡(luò)模塊。而算法層面的優(yōu)化除了剪枝還有量化,量化就是把浮點(diǎn)數(shù)(高精度)表示的權(quán)值和激活值用更低精度的整數(shù)來(lái)近似表示。低精度的優(yōu)點(diǎn)有,相比于高精度算術(shù)運(yùn)算,其在單位時(shí)間內(nèi)能處理更多的數(shù)據(jù),而且權(quán)值量化之后模型的存儲(chǔ)空間能進(jìn)一步的減少等等。對(duì)訓(xùn)練好的網(wǎng)絡(luò)做量化,在實(shí)踐中嘗試過(guò)TensorRT的后訓(xùn)練量化算法,在一些任務(wù)上效果還不錯(cuò)。但是如果能在訓(xùn)練過(guò)程中去模擬量化的過(guò)程,讓網(wǎng)絡(luò)學(xué)習(xí)去修正量化帶來(lái)的誤差,那么得到的量化參數(shù)應(yīng)該是更準(zhǔn)確的,而且在實(shí)際量化推斷中模型的性能損失應(yīng)該能更小。而本文的內(nèi)容就是介紹google的論文和復(fù)現(xiàn)其過(guò)程中的一些細(xì)節(jié)。本文相關(guān)實(shí)驗(yàn)代碼:

MXNET-Scala TrainQuantization

訓(xùn)練模擬量化

? ? ? ? 首先來(lái)看下量化的具體定義,對(duì)于量化激活值到有符號(hào)8bit整數(shù),論文中給出的定義如下:

? ? 公式中的三角形表示量化的縮放因子,x表示浮點(diǎn)數(shù)激活值,首先通過(guò)除以縮放因子然后最近鄰取整,然后把范圍限制到一個(gè)區(qū)間內(nèi),比如量化到有符號(hào)8bit,那么范圍就是 [-128, 127]。而對(duì)于權(quán)值還有一個(gè)小的技巧,就是量化到[-127, 127]:

具體為什么這么做,論文中說(shuō)了是為了實(shí)現(xiàn)上的優(yōu)化,具體解釋可以看論文附錄B ARM NEON details這一小節(jié)。

? ? ? ? 而訓(xùn)練量化我理解就是在forward階段去模擬量化這個(gè)過(guò)程,就是把權(quán)值和激活值量化到8bit再反量化回有誤差的32bit,所以訓(xùn)練還是浮點(diǎn),backward階段是對(duì)模擬量化之后權(quán)值的求梯度,然后用這個(gè)梯度去更新量化前的權(quán)值。然后在下個(gè)batch繼續(xù)這個(gè)過(guò)程,通過(guò)這樣子能夠讓網(wǎng)絡(luò)學(xué)會(huì)去修正量化帶來(lái)的誤差。

? ? ? ? 上面給這個(gè)示意圖就很直觀的表示了模擬量化的過(guò)程,比如上面那條線表示的是量化前的范圍[rmin, rmax],然后下面那條線表示的就是量化之后的范圍[-128, 127],比如現(xiàn)在要進(jìn)行模擬量化的forward,先看上面那條線從左到右數(shù)第4個(gè)圓點(diǎn),通過(guò)除以縮放因子之后就會(huì)映射124到125之間的一個(gè)浮點(diǎn)數(shù),然后通過(guò)最近鄰取整就取到了125,再通過(guò)乘以縮放因子返回上面第五個(gè)圓點(diǎn),最后就用這個(gè)有誤差的數(shù)替換原來(lái)的去forward。forward階段的模擬量化用公式表示如下:

backward階段求梯度的公式表示如下:

? ? ? 對(duì)于縮放因子的計(jì)算,權(quán)值和激活值的不一樣,權(quán)值的計(jì)算方法是每次forward直接對(duì)權(quán)值求絕對(duì)值取最大值,然后縮放因子 weight scale = max(abs(weight)) / 127。然后對(duì)于激活值,稍微有些不一樣,激活值的量化范圍不是簡(jiǎn)單的計(jì)算最大值,而是通過(guò)EMA(exponential moving averages)在訓(xùn)練中去統(tǒng)計(jì)這個(gè)量化范圍,更新公式如下:

moving_max = moving_max * momenta + max(abs(activation)) * (1- momenta)

? ? ? ? 公式中的activation表示每個(gè)batch的激活值,而論文中說(shuō)momenta取接近1的數(shù)就行了,在實(shí)驗(yàn)中我是取0.95。然后縮放因子 activation scale = moving_max /128。

實(shí)現(xiàn)細(xì)節(jié)

? ? ? 在實(shí)現(xiàn)過(guò)程中我沒(méi)有按照論文的方法量化到無(wú)符號(hào)8bit,而是有符號(hào)8bit,第一是因?yàn)闊o(wú)符號(hào)8bit量化需要引入額外的零點(diǎn),增加復(fù)雜性,其次在實(shí)際應(yīng)用過(guò)程中都是量化到有符號(hào)8bit。然后論文中提到,對(duì)于權(quán)值的量化分通道進(jìn)行求縮放因子,然后對(duì)于激活值的量化整體求一個(gè)縮放因子,這樣效果最好。在實(shí)踐中發(fā)現(xiàn)權(quán)值不分通道量化效果也不錯(cuò),這個(gè)還是看具體任務(wù)吧,而本文給出的實(shí)驗(yàn)代碼是沒(méi)分的。

? ? ? 然后對(duì)于卷積層之后帶batchnorm的網(wǎng)絡(luò),因?yàn)橐话阍趯?shí)際使用階段,為了優(yōu)化速度,batchnorm的參數(shù)都會(huì)提前融合進(jìn)卷積層的參數(shù)中,所以訓(xùn)練模擬量化的過(guò)程也要按照這個(gè)流程。首先把batchnorm的參數(shù)與卷積層的參數(shù)融合,然后再對(duì)這個(gè)參數(shù)做量化。以下兩張圖片分別表示的是訓(xùn)練過(guò)程與實(shí)際應(yīng)用過(guò)程中對(duì)batchnorm層處理的區(qū)別:

? ? 對(duì)于如何融合batchnorm參數(shù)進(jìn)卷積層參數(shù),看以下公式:

? ? ? 公式中的,W和b分別表示卷積層的權(quán)值與偏置,x和y分別為卷積層的輸入與輸出,則根據(jù)bn的計(jì)算公式,可以推出融合了batchnorm參數(shù)之后的權(quán)值與偏置,Wmerge和bmerge。這里對(duì)于融合了bn權(quán)值的偏置的公式推導(dǎo)結(jié)果和論文中的有些不同,論文中的結(jié)果看起來(lái)應(yīng)該是沒(méi)有考慮卷積層本身帶有偏置的情況。

? ? ? 而且在實(shí)驗(yàn)中我是簡(jiǎn)化了融合batchnorm的流程,要是完全按照論文中的實(shí)現(xiàn)要復(fù)雜很多,而且是基于已經(jīng)訓(xùn)好的網(wǎng)絡(luò)去做模擬量化實(shí)驗(yàn)的,不基于預(yù)訓(xùn)練模型訓(xùn)不起來(lái),可能還有坑要踩。而且在模擬量化訓(xùn)練過(guò)程中batchnorm層參數(shù)固定,融合batchnorm參數(shù)也是用已經(jīng)訓(xùn)好的移動(dòng)均值和方差,而不是用每個(gè)batch的均值和方差。

? ? 具體實(shí)現(xiàn)的時(shí)候就是按照論文中的這個(gè)模擬量化卷積層示例圖去寫(xiě)訓(xùn)練網(wǎng)絡(luò)結(jié)構(gòu)的。

實(shí)驗(yàn)結(jié)果

? ? ? 用VGG在Cifar10上做了下實(shí)驗(yàn),效果還可以,因?yàn)槭菫榱蓑?yàn)證量化訓(xùn)練的有效性,所以訓(xùn)Cifar10的時(shí)候沒(méi)怎么調(diào)過(guò)參,數(shù)據(jù)增強(qiáng)也沒(méi)做,訓(xùn)出來(lái)的模型精確度最高只有0.877,比最好的結(jié)果0.93差不少,然后模擬量化是基于這個(gè)0.877的模型去做的,可以得到與普通訓(xùn)練精確度基本一樣的模型,可能是這個(gè)分類任務(wù)比較簡(jiǎn)單。然后得到訓(xùn)好的模型與每層的量化因子之后,就可以模擬真實(shí)的量化推斷過(guò)程,不過(guò)因?yàn)镸XNet的卷積層不支持整型運(yùn)算,所以模擬的過(guò)程也是用浮點(diǎn)來(lái)模擬,具體實(shí)現(xiàn)細(xì)節(jié)可見(jiàn)示例代碼。

結(jié)束語(yǔ)

? ? ? 以上內(nèi)容是根據(jù)最近的一些工作實(shí)踐總結(jié)得到的一篇博客,對(duì)于論文的實(shí)現(xiàn)很多地方都是我自己個(gè)人的理解,如果有讀者發(fā)現(xiàn)哪里有誤或者有疑問(wèn),也請(qǐng)指出,大家互相交流學(xué)習(xí):)。

相關(guān)資料:

1、https://heartbeat.fritz.ai/8-bit-quantization-and-tensorflow-lite-speeding-up-mobile-inference-with-low-precision-a882dfcafbbd

2、https://github.com/google/gemmlowp/blob/master/doc/quantization.md

3、https://arxiv.org/pdf/1712.05877.pdf

4、https://arxiv.org/abs/1806.08342

5、http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf

6、TensorRT(5)-INT8校準(zhǔn)原理


本文首發(fā)于:https://zhuanlan.zhihu.com/p/65468307

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容