基于指數(shù)族分布的變分推斷——變分推斷(二)

基于指數(shù)族分布的變分推斷——變分推斷(二)

讓我們書(shū)接上文。

前一篇博客(基于近似計(jì)算解決推斷問(wèn)題——變分推斷(一))我們說(shuō)到基于高斯貝葉斯混合的 CAVI (坐標(biāo)上升變分推斷),那么,我們能不能將這類變分推斷進(jìn)行擴(kuò)展,變成更為通用的算法框架呢?

顯然,基于指數(shù)分布族(exponential families)的某些特性,這樣的做法是可行的。下面讓我們先看看什么是指數(shù)分布族。

本文主要參考的文獻(xiàn)為David M.Blei 2018年發(fā)表的論文 Variational Inference: A Review for Statisticians

指數(shù)分布族

指數(shù)分布族的定義

指數(shù)族分布(exponential family of distributions)也叫指數(shù)型分布族,包含高斯分布、伯努利分布二項(xiàng)分布、泊松分布、Beta 分布、Dirichlet 分布、Gamma 分布。指數(shù)族分布通??梢员硎緸椋?/p>

\begin{aligned} p(x|\eta)&=h(x)\exp(\eta^T\phi(x)-A(\eta))\\ &=\frac1{\exp(A(\eta))}h(x)\exp(\eta^T\phi(x)) \end{aligned}\tag{a}
其中有幾個(gè)比較重要的參數(shù)后面可能會(huì)用到:

  • h(x)底層觀測(cè)值(underlying measure)

  • \eta 是參數(shù)向量,也成為自然參數(shù)(natural parameter);

  • A(\eta) 是對(duì)數(shù)配分函數(shù),可以認(rèn)為是對(duì)數(shù)歸一化因子(log normalizer);

  • \phi(x)充分統(tǒng)計(jì)量(sufficient statistic),包含樣本集合所有信息,如高斯分布中的均值和方差。對(duì)于一個(gè)數(shù)據(jù)集,只需要記錄樣本的充分統(tǒng)計(jì)量即可。

    統(tǒng)計(jì)“:即表示對(duì)樣本的統(tǒng)計(jì)值
    充分“:如通過(guò)兩個(gè)統(tǒng)計(jì)值,可以求得均值和方差,進(jìn)而可以獲得高斯分布表達(dá)式。

或者,也可以采用另一種表示形式:
p(x|\theta,\psi)=\exp\{\frac{x\theta-b(\theta)}{\psi}+c(x,\psi)\}\tag
其中, \theta 是指數(shù)族的自然參數(shù)\psi尺度參數(shù)討厭參數(shù)。 b(\cdot)c(\cdot) 依據(jù)不同指數(shù)族而確定的函數(shù)。注意 c(\cdot) 只由 x\psi 決定

常見(jiàn)的指數(shù)分布族

image-20211002102134532.png

一維高斯分布的指數(shù)分布族表示形式

一維高斯分布

一維變量 x 若服從均值為 \mu 、方差為 \sigma 的一維高斯分布,則可以表示為
p(x|\mu,\sigma)=\frac1{\sqrt{2\pi}\sigma}\exp(-\frac{(x-\mu)^2}{2\sigma^2})\tag{c}
公式(a)的形式

如果按照公式(a)對(duì)高斯分布的公式進(jìn)行轉(zhuǎn)變,則可以變?yōu)?br> \begin{aligned} p(x|\mu,\sigma)&=\frac1{\sqrt{2\pi}\sigma}\exp(-\frac{(x-\mu)^2}{2\sigma^2})\\ &=\exp(\log(2\pi\sigma^2)^{-\frac12})\exp(-\frac1{2\sigma^2}\begin{pmatrix}-2\mu,1\end{pmatrix}\begin{pmatrix}x\\x^2\end{pmatrix}-\frac{\mu^2}{2\sigma^2}) \end{aligned}\tagu0z1t8os
可以看到,自然參數(shù)可以表示為 \eta=\begin{pmatrix}\frac{\mu}{\sigma^2}\\-\frac1{2\sigma^2}\end{pmatrix}=\begin{pmatrix}\eta_1\\\eta_2\end{pmatrix} ,對(duì)數(shù)配分函數(shù)可以表示為 A(\eta)=-\frac{\eta_1^2}{4\eta_2}+\frac12\log(-\frac{\pi}{\eta_2}) 。按照這個(gè)公式,我們可以計(jì)算出均值、方差與自然函數(shù)的關(guān)系
\begin{aligned} \mu&=-\frac{\eta_1}{2\eta_2}\\ \sigma^2&=-\frac1{2\eta_2}\\ \end{aligned}\tag{e}
這也是上一篇博客中,公式(34)的由來(lái)。

公式(b)的形式

按照公式(b),可以化為

p(x|\mu,\sigma)=\exp\big\{\frac{x\mu-\frac{\mu^2}2}{\sigma^2}-\frac{x^2}{2\sigma^2}-\frac12\log(2\pi\sigma^2)\big\}\tag{f}
其中,

\begin{aligned} \theta &= \mu\\ \psi &= \sigma^2\\ b(\mu) &= \frac{\mu^2}2\\ c(x,\psi)&= \frac{x^2}{2\psi}+\frac12\log(2\pi\sigma^2) \end{aligned}

充分統(tǒng)計(jì)量和配分函數(shù)的關(guān)系

對(duì)概率密度函數(shù)求積分:
\exp(A(\eta))=\int h(x)\exp(\eta^T\phi(x))dx\tag{g}
兩邊對(duì)參數(shù)求導(dǎo)
\exp(A(\eta))A'(\eta)=\int h(x)\exp(\eta^T\phi(x))\phi(x)dx\\ \Rightarrow A'(\eta)=\mathbb E_{p(x|\eta)}[\phi(x)]\tag{h}
類似的
A''(\eta)=Var_{p(x|eta)}[\phi(x)]\tag{i}
由于方差為正,所以 A(\eta) 一定是凸函數(shù)

充分統(tǒng)計(jì)量和極大似然估計(jì)

對(duì)于獨(dú)立分布采樣得到的數(shù)據(jù)集 \mathcal D=\{x_1,x_2,\dots,x_N\}

\eta 的的極大似然估計(jì)為
\begin{aligned} \eta_{MLE}&=\arg\max_\eta\sum^N_{i=1}\log p(x_i|\eta)\\ &= \arg\max_\eta\sum^N_{i=1}(\eta^T \psi (x_i)-A(\eta))\\ &\Longrightarrow A'(\eta_{MLE})=\frac1N\sum^N_{i=1}\psi(x_i) \end{aligned}\tag{j}
所以,如果要進(jìn)行估算參數(shù),只要知道充分統(tǒng)計(jì)量就可以了

最大熵原理推導(dǎo)指數(shù)分布族公式

信息熵公式為
\text{Entropy}=\int-p(x)\log(p(x))dx\tag{k}
對(duì)于一個(gè)數(shù)據(jù)集 D ,在這個(gè)數(shù)據(jù)集上的經(jīng)驗(yàn)分布為 \hat{p}(x)=\frac{Count(x)}N ,實(shí)際不可能滿足所有的經(jīng)驗(yàn)概率相同,于是在上面的最大熵原理中還需要加入這個(gè)經(jīng)驗(yàn)分布的約束。

對(duì)于任意一個(gè)函數(shù),經(jīng)驗(yàn)分布的經(jīng)驗(yàn)期望可以求得為
\mathbb E_{\hat{p}}[f(x)]=\Delta
Lagrange 函數(shù)為
L(p,\lambda_0,\lambda)=\sum^K_{k=1}p_k\log p_k+\lambda_0(1-\sum^N_{K=1}p_k)+\lambda^T(\Delta-\mathbb E_{p}[f(x)])\tag{l}
求導(dǎo)可得
\begin{aligned} \frac{\partial}{\partial p(x)}L &=\sum^K_{k=1}(\log p_k(x)+1)-\sum^K_{k=1}\lambda_0-\sum^K_{k=1}\lambda^T f(x)\\ &\Longrightarrow \sum^K_{k=1}[\log p_k(x)+1-\lambda_0-\lambda^T f(x)] \end{aligned}\tag{m}
由于數(shù)據(jù)集是任意的,對(duì)數(shù)據(jù)集求和就意味著求和項(xiàng)里面的每一項(xiàng)都是0,所以有
p(x)=\exp(\lambda^Tf(x)+\lambda_0-1)\tag{n}
這就是指數(shù)族分布的公式。

共軛先驗(yàn)

在推斷問(wèn)題中,我們常常要計(jì)算下列式子
p(z|x)=\frac{P(x|z)p(z)}{\int_z P(x|z)p(z)dz}\tag{o}

上式中分母積分十分難計(jì)算,為了解決積分難計(jì)算的問(wèn)題,一個(gè)思路是能否繞過(guò)積分呢?我們知道存在如下關(guān)系 p(z|x)\propto P(x|z)p(z) ,其中 p(z|x) 是后驗(yàn)分布, P(x|z) 是似然, p(z) 是先驗(yàn)

如果存在這樣的?個(gè)先驗(yàn)分布,那么上?時(shí)刻的輸出可以作為下?時(shí)刻計(jì)算的先驗(yàn)分布,那么這樣整個(gè)計(jì)算就可以形成閉環(huán)。也就是說(shuō)如果后驗(yàn)分布和先驗(yàn)分布是同分布,此時(shí)我們稱先驗(yàn)分布和后驗(yàn)分布是共軛分布,且稱先驗(yàn)分布是似然函數(shù)的共軛先驗(yàn)。?如?斯分布家族在?斯似然函數(shù)下與其?身共軛,也叫?共軛。

共軛先驗(yàn)的好處主要在于代數(shù)上的方便性,可以直接給出后驗(yàn)分布的封閉形式,否則的話只能做數(shù)值計(jì)算

對(duì)于一個(gè)模型分布假設(shè)(似然),那么我們?cè)谇蠼庵?,常常需要尋找一個(gè)共軛先驗(yàn),使得先驗(yàn)與后驗(yàn)的形式相同,例如選取似然是二項(xiàng)分布,可取先驗(yàn)是 Beta 分布,那么后驗(yàn)也是 Beta 分布。指數(shù)族分布常常具有共軛的性質(zhì),于是我們?cè)谀P瓦x擇以及推斷具有很大的便利。

指數(shù)分布族中的 Complete Conditional

在上一篇博客中,我們提到,在推斷問(wèn)題中,對(duì)于第 j 個(gè)隱變量 z_j ,其 complete conditional (完全條件)為給定其他隱變量和觀測(cè)數(shù)據(jù)時(shí),它的條件密度,即 p(z_j|\boldsymbol z_{-j},\boldsymbol x) 。結(jié)合指數(shù)族分布的概念,當(dāng)后驗(yàn)分布為指數(shù)族分布時(shí),我們可以將隱變量的 complete conditional 寫(xiě)為
p(z_j|\boldsymbol z_{-j},\boldsymbol x)=h(z_j)\exp\{\eta_j(\boldsymbol z_{-j},\boldsymbol x)^\top z_j-a(\eta_j(\boldsymbol z_{-j},\boldsymbol x))\}\tag{36}
其中,

  • z_j 為充分統(tǒng)計(jì)量;
  • h(\cdot) 是基本測(cè)量(base measure)或底層觀測(cè)值;
  • a(\cdot) 是對(duì)數(shù)歸一化算子(log normalizer);
  • \eta_j(\boldsymbol z_{-j},\boldsymbol x) 是條件集合( \{\boldsymbol z_{-j},\boldsymbol x\})的函數(shù)。

所以,根據(jù)上一篇博客中,我們知道 CAVI 算法的參數(shù)更新公式(17),當(dāng)假設(shè)后驗(yàn)分布為指數(shù)族分布時(shí),坐標(biāo)上升的更新公式為
\begin{aligned} q(z_j)&\propto\exp\{\mathbb E[\log p(z_j|\boldsymbol z_{-j},\boldsymbol x)]\}\\ &=\exp\{\log h(z_j)+\mathbb E[\eta_j(\boldsymbol z_{-j},\boldsymbol x)]^\top z_j-\mathbb E[a(\eta_j(\boldsymbol z_{-j},\boldsymbol x))]\}\\ &=h(z_j)\exp\{\mathbb E[\eta_j(\boldsymbol z_{-j},\boldsymbol x)]^\top z_j\} \end{aligned}\tag{37-39}
更新公式揭示了更新變分因子的參數(shù)形式,每一個(gè)更新因子都與它對(duì)應(yīng)的 complete conditional 屬于同一指數(shù)族,它的參數(shù)擁有相同維度以及相同的基本測(cè)量 h(\cdot) 和對(duì)數(shù)歸因算子 a(\cdot) 。

我們可以令 v_j 為第 j 個(gè)數(shù)據(jù)點(diǎn)的變分參數(shù),當(dāng)我們更新每個(gè)因子時(shí),只需要令其變分參數(shù)等于完全條件的期望參數(shù)
v_j=\mathbb[\eta_j(\boldsymbol z_{-j},\boldsymbol x)]\tag{40}

條件共軛模型及其推斷

條件共軛模型

對(duì)于指數(shù)族模型,一個(gè)比較特殊的情況是條件共軛模型(conditionally conjugate models),它在貝葉斯學(xué)習(xí)和機(jī)器學(xué)習(xí)中常被運(yùn)用。

我們將條件共軛模型涉及的變量可以分為兩類

  • 一類變量是數(shù)族模型中我們要學(xué)習(xí)的參數(shù),對(duì)所有數(shù)據(jù)都有潛在的控制能力,稱之為全局隱變量(global latent variables),我們表示為向量 \boldsymbol\beta;
  • 令一類變量只對(duì)某一個(gè)數(shù)據(jù)點(diǎn)進(jìn)行控制,稱之為局部隱變量(local latent variables),我們表示為向量 \boldsymbol z 。其中 z_i 控制第 i 個(gè)數(shù)據(jù)點(diǎn)

根據(jù) i.i.d. 假設(shè),其聯(lián)合分布可以表示為
p(\boldsymbol\beta,\boldsymbol z,\boldsymbol x)=p(\boldsymbol\beta)\prod^n_i p(z_i,x_i|\boldsymbol\beta)\tag{41}
回顧前面提到的高斯混合,用這類的模型解釋的話,全局變量就是混合組件參數(shù),而局部變量就是每個(gè)數(shù)據(jù)點(diǎn) x_i 的聚類分配。

我們假設(shè)基于全局變量 \beta,每個(gè)數(shù)據(jù)點(diǎn) (x_i,z_i) 的聯(lián)合分布,都有指數(shù)族形式
p(z_i,x_i|\beta)=h(z_i,x_i)\exp\{\beta^\top t(z_i,x_i)-\alpha(\beta)\}\tag{42}
其中 t(z_i,x_i) 為充分統(tǒng)計(jì)量。

接下來(lái),我們可以假設(shè)全局變量的先驗(yàn)分布是公式(42)的共軛分布
p(\beta)=h(\beta)\exp\{\alpha^\top[\beta,-a(\beta)]-a(\alpha)\}\tag{43}
這一分布的自然參數(shù)為 \alpha=[\alpha_1,\alpha_2] ,充分統(tǒng)計(jì)量為全局變量及其對(duì)數(shù)歸一化的負(fù)數(shù)。

有了上述的共軛先驗(yàn),我們也能讓得到全局變量的 complete conditional 也在同一分布
\begin{aligned} p(\boldsymbol\beta,\boldsymbol z,\boldsymbol x)&=p(\boldsymbol\beta)\prod^n_i p(z_i,x_i|\boldsymbol\beta)\\ &=h(\beta)\exp\{\alpha^\top[\beta,-a(\beta)]-a(\alpha)\}\cdot\prod^n_i h(z_i,x_i)\exp\{\beta^\top t(z_i,x_i)-\alpha(\beta)\}\\ &=[h(\boldsymbol\beta)\prod^n_i h(z_i,x_i)]\cdot\exp\{\begin{bmatrix}\alpha_1+\sum^n_{i=1}t(z_i,x_i)\\\alpha_2+n\end{bmatrix}[\beta,-a(\beta)]-a(\hat\alpha)\}\\ &=H(\boldsymbol\beta,\boldsymbol x,\boldsymbol z)\exp\{\hat\alpha^\top[\beta,-a(\beta)]-a(\hat\alpha)\} \end{aligned}\tag{44}
其中,基本測(cè)量為 H(\boldsymbol\beta,\boldsymbol x,\boldsymbol z)=h(\boldsymbol\beta)\prod^n_i h(z_i,x_i),自然參數(shù)為 \hat\alpha=[\alpha_1+\sum^n_{i=1}t(z_i,x_i),\alpha_2+n]

而對(duì)于局部變量 z_i 的 complete conditional ,在 i.i.d. 假設(shè)下有等式
p(z_i|x_i,\boldsymbol\beta,\boldsymbol z_{-i},\boldsymbol x_{-i})=p(z_i|x_i,\boldsymbol\beta)\tag{45}
我們假設(shè)其服從指數(shù)族分布
p(z_i|x_i,\boldsymbol\beta)=h(z_j)\exp\{\eta(\beta,x_i)^\top z_i-a(\eta(\beta,x_i))\}\tag{46}

條件共軛模型的變分推斷

接下來(lái)讓我們將這個(gè)模型引入 CAVI 算法框架。我們將 \beta 的變分后驗(yàn)分布近似表示為 q(\beta|\lambda)\lambda全局變分參數(shù)),它與后驗(yàn)分布有相同的指數(shù)族分布;將 z_i 的變分后驗(yàn)分布近似為 q(z_i|\psi_i) ,其中 \psi_i 為數(shù)據(jù)點(diǎn) i局部變分參數(shù),它與局部 complete condititonal 有相同的指數(shù)族分布。

在 CAVI 算法中,我們將迭代地進(jìn)行局部變分參數(shù)和全局變分參數(shù)的更新。

局部變分參數(shù)的更新

這里我們用到前面的公式(40),可以得到更新公式
\psi_i=\mathbb E_\lambda[\eta(\boldsymbol\beta,x_i)]\tag{47}
得到的結(jié)果為公式(45)中自然參數(shù)的期望。

全局變分參數(shù)的更新

全局變分參數(shù)的更新利用類似的方法,更新公式為
\lambda=[\alpha_1+\sum^n_{i=1}\mathbb E_{\psi_i}[t(z_i,x_i)],\alpha_2+n]^\top\tag{48}
得到的結(jié)果為公式(44)中自然參數(shù)的期望。

ELBO 的計(jì)算

CAVI 通過(guò)迭代更新局部變分參數(shù)和全局變分參數(shù),每次迭代我們可以計(jì)算 ELBO ,來(lái)決定模型是否收斂。將公式(44)帶入 ELBO 公式(13),我們可以得到條件共軛模型的 ELBO
ELBO=(\alpha_1+\sum^n_{i=1}\mathbb E_{\psi_i}[t(z_i,x_i)])^\top\mathbb E_\lambda[\boldsymbol\beta]-(\alpha_2+n)\mathbb E_\lambda[a(\boldsymbol\beta)]-\mathbb E[\log q(\boldsymbol\beta,\boldsymbol z)]\tag{49}
后面一項(xiàng)可以表示為
\mathbb E[\log q(\boldsymbol\beta,\boldsymbol z)]=\lambda^\top\mathbb E_\lambda[t(\boldsymbol\beta)]-a(\lambda)+\sum^n_{i=1}\psi_i^\top\mathbb E_{\psi_i}[z_i]-a(\psi_i)\tag{50}
論文中附錄 C 還有描述了基于 LDA 的 CAVI 算法,有興趣的小朋友可以看一下論文,這里不過(guò)多贅述。

CAVI 的問(wèn)題

CAVI 給了變分推斷問(wèn)題一個(gè)解決問(wèn)題的框架,引入指數(shù)族分布使得模型更加簡(jiǎn)化,似乎到這里問(wèn)題已經(jīng)解決得差不多了,但事實(shí)上真的是這樣嗎?

實(shí)際上,在真實(shí)場(chǎng)景中,我們要應(yīng)對(duì)的數(shù)據(jù)可能是成百上千甚至是上十萬(wàn)的,這就給 CAVI 這一算法框架帶來(lái)了極大的挑戰(zhàn)。 CAVI 在計(jì)算過(guò)程中,每一次迭代都需要遍歷所有數(shù)據(jù),隨著數(shù)據(jù)量的增加,計(jì)算量也越來(lái)越大,這顯然是不符合我們的需要。

所以,我們還需要另外一套計(jì)算方法,對(duì)算法的效率進(jìn)行優(yōu)化。這也是我下一篇博客會(huì)講到的兩種方法——隨機(jī)變分推斷(Stochastic variational inference,SVI)和變分自編碼器(Variational Auto-encoder,VAE)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容