(updated in 20210828)
文章名稱(chēng)
Deep Counterfactual Networks with Propensity-Dropout
核心要點(diǎn)
作者把causal inference視作multitask learning的問(wèn)題,并通過(guò)dropout的方式來(lái)進(jìn)行正則化,減少selection bias帶來(lái)的影響。這種正則化通過(guò)估計(jì)每一個(gè)樣本的propensity score,來(lái)調(diào)整網(wǎng)絡(luò)的復(fù)雜度。本質(zhì)上類(lèi)似于propensity weighting的方法,只是更具有神經(jīng)網(wǎng)絡(luò)的風(fēng)格。調(diào)整網(wǎng)絡(luò)復(fù)雜度,會(huì)在propensity score較低的時(shí)候減少過(guò)擬合的風(fēng)險(xiǎn),實(shí)現(xiàn)更好的泛化。
文章主要關(guān)注CATE(ITE),并且在binary treatment的場(chǎng)景下展開(kāi)討論。但是,可以模仿其他的multiple head的神經(jīng)網(wǎng)絡(luò)方法,遷移到multiple treatment的場(chǎng)景下。通過(guò),Generalized Propensity Score,可以泛化文中提出的propensity score dropout。
方法細(xì)節(jié)
問(wèn)題引入
文章建立在potential outcome框架下,并且需要滿(mǎn)足unconfoundness的假設(shè),即。為了需要估計(jì)因果效應(yīng)
,比較老的方法采用直接建模的方式,比如
,也就是我們常說(shuō)的single learner(如果兩個(gè)
帶有下表就是T-learner)。這種建模方式的弊端是高維的特征
會(huì)淹沒(méi)低緯度的干預(yù)
。
如果采用T-learner,不會(huì)存在干預(yù)被淹沒(méi)的問(wèn)題,也比較靈活,卻引入了模型誤差帶來(lái)的因果效應(yīng)估計(jì)的偏差,并且犧牲了統(tǒng)計(jì)效率,不能夠充分利用樣本。
具體做法
為了同時(shí)確保靈活性和樣本效率,作者把因果推斷問(wèn)題看做多任務(wù)學(xué)習(xí)的問(wèn)題,不同的counterfactual估計(jì)是不同的、卻相關(guān)的任務(wù)。不同的干預(yù)意味著不同的任務(wù),任務(wù)id和treatment id一一對(duì)應(yīng)。如下圖所示,左側(cè)的potential outcome network具有多層共享網(wǎng)絡(luò)來(lái)提升樣本效率,因?yàn)橥ㄟ^(guò)不同treatment的數(shù)據(jù)共同訓(xùn)練了這些層,提取了共同的因素。而后續(xù)單獨(dú)的輸出網(wǎng)絡(luò),有保證了靈活性和獨(dú)立性。相當(dāng)于在T-learner和S-learner之間做了一個(gè)折中,也是很多神經(jīng)網(wǎng)絡(luò)處理不同treatment的慣常方式。

但是,這樣的網(wǎng)絡(luò)并不能糾正由混淆變量帶來(lái)的偏差。類(lèi)似IPW和其他傳統(tǒng)的因果推斷方法,作者也利用propensity score進(jìn)行樣本權(quán)重調(diào)節(jié)。只不過(guò),調(diào)節(jié)不是發(fā)生在loss的權(quán)重上,而是改變網(wǎng)絡(luò)的復(fù)雜程度。在propensity score比較極端(非常接近0或者1)的情況下,利用dropout,使得網(wǎng)絡(luò)變得簡(jiǎn)單。在propensity score比較接近0.5的情況下,保持原來(lái)網(wǎng)絡(luò)的復(fù)雜程度,這樣可以充分利用,在不同干預(yù)下,特征分布重合度(overlap)較高的樣本。這種權(quán)重調(diào)節(jié),其實(shí)是減少了對(duì)具有極端propensity score樣本的學(xué)習(xí)充分程度,是另一種意義上的降權(quán)。這樣的降權(quán),可以減少對(duì)selection bias的過(guò)度擬合,提升網(wǎng)絡(luò)預(yù)測(cè)在不同counterfactual估計(jì)上的泛化能力。
作者定義
除此之外,通過(guò)Monte Carlo Dropout可以得到樣本的置信度的點(diǎn)估計(jì),估計(jì)步驟如下圖。圖中,

整個(gè)網(wǎng)絡(luò)是交替進(jìn)行訓(xùn)練的,在訓(xùn)練的過(guò)程中,share網(wǎng)絡(luò)會(huì)在每一個(gè)epoch中被充分訓(xùn)練,而不同干預(yù)的輸出網(wǎng)絡(luò)是每隔一個(gè)epoch單獨(dú)訓(xùn)練的,每一個(gè)epoch都會(huì)采用propensity dropout。整個(gè)流程如下圖所示,單數(shù)epoch左側(cè)的outcome網(wǎng)絡(luò)沒(méi)有被訓(xùn)練,而雙數(shù)epoch時(shí)右側(cè)將不會(huì)被訓(xùn)練。

總的來(lái)說(shuō),propensity dropout的方法是很有創(chuàng)意的結(jié)合了神經(jīng)網(wǎng)絡(luò)的dropout和propensity score weighting,同時(shí)延續(xù)了當(dāng)時(shí)multitask learning的counterfactual估計(jì)的主流思路。
代碼實(shí)現(xiàn)
文中訓(xùn)練的偽代碼如下圖所示,

具體的pytorch實(shí)現(xiàn)可以參見(jiàn),
https://github.com/Shantanu48114860/Deep-Counterfactual-Networks-with-Propensity-Dropout
首先是DCN網(wǎng)絡(luò)(作者稱(chēng)沒(méi)有propensity dropout的網(wǎng)絡(luò)為deep counterfactual network),可以看到是簡(jiǎn)單的FFN,不過(guò)有兩個(gè)獨(dú)立的輸出和一些共享層(代碼里是2層)。

DCN的訓(xùn)練是交替訓(xùn)練的,如下圖所示,在偶數(shù)epoch的時(shí)候,訓(xùn)練potential outcome

之后是propensity score網(wǎng)絡(luò),可以看到也是簡(jiǎn)單的2層FFN,當(dāng)然可以根據(jù)數(shù)據(jù)采用更復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu)。

Propensity Score網(wǎng)路的訓(xùn)練是在所有數(shù)據(jù)集上進(jìn)行的,訓(xùn)練方法如下圖所示。本質(zhì)就是利用觀(guān)測(cè)樣本中treatment的分布,去估計(jì)treatment assignment。

最后來(lái)看一下Propensity Dropout的部分??梢钥吹剑紫壤玫玫降腜ropensity Score,計(jì)算信息熵,然后按照公式得到dropout的概率值,進(jìn)而得到dropout mask向量。

Propensity Dropout過(guò)程中用到的一些工具函數(shù),也都比較簡(jiǎn)單直接。

值得注意的是,訓(xùn)練時(shí)候的Propensity Score是數(shù)據(jù)集中處理得到的,實(shí)際訓(xùn)練中,可以先通過(guò)訓(xùn)練好的propensity score網(wǎng)絡(luò)得到估計(jì)的每個(gè)樣本的propensity score。
心得體會(huì)
利用propensity score在樣本粒度調(diào)整模型復(fù)雜度
個(gè)人理解propensity score dropout,是作者把selection bias對(duì)counterfactual估計(jì)的影響,看做是某種過(guò)擬合的產(chǎn)物,是一種保守的反事實(shí)估計(jì)策略。為了減少混淆變量的影響,放棄了對(duì)觀(guān)測(cè)outcome的充分?jǐn)M合。
Doubly-Robust
雖然作者說(shuō)同時(shí)利用了利用了outcome和propensity score的信息,所以是doubly-robust的。但是,個(gè)人認(rèn)為沒(méi)有最多只能說(shuō)是隱式的進(jìn)行了propensity score weighting,但是并沒(méi)有進(jìn)行顯示的doubly-robust的建模。
文章引用
[1] https://github.com/Shantanu48114860/Deep-Counterfactual-Networks-with-Propensity-Dropout