因果推斷深度學(xué)習(xí)工具箱 - CounterFactual Regression with Importance Sampling Weights

文章名稱

CounterFactual Regression with Importance Sampling Weights

核心要點(diǎn)

文章主要針對(duì)binary treatment的場(chǎng)景,能夠用來(lái)估計(jì)CATE(當(dāng)然也可以估計(jì)ATE)。作者基于CFR[1],提出利用上下文感知的重要性采樣來(lái)取代CFR的固定權(quán)重,來(lái)平衡selection bias。相比于BNN和CFR利用頻率統(tǒng)計(jì)得到的樣本權(quán)重,文章提出的方法能夠?qū)崿F(xiàn)selection bias的平衡,彌補(bǔ)IPM loss較小平衡能力不足的問(wèn)題。CFR-IS采用兩階段交替學(xué)習(xí)。首先,利用給定權(quán)重,訓(xùn)練類似BNN和CFR的loss。隨后,通過(guò)最小化NLL得到更優(yōu)的權(quán)重。

方法細(xì)節(jié)

問(wèn)題引入

BNN和CFR主要利用IPM來(lái)平衡不同treatment下的分布差異,具體loss如下圖所示。但是由于這種平衡是建立在P(\Phi(x), t)的聯(lián)合分布上的,t的影響可能會(huì)被忽略,而且高維特征會(huì)導(dǎo)致有treatment引起的分布距離比較小,不能夠提供足夠的loss,來(lái)進(jìn)行selection bias的平衡。

CFR loss

同時(shí),BNN和CFR在構(gòu)建factual loss(估計(jì)樣本實(shí)際輸出)的時(shí)候,采用了頻率統(tǒng)計(jì)得到的權(quán)重,即圖中的w_i,其計(jì)算方法如下圖所示??梢钥闯鲞@個(gè)weight是一個(gè)頻率統(tǒng)計(jì)值,本質(zhì)是一個(gè)propensity score的倒數(shù)。
CFR weight

CFR weight(2)

而經(jīng)過(guò)loss的改寫(xiě),發(fā)現(xiàn)這部分權(quán)重的目標(biāo)是平衡樣本不均(參見(jiàn)引用[1]),并不能起到balancing當(dāng)中的re-weigthing的作用。因此,總體作者認(rèn)為對(duì)selection bias的矯正是不充分的。所以,提出利用重要性采樣的方法來(lái)學(xué)習(xí)樣本權(quán)重實(shí)現(xiàn)不同treatment下的covariates均衡(大家都是這條路,做法不同而已)。
CFR loss reformation

具體做法

因此,作者把兩個(gè)不同的treatment下的分布,看做是兩個(gè)不同分布的采樣。為了對(duì)齊兩個(gè)分布的學(xué)習(xí)效果,我們把counterfactual的covariates分布p(y, \phi | \neg t)當(dāng)做是目標(biāo)分布p(x),把實(shí)際觀測(cè)到的樣本分布p(y, \phi | t)當(dāng)做采樣分布q(x)。例如,當(dāng)我們處理t = 0的數(shù)據(jù)是,t = 0的covariates分布就是采樣分布,而t = 1是目標(biāo)分布。

importance sampling

當(dāng)控制住\phi = \Phi(x)之后,下圖中因果圖的后門(mén)被阻斷(后門(mén)準(zhǔn)則),那么ty是獨(dú)立的。
belif net

因此,得到不同treatment下y\phi的聯(lián)合分布的比值等于不同treatment下\phi的比值。這樣我們構(gòu)造了一個(gè)有covariates得到的隱向量\Phi(x)決定的重要性采樣權(quán)重。
counterfactual IS

為了能夠在觀測(cè)數(shù)據(jù)上也表現(xiàn)得好(也就是預(yù)測(cè)好factual),作者在權(quán)重上加1,表示采樣分布和目標(biāo)分布是同一個(gè)。
weight

但是,我們發(fā)現(xiàn)直接估計(jì)這個(gè)weight不現(xiàn)實(shí),因?yàn)槭且烙?jì)一個(gè)隱向量在不同treatment下出現(xiàn)的概率的比值。無(wú)論是直接估計(jì)概率密度函數(shù),還是用高斯建模概率的密度函數(shù)要么計(jì)算量大,要么假設(shè)太強(qiáng),不準(zhǔn)確。所以作者采用貝葉斯法則轉(zhuǎn)化了weight的估計(jì)方式,如下圖所示。其中,\pi_{0}(t|\phi)表示propensity score,可以用LR或者神經(jīng)網(wǎng)絡(luò)得到。
weight reformation

propensity \pi

學(xué)習(xí)propensity的loss就是簡(jiǎn)單的NLL。作者采用交替優(yōu)化CFR loss和propensity loss的方法進(jìn)行學(xué)(也許可以一起學(xué),類似Dragnnet)。
propensity loss

具體的網(wǎng)絡(luò)結(jié)構(gòu)如圖所示,
network structure

代碼實(shí)現(xiàn)

pseudo code

(留坑待填...)

心得體會(huì)

why IS work?

個(gè)人理解,IS就是把眼分布的數(shù)據(jù)用來(lái)?yè)Q到目標(biāo)分布來(lái)估計(jì)目標(biāo)結(jié)果。這里weight是用在factual loss的那個(gè)部分,也就是說(shuō),我們假設(shè)樣本可能來(lái)自counterfactual分布,在這種情況下還用觀測(cè)結(jié)果作為事實(shí)來(lái)代表counterfactual的值,就需要用IS。并且IS之后,就可以把估計(jì)factual loss當(dāng)做是在估計(jì)counterfactual loss。

add 1 to weight

在權(quán)重上+1,就把一個(gè)樣本分成了兩個(gè)。因?yàn)椋?img class="math-inline" src="https://math.jianshu.com/math?formula=(1%2Bw_%7Bi%7D)%20x%20%3D%20x%20%2B%20w_i%20x" alt="(1+w_{i}) x = x + w_i x" mathimg="1">。本質(zhì)是表示如果這個(gè)樣本實(shí)際就是從觀測(cè)分布來(lái)的,那么就不需要加權(quán),但需要被用來(lái)估計(jì)factual。

文章引用

[1] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容