【優(yōu)化技巧】指數(shù)移動(dòng)平均(EMA)的原理及PyTorch實(shí)現(xiàn)

在深度學(xué)習(xí)中,經(jīng)常會(huì)使用EMA(指數(shù)移動(dòng)平均)這個(gè)方法對(duì)模型的參數(shù)做平均,以求提高測(cè)試指標(biāo)并增加模型魯棒。

今天瓦礫準(zhǔn)備介紹一下EMA以及它的Pytorch實(shí)現(xiàn)代碼。

EMA的定義

指數(shù)移動(dòng)平均(Exponential Moving Average)也叫權(quán)重移動(dòng)平均(Weighted Moving Average),是一種給予近期數(shù)據(jù)更高權(quán)重的平均方法。

假設(shè)我們有n個(gè)數(shù)據(jù):[\theta_1, \theta_2, ..., \theta_n]?

  • 普通的平均數(shù):\overline{v}=\frac{1}{n}\sum_{i=1}^n \theta_i
  • EMA:v_t = \alpha\cdot v_{t-1} + (1-\alpha)\cdot \theta_t,其中,v_t表示前t條的平均值 (v_0=0),\alpha 是加權(quán)權(quán)重值 (一般設(shè)為0.9-0.999)。

Andrew Ng在Course 2 Improving Deep Neural Networks中講到,EMA可以近似看成過去1/(1-\alpha)個(gè)時(shí)刻v值的平均。

普通的過去n時(shí)刻的平均是這樣的:
v_t =\frac{(n-1)\cdot v_{t-1}+\theta_t}{n}
類比EMA,可以發(fā)現(xiàn)當(dāng)\alpha=\frac{n-1}{n}時(shí),兩式形式上相等。需要注意的是,兩個(gè)平均并不是嚴(yán)格相等的,這里只是為了幫助理解。

實(shí)際上,EMA計(jì)算時(shí),過去1/(1-\alpha)個(gè)時(shí)刻之前的平均會(huì)decay到 \frac{1}{e} ,證明如下。

如果將這里的v_t展開,可以得到:
v_t = \alpha^n v_{t-n} + (1-\alpha)(\alpha^{n-1}\theta_{t-n+1}+ ... +\alpha^0\theta_t)
其中,n=\frac{1}{1-\alpha},代入可以得到\alpha^n=\alpha^{\frac{1}{1-\alpha}}\approx \frac{1}{e}

EMA的偏差修正

實(shí)際使用中,如果令v_0=0,步數(shù)較少的情況下,ema的計(jì)算結(jié)果會(huì)有一定偏差。

偏差

理想的平均是綠色的,因?yàn)槌跏贾禐?,所以得到的是紫色的。

因此可以加一個(gè)偏差修正(bias correction)。
v_t = \frac{v_t}{1-\alpha^t}
顯然,當(dāng)t很大時(shí),修正近似于1。

在深度學(xué)習(xí)的優(yōu)化中的EMA

上面講的是廣義的ema定義和計(jì)算方法,特別的,在深度學(xué)習(xí)的優(yōu)化過程中,\theta_t 是t時(shí)刻的模型權(quán)重weights,v_t是t時(shí)刻的影子權(quán)重(shadow weights)。在梯度下降的過程中,會(huì)一直維護(hù)著這個(gè)影子權(quán)重,但是這個(gè)影子權(quán)重并不會(huì)參與訓(xùn)練。基本的假設(shè)是,模型權(quán)重在最后的n步內(nèi),會(huì)在實(shí)際的最優(yōu)點(diǎn)處抖動(dòng),所以我們?nèi)∽詈髇步的平均,能使得模型更加的魯棒。

EMA為什么有效

網(wǎng)上大多數(shù)介紹EMA的博客,在介紹其為何有效的時(shí)候,只做了一些直覺上的解釋,缺少嚴(yán)謹(jǐn)?shù)耐评?,瓦礫在這補(bǔ)充一下,不喜歡看公式的讀者可以跳過。

令第n時(shí)刻的模型權(quán)重(weights)為v_n,梯度為g_n,可得:
\begin{align} \theta_n &= \theta_{n-1}-g_{n-1} \\\\ &=\theta_{n-2}-g_{n-1}-g_{n-2} \\\\ &= ... \\\\ &= \theta_1-\sum_{i=1}^{n-1}g_i \end{align}
令第n時(shí)刻EMA的影子權(quán)重為v_n,可得:
\begin{align} v_n &= \alpha v_{n-1}+(1-\alpha)\theta_n \\\\ &= \alpha (\alpha v_{n-2}+(1-\alpha)\theta_{n-1})+(1-\alpha)\theta_n \\\\ &= ... \\\\ &= \alpha^n v_0+(1-\alpha)(\theta_n+\alpha\theta_{n-1}+\alpha^2\theta_{n-2}+...+\alpha^{n-1}\theta_{1}) \end{align}

代入上面\theta_n的表達(dá),令v_0=\theta_1展開上面的公式,可得:
\begin{align} v_n &= \alpha^n v_0+(1-\alpha)(\theta_n+\alpha\theta_{n-1}+\alpha^2\theta_{n-2}+...+\alpha^{n-1}\theta_{1})\\\\ &= \alpha^n v_0+(1-\alpha)(\theta_1-\sum_{i=1}^{n-1}g_i+\alpha(\theta_1-\sum_{i=1}^{n-2}g_i)+...+ \alpha^{n-2}(\theta_1-\sum_{i=1}^{1}g_i)+\alpha^{n-1}\theta_{1})\\\\ &= \alpha^n v_0+(1-\alpha)(\frac{1-\alpha^n}{1-\alpha}\theta_1-\sum_{i=1}^{n-1}\frac{1-\alpha^{n-i}}{1-\alpha}g_i) \\\\ &= \alpha^n v_0+(1-\alpha^n)\theta_1 -\sum_{i=1}^{n-1}(1-\alpha^{n-i})g_i\\\\ &= \theta_1 -\sum_{i=1}^{n-1}(1-\alpha^{n-i})g_i \end{align}
對(duì)比兩式:
\theta_n = \theta_1-\sum_{i=1}^{n-1}g_i
v_n = \theta_1 -\sum_{i=1}^{n-1}(1-\alpha^{n-i})g_i
EMA對(duì)第i步的梯度下降的步長增加了權(quán)重系數(shù)1-\alpha^{n-i}?,相當(dāng)于做了一個(gè)learning rate decay。

PyTorch實(shí)現(xiàn)

瓦礫看了網(wǎng)上的一些實(shí)現(xiàn),使用起來都不是特別方便,所以自己寫了一個(gè)。

class EMA():
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()
    
    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]
    
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

# 初始化
ema = EMA(model, 0.999)
ema.register()

# 訓(xùn)練過程中,更新完參數(shù)后,同步update shadow weights
def train():
    optimizer.step()
    ema.update()

# eval前,apply shadow weights;eval之后,恢復(fù)原來模型的參數(shù)
def evaluate():
    ema.apply_shadow()
    # evaluate
    ema.restore()

References

  1. 機(jī)器學(xué)習(xí)模型性能提升技巧:指數(shù)加權(quán)平均(EMA)
  2. Exponential Weighted Average for Deep Neutal Networks
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容