因果推斷深度學(xué)習(xí)工具箱 - Learning Representations for Counterfactual Inference

文章名稱

Learning Representations for Counterfactual Inference

核心要點

因果推斷的核心問題1)missing counterfactuals;2)imbalance covariates distribution under different intervention。只有知道了各種干預(yù)下的結(jié)果,才能計算出不同干預(yù)之間的因果效應(yīng)。為了準確的估計反事實,需要解決由于混淆變量引起的不同干預(yù)下,樣本特征分布不一致的問題,否則會具有selection bias,同時會帶來估計的高方差。
不同干預(yù)下樣本特征分布不一致,意味著P(X) \neq P(X|T=t),也意味著P(X|T=t_0) \neq P(X|T=t_1)。采用經(jīng)驗風險最小化的機器學(xué)習(xí)方法在觀察到的事實結(jié)果上可能擬合的很好,但在反事實上遇到了不同的特征分布,導(dǎo)致模型效果變差。例如,某個樣本的實際干預(yù)是T=t_1,模型擬合了P(Y|X, T=t_1),也就是事實數(shù)據(jù),但遇到反事實分布P(Y|X, T=t_0)時,就會估計的不準確。猛然一拍大腿,這個是不是有點像訓(xùn)練集和測試集分布不一致的問題?
這種不一致的問題在領(lǐng)域遷移里是基操,作者借鑒domain adaptation的思想,結(jié)合表示學(xué)習(xí),利用正則化的手段,使得P(\Phi(X)|T=t_0) = P(\Phi(X)|T=t_1),其中\Phi(X)是學(xué)習(xí)到的特征表示(分布平衡是在表示層做的)。有了這種表示,模型能夠更好地回答反事實的問題。并且,作者證明了這種方法是在最小化counterfactual的regret的上界。

方法細節(jié)

問題引入

因果推斷問題,旨在計算不同干預(yù)之間的效果差異,即Y_1(x) - Y_0(x),其中x是樣本的covariates。然而,我們只能觀測到一個factual outcome,y_{i}^{F} = t_{i} Y_{1}(x_i) + (1 - t_{i}) Y_{0}(x_i)。也就是說,觀測數(shù)據(jù)實際來自于兩個分布y_{i}^{F},P^F(x, t) = P(x) P(t|x)P^{CF} = P(x) P(\neg t|x),其中CF代表counterfactual。由于混淆變量的存在,這兩個分布是不同的。如果通過直接建模的方式來估計,無論是單個模型h(x_i, t_i) = \hat{y}_{i}^{F},還是多個模型h_{j}(x_i, t_i = j) = \hat{y}_{i:t_i=j}^{F},我們都需要把一個在不同分布上訓(xùn)練的模型,應(yīng)用到在另一個不同的分布上來估計counterfactual,就像訓(xùn)練集和測試集的分布不同一樣,導(dǎo)致模型效果不夠理想(實際上,由于觀測數(shù)據(jù)得到的P(X|T=t_0)P(X|T=t_1)也只是真實條件分布的采樣,最終會導(dǎo)致有更大的偏差)。由于這里的分布不一致,指的是covariates,也就是特征分布不一致,也就是所謂的covariates shift,是domain adaptation的一個特殊場景。
其實,在很多文章中都有過闡述,領(lǐng)域遷移(協(xié)變量遷移)與因果推斷的關(guān)系是密不可分。因此,作者從領(lǐng)域遷移的idea出發(fā),把因果推斷問題定義為領(lǐng)域遷移問題,通過正則化的方法來平衡不同干預(yù)下的covariates分布。其他利用re-weight,調(diào)整樣本權(quán)重的方法不同,文章提出的方法的正則化是在表示層進行的,也就是約束的是\Phi(x),\Phi是映射函數(shù),把covariates映射到representation。通常情況下表示層會是更稠密的向量,有更深層次的語義。

具體做法

learning process

為了更好地估計因果效應(yīng),我們需要學(xué)習(xí)兩個函數(shù)\Phi(x)h(\Phi(x), t)。這兩個函數(shù)需要在整個covariates分布上有良好的泛化能力。因此需要做到三點,

  • 估計好事實,對觀測到的實際outcome估計準確;
  • 估計好反事實,這里利用的是最近鄰的方法,來構(gòu)造反事實,即y^{CF}_{i:t_i = 0} = y^{CF}_{NN_i:t_i \neq 0},其中NN_i表示最近鄰的鄰居。本質(zhì)是在模擬樣本的反事實,有點類似于matching的方法。
  • 平衡好不同干預(yù)下的representation

整體的損失函數(shù)如下圖所示,分別對應(yīng)著上邊所說的三個要點。


loss in step 1

那么如何學(xué)習(xí)到好的樣本表示呢,作者闡述了兩種學(xué)習(xí)器,1)線性表示學(xué)習(xí)器;2)深度表示學(xué)習(xí)器。

To be continued

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容