因果推斷深度學(xué)習(xí)工具箱 - Perfect Match: A Simple Method for Learning Representations For Counterfactual Infe...

文章名稱

Perfect Match: A Simple Method for Learning Representations For Counterfactual Inference With Neural Networks

核心要點(diǎn)

現(xiàn)有的深度學(xué)習(xí)的overly complex,作者通過propensity matching的方法,用目標(biāo)樣本其他treatment下的最近鄰樣本,構(gòu)造訓(xùn)練的mini-batch,通過數(shù)據(jù)增廣的方式來解決觀測數(shù)據(jù)下因果推斷的2個(gè)基本問題,1)缺失的反事實(shí);2)混淆偏差。
比較大的優(yōu)勢是,這種方法不同于介紹過的文章,可以應(yīng)用于multiple treatment。

方法細(xì)節(jié)

問題引入

文章建立在potential outcome框架下,并且需要滿足unconfoundness的假設(shè),即Y \perp \!\!\! \perp T | X。為了需要估計(jì)因果效應(yīng)T(x) = \mathbb{E}[Y_i(1) - Y_i(0) | X=x_i],比較老的方法采用直接建模的方式,比如\hat{T}(x) = f(x_i, 1) - f(x_i, 0),也就是我們常說的single learner(如果兩個(gè)f帶有下表就是T-learner)。這種建模方式的弊端是高維的特征x_i會(huì)淹沒低緯度的干預(yù)t_i。
如果采用T-learner,不會(huì)存在干預(yù)被淹沒的問題,也比較靈活,卻引入了模型誤差帶來的因果效應(yīng)估計(jì)的偏差,并且犧牲了統(tǒng)計(jì)效率,不能夠充分利用樣本。

具體做法

首先,作者擴(kuò)展了TARNET,把two heads擴(kuò)展成為multiple heads,模仿TARNET解決treatment在x維度較高的時(shí)候,被淹沒的情況。但是這個(gè)改進(jìn)非常subtile[汗]。
其次,作者利用propensity score做balancing,構(gòu)造虛擬的隨機(jī)實(shí)驗(yàn)mini-batch。其實(shí)是利用最近鄰matching的方法,做數(shù)據(jù)增廣,期望在梯度回傳的時(shí)候減少overfit,來解決由于混淆變量引起的訓(xùn)練樣本分布不均,以及預(yù)測時(shí)分布遷移的問題。
同時(shí),作者定義(拓展)了一些評(píng)價(jià)指標(biāo),首先,利用真實(shí)值和估計(jì)值,拓展了PEHE到\hat{\epsilon}_{mPEHE},其中,在multiple treatment的時(shí)候,采用的是pairwise的平均值。這種指標(biāo)需要我們知道真實(shí)的各種counterfactual,除非模擬數(shù)據(jù),不然是不現(xiàn)實(shí)的。因此,模型選擇的部分,作者也提出了基于NN的\hat{\epsilon}_{NN-PEHE},

$\epsilon_{PEHE}$
metrics for multiple treatments

NN-PEHE

最后作者也證明了為什么這樣的訓(xùn)練數(shù)據(jù)下,利用SGD能夠得到causal effect的一致性估計(jì)。證明的核心邏輯是,利用各種因果效應(yīng)可以被識(shí)別的假設(shè),推導(dǎo)出我們是在做條件期望的極限。當(dāng)N趨于無窮大時(shí),極大概率會(huì)有一個(gè)樣本是和當(dāng)前樣本特征一模一樣,但treatment不一樣的。我們可以利用這樣的樣本估計(jì)因果效應(yīng)。個(gè)人覺得,建立在positive的假設(shè)下,這個(gè)證明應(yīng)該是沒問題的。。

proof of consistency

代碼實(shí)現(xiàn)

文章中的偽代碼,思路上還是比較直接的,每個(gè)mini-batch,利用propensity score尋找最近的樣本,返回mini-batch。后續(xù)直接用改進(jìn)的TARNET進(jìn)行訓(xùn)練。


pseudo code

To be continued...

心得體會(huì)

model selection criteria

文章另外比較大的貢獻(xiàn)是提供了一些模型評(píng)價(jià)指標(biāo),可以用來做模型選擇,并且公開了可以用來驗(yàn)證multiple treatment下模型性能的基準(zhǔn)數(shù)據(jù)集。雖然個(gè)人覺得\epsilon_{nnPEHE},其實(shí)就是作者訓(xùn)練的思路,有點(diǎn)作弊的嫌疑。但是,還是對(duì)觀測數(shù)據(jù)下的模型篩選,提供了一個(gè)思路(雖然這個(gè)思路,很在就有了,參見reference[1],但是作者詳細(xì)定義了指標(biāo),也與非nn的指標(biāo)進(jìn)行了統(tǒng)一)。

nearest neighbor matching

構(gòu)造mini-batch的時(shí)候,可以采用多種matching的方法,包括最近鄰,k近鄰等等,甚至不用propensity score作為balancing score,這些方法都可以從傳統(tǒng)的balancing里借鑒,甚至結(jié)合一些其他的balancing weighting學(xué)習(xí)的方法(后續(xù)會(huì)介紹,比如利用adversarial training)。這種trick也許在工業(yè)界,能有不錯(cuò)的效果。
同時(shí),這種方法和另外一些新興的imputing的方法有異曲同工之妙。

matching in minibatch&efficient in heavy overlap region

個(gè)人理解PM是matching的一種minibatch版本。在樣本特征分布重合度較高的地方,會(huì)被加強(qiáng)。因?yàn)樘卣鞣植贾睾隙容^高意味著對(duì)每一個(gè)樣本,有充足的其他treatment下樣本可以用來學(xué)習(xí)反事實(shí)。最極端的情況是,正好有特征完全重合的樣本,可以用來估計(jì)該樣本的causal effect。之前介紹的propensity dropout也是希望充分利用overlap度較高的樣本訓(xùn)練模型,從這個(gè)角度說,兩偏文章分別利用了兩種深度學(xué)習(xí)技巧augmentation和dropout來解決因果推斷的基本問題,簡單直接好理解,角度也比較新穎。
另外,考慮到神經(jīng)網(wǎng)絡(luò)需要大量的樣本進(jìn)行訓(xùn)練,propensity dropout確實(shí)也可能存在作者所說的樣本利用率欠缺的問題,考慮到神經(jīng)網(wǎng)絡(luò)需要大量的樣本進(jìn)行訓(xùn)練。其實(shí)也就是深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練技巧,數(shù)據(jù)增廣方法的各種花樣也許都可以用來結(jié)合一下構(gòu)造樣本。
作者也提到mini-batch的方法類似于minibatch sampling strategy,只不過是用在了causal inference的場景。這種mini-batch的方法優(yōu)于整體做augmentation,因?yàn)?,整體augmentation之后,還需要再采樣mini-batch,相同covariates的樣本可能并不會(huì)被分到同一個(gè)mini-batch,反而沒有起到虛擬隨機(jī)實(shí)驗(yàn)的模擬效果。

simple to use

PM方法確實(shí)非常簡單直接,因?yàn)椴恍枰淖兙W(wǎng)絡(luò)結(jié)構(gòu)、損失函數(shù),并且沒有添加任何額外的計(jì)算,所以理論上是可以和任何神經(jīng)網(wǎng)絡(luò)相關(guān)的causal inference方法組合的。但是,由于訓(xùn)練是改變了樣本周邊的分布,相當(dāng)于加權(quán)了和當(dāng)前樣本相關(guān)的周邊的別的treatment的樣本,如果和其他調(diào)整樣本分布的方法,比如re-weighting的方法一起使用時(shí),需要考慮re-weighting的學(xué)習(xí)過程是否收到影響。

文章引用

[1] Kapelner, A., Bleich, J., Levine, A., Cohen, Z., DeRubeis, R., & Berk, R. (2021). Evaluating the Effectiveness of Personalized Medicine With Software. Frontiers in Big Data, 4.
[2] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.
[3] https://github.com/d909b/perfect_match/tree/master/perfect_match/models

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容