文章名稱
CounterFactual Regression with Importance Sampling Weights
核心要點(diǎn)
文章主要針對(duì)binary treatment的場(chǎng)景,能夠用來(lái)估計(jì)CATE(當(dāng)然也可以估計(jì)ATE)。作者基于CFR[1],提出利用上下文感知的重要性采樣來(lái)取代CFR的固定權(quán)重,來(lái)平衡selection bias。相比于BNN和CFR利用頻率統(tǒng)計(jì)得到的樣本權(quán)重,文章提出的方法能夠?qū)崿F(xiàn)selection bias的平衡,彌補(bǔ)IPM loss較小平衡能力不足的問(wèn)題。CFR-IS采用兩階段交替學(xué)習(xí)。首先,利用給定權(quán)重,訓(xùn)練類似BNN和CFR的loss。隨后,通過(guò)最小化NLL得到更優(yōu)的權(quán)重。
方法細(xì)節(jié)
問(wèn)題引入
BNN和CFR主要利用IPM來(lái)平衡不同treatment下的分布差異,具體loss如下圖所示。但是由于這種平衡是建立在的聯(lián)合分布上的,
的影響可能會(huì)被忽略,而且高維特征會(huì)導(dǎo)致有treatment引起的分布距離比較小,不能夠提供足夠的loss,來(lái)進(jìn)行selection bias的平衡。

同時(shí),BNN和CFR在構(gòu)建factual loss(估計(jì)樣本實(shí)際輸出)的時(shí)候,采用了頻率統(tǒng)計(jì)得到的權(quán)重,即圖中的


而經(jīng)過(guò)loss的改寫(xiě),發(fā)現(xiàn)這部分權(quán)重的目標(biāo)是平衡樣本不均(參見(jiàn)引用[1]),并不能起到balancing當(dāng)中的re-weigthing的作用。因此,總體作者認(rèn)為對(duì)selection bias的矯正是不充分的。所以,提出利用重要性采樣的方法來(lái)學(xué)習(xí)樣本權(quán)重實(shí)現(xiàn)不同treatment下的covariates均衡(大家都是這條路,做法不同而已)。

具體做法
因此,作者把兩個(gè)不同的treatment下的分布,看做是兩個(gè)不同分布的采樣。為了對(duì)齊兩個(gè)分布的學(xué)習(xí)效果,我們把counterfactual的covariates分布當(dāng)做是目標(biāo)分布
,把實(shí)際觀測(cè)到的樣本分布
當(dāng)做采樣分布
。例如,當(dāng)我們處理
的數(shù)據(jù)是,
的covariates分布就是采樣分布,而
是目標(biāo)分布。

當(dāng)控制住

因此,得到不同treatment下

為了能夠在觀測(cè)數(shù)據(jù)上也表現(xiàn)得好(也就是預(yù)測(cè)好factual),作者在權(quán)重上加1,表示采樣分布和目標(biāo)分布是同一個(gè)。

但是,我們發(fā)現(xiàn)直接估計(jì)這個(gè)weight不現(xiàn)實(shí),因?yàn)槭且烙?jì)一個(gè)隱向量在不同treatment下出現(xiàn)的概率的比值。無(wú)論是直接估計(jì)概率密度函數(shù),還是用高斯建模概率的密度函數(shù)要么計(jì)算量大,要么假設(shè)太強(qiáng),不準(zhǔn)確。所以作者采用貝葉斯法則轉(zhuǎn)化了weight的估計(jì)方式,如下圖所示。其中,


學(xué)習(xí)propensity的loss就是簡(jiǎn)單的NLL。作者采用交替優(yōu)化CFR loss和propensity loss的方法進(jìn)行學(xué)(也許可以一起學(xué),類似Dragnnet)。

具體的網(wǎng)絡(luò)結(jié)構(gòu)如圖所示,

代碼實(shí)現(xiàn)

(留坑待填...)
心得體會(huì)
why IS work?
個(gè)人理解,IS就是把眼分布的數(shù)據(jù)用來(lái)?yè)Q到目標(biāo)分布來(lái)估計(jì)目標(biāo)結(jié)果。這里weight是用在factual loss的那個(gè)部分,也就是說(shuō),我們假設(shè)樣本可能來(lái)自counterfactual分布,在這種情況下還用觀測(cè)結(jié)果作為事實(shí)來(lái)代表counterfactual的值,就需要用IS。并且IS之后,就可以把估計(jì)factual loss當(dāng)做是在估計(jì)counterfactual loss。
add 1 to weight
在權(quán)重上+1,就把一個(gè)樣本分成了兩個(gè)。因?yàn)椋?img class="math-inline" src="https://math.jianshu.com/math?formula=(1%2Bw_%7Bi%7D)%20x%20%3D%20x%20%2B%20w_i%20x" alt="(1+w_{i}) x = x + w_i x" mathimg="1">。本質(zhì)是表示如果這個(gè)樣本實(shí)際就是從觀測(cè)分布來(lái)的,那么就不需要加權(quán),但需要被用來(lái)估計(jì)factual。
文章引用
[1] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.