論文筆記|SoftMatch:解決半監(jiān)督學(xué)習(xí)中偽標(biāo)簽質(zhì)量與數(shù)量的平衡問(wèn)題

論文標(biāo)題:SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning

論文鏈接:https://arxiv.org/abs/2301.10921

代碼鏈接:https://github.com/Hhhhhhao/SoftMatch

論文來(lái)源:ICLR2023

作者單位:Carnegie Mellon University, Saarland Informatics Campus, Microsoft Research Asia, Mohamed bin Zayed University of AI

摘要

??之前的半監(jiān)督學(xué)習(xí)方法在選取偽標(biāo)簽時(shí),往往不能兼顧偽標(biāo)簽的質(zhì)量和數(shù)量(利用率)。針對(duì)半監(jiān)督學(xué)習(xí)中如何進(jìn)行偽標(biāo)注(pseudo-labeling)的問(wèn)題,本文從給無(wú)監(jiān)督損失項(xiàng)加權(quán)的角度解決了偽標(biāo)簽的質(zhì)量與數(shù)量權(quán)衡(trade-off)的問(wèn)題。本文設(shè)計(jì)了一個(gè)截?cái)嗟母咚购瘮?shù),根據(jù)模型對(duì)無(wú)標(biāo)注數(shù)據(jù)預(yù)測(cè)結(jié)果,對(duì)每個(gè)無(wú)標(biāo)注樣本的無(wú)監(jiān)督損失加權(quán)。此外還提出了一個(gè)均勻性對(duì)齊模塊來(lái)解決各類偽標(biāo)簽分布不平衡的問(wèn)題。

動(dòng)機(jī)

??之前半監(jiān)督學(xué)習(xí)中大量使用的偽標(biāo)注(pseudo-labeling)方法通過(guò)模型對(duì)無(wú)標(biāo)注數(shù)據(jù)的預(yù)測(cè),根據(jù)預(yù)先定義的置信度閾值來(lái)選取可信的偽標(biāo)簽。然而這種機(jī)制存在偽標(biāo)簽質(zhì)量和數(shù)量之間的trade-off,比如在FixMatch中使用了一個(gè)較高的固定置信度閾值(0.95)來(lái)選取偽標(biāo)簽。雖然這種方法確保了偽標(biāo)簽的質(zhì)量,但是它丟棄了相當(dāng)一部分unconfident但是正確的偽標(biāo)簽。如下圖所示,有近71%的正確偽標(biāo)簽被丟棄。相反,動(dòng)態(tài)閾值(Flexmatch、Adamatch、Dash等方法)鼓勵(lì)提高偽標(biāo)簽的利用率,但降低了偽標(biāo)簽的質(zhì)量。如下圖所示,有近16%的利用的偽標(biāo)簽是錯(cuò)誤的。


??本文顯示地定義了描述偽標(biāo)簽質(zhì)量與數(shù)量的方法,并從一個(gè)樣本加權(quán)的角度分析了之前的研究方法。本文認(rèn)為之前的方法缺乏對(duì)偽標(biāo)簽分布的加權(quán)函數(shù)施加的復(fù)雜假設(shè)。置信度閾值可以被視為一種階梯函數(shù),其根據(jù)樣本置信度來(lái)指定二類(binary)的權(quán)重,即假設(shè)超過(guò)置信度閾值的偽標(biāo)簽同等地正確,然而其它樣本都是錯(cuò)誤的?;诖耍疚奶岢隽薙oftMatch半監(jiān)督學(xué)習(xí)框架來(lái)同時(shí)提高偽標(biāo)簽的質(zhì)量和利用率。

先驗(yàn)知識(shí)

問(wèn)題描述

??在半監(jiān)督學(xué)習(xí)中,對(duì)于一個(gè)batch的有標(biāo)注數(shù)據(jù)和無(wú)標(biāo)注數(shù)據(jù),模型通常根據(jù)有監(jiān)督分類損失和無(wú)監(jiān)督損失\mathcal{L}=\mathcal{L}_{s}+\mathcal{L}_{u}來(lái)進(jìn)行優(yōu)化。
\mathcal{L}_{s}=\frac{1}{B_{L}} \sum_{i=1}^{B_{L}} \mathcal{H}\left(\mathbf{y}_{i}, \mathbf{p}\left(\mathbf{y} \mid \mathbf{x}_{i}^{l}\right)\right)

其中\mathcal{H}代表交叉熵,B_{L}代表有標(biāo)注數(shù)據(jù)的batch大小。對(duì)于無(wú)監(jiān)督損失,本文采用了模型對(duì)強(qiáng)增強(qiáng)數(shù)據(jù)\Omega\left(\mathbf{x}^{u}\right)的預(yù)測(cè)結(jié)果和弱增強(qiáng)數(shù)據(jù)\omega\left(\mathbf{x}^{u}\right)的偽標(biāo)簽計(jì)算加權(quán)的交叉熵?fù)p失:
\mathcal{L}_{u}=\frac{1}{B_{U}} \sum_{i=1}^{B_{U}} \lambda\left(\mathbf{p}_{i}\right) \mathcal{H}\left(\hat{\mathbf{p}}_{i}, \mathbf{p}\left(\mathbf{y} \mid \Omega\left(\mathbf{x}_{i}^{u}\right)\right)\right)
其中\mathbf{p}是對(duì)\mathbf{p}\left(\mathbf{y} \mid \omega\left(\mathbf{x}^{u}\right)\right)的縮寫,\hat{\mathbf{p}}是one-hot的偽標(biāo)簽,\lambda(\mathbf{p})是無(wú)標(biāo)注樣本的值域在[0,\lambda_{\max }]的加權(quán)函數(shù),B_{U}是無(wú)標(biāo)注數(shù)據(jù)的batch大小。

從樣本加權(quán)角度來(lái)考慮數(shù)量和質(zhì)量之間的權(quán)衡
  • 定義2.1 (偽標(biāo)簽的數(shù)量):所有無(wú)標(biāo)注數(shù)據(jù)樣本權(quán)重的期望:
    f(\mathbf{p})=\mathbb{E}_{\mathcal{D}_{U}}[\lambda(\mathbf{p})] \in\left[0, \lambda_{\max }\right]
  • 定義2.2 (偽標(biāo)簽的質(zhì)量):偽標(biāo)簽正確率的期望,此處用\mathbf{y}^{u}表示無(wú)標(biāo)注數(shù)據(jù)\mathbf{x}^{u}的類標(biāo)簽:
    g(\mathbf{p})=\sum_{i}^{N_{U}} \mathbb{1}\left(\hat{\mathbf{p}}_{i}=\mathbf{y}_{i}^{u}\right) \frac{\lambda\left(\mathbf{p}_{i}\right)}{\sum_{j}^{N_{U}} \lambda\left(\mathbf{p}_{j}\right)}=\mathbb{E}_{\bar{\lambda}(\mathbf{p})}\left[\mathbb{1}\left(\hat{\mathbf{p}}=\mathbf{y}^{u}\right)\right] \in[0,1]
    其中\bar{\lambda}(\mathbf{p})=\lambda(\mathbf{p}) / \sum \lambda(\mathbf{p})表示概率質(zhì)量函數(shù)(probability mass function ,PMF)。

??在之前的方法中,\lambda(\mathbf{p})幾乎沒被研究過(guò),本文首先概括\lambda(\mathbf{p}), \bar{\lambda}(\mathbf{p}), f(\mathbf{p}), g(\mathbf{p}),如下表所示:

SoftMatch

用于樣本加權(quán)的高斯函數(shù)

??本文假設(shè)邊緣分布的概率質(zhì)量函數(shù)\bar{\lambda}(\mathbf{p})服從一個(gè)動(dòng)態(tài)的、以在第t輪訓(xùn)練的\mu_{t}為均值,\sigma_{t}為標(biāo)準(zhǔn)差截?cái)喔咚狗植迹?br> \lambda(\mathbf{p})=\left\{\begin{array}{ll} \lambda_{\max } \exp \left(-\frac{\left(\max (\mathbf{p})-\mu_{t}\right)^{2}}{2 \sigma_{t}^{2}}\right), & \text { if } \max (\mathbf{p})<\mu_{t}, \\ \lambda_{\max }, & \text { otherwise. } \end{array}\right.

??然而高斯函數(shù)的參數(shù)\mu_{t}\sigma_{t}是未知的,本文根據(jù)模型對(duì)無(wú)標(biāo)注數(shù)據(jù)的歷史預(yù)測(cè)結(jié)果來(lái)估計(jì)\mu\sigma^{2}。在第t輪迭代,計(jì)算empirical的均值和方差:
\hat{\mu}_=\hat{\mathbb{E}}_{B_{U}}[\max (\mathbf{p})]=\frac{1}{B_{U}} \sum_{i=1}^{B_{U}} \max \left(\mathbf{p}_{i}\right)
\hat{\sigma}_^{2}=\hat{\operatorname{Var}}_{B_{U}}[\max (\mathbf{p})]=\frac{1}{B_{U}} \sum_{i=1}^{B_{U}}\left(\max \left(\mathbf{p}_{i}\right)-\hat{\mu}_\right)^{2}

??隨后利用之前batches的指數(shù)移動(dòng)平均來(lái)獲得一個(gè)更穩(wěn)定的估計(jì):
\hat{\mu}_{t}=m \hat{\mu}_{t-1}+(1-m) \hat{\mu}_
\hat{\sigma}_{t}^{2}=m \hat{\sigma}_{t-1}^{2}+(1-m) \frac{B_{U}}{B_{U}-1} \hat{\sigma}_^{2}
其中\hat{\mu}_{0}被初始化為1/C,\hat{\sigma}_{0}^{2}被初始化為1.0。利用估計(jì)的\hat{\mu}_{t}\hat{\sigma}_{t}^{2}代入上述高斯函數(shù)公式來(lái)計(jì)算樣本的權(quán)重。隨著模型學(xué)習(xí)得越好,預(yù)測(cè)結(jié)果越趨于穩(wěn)定,\hat{\mu}_{t}增加,\hat{\sigma}_{t}減少。截?cái)嗟母咚购瘮?shù)可被視為一種soft和adaptive版本的置信度閾值,因此本方法被命名為SoftMatch。

均勻性對(duì)齊(Uniform Alignment)

??由于不同類別存在不同的學(xué)習(xí)難度,生成的偽標(biāo)簽類別可能是不平衡的分布,可能限制PMF假設(shè)的泛化性。為解決此問(wèn)題,本文提出一致性對(duì)齊來(lái)鼓勵(lì)不同類別生成更類別更為平衡的偽標(biāo)簽。本文將偽標(biāo)簽的分布定義為模型對(duì)無(wú)標(biāo)注預(yù)測(cè)結(jié)果的期望\mathbb{E}_{\mathcal{D}_{U}}\left[\mathbf{p}\left(\mathbf{y} \mid \mathbf{x}^{u}\right)\right]。在訓(xùn)練過(guò)程中,利用一個(gè)batch中無(wú)標(biāo)注數(shù)據(jù)預(yù)測(cè)結(jié)果的指數(shù)移動(dòng)平均\hat{\mathbb{E}}_{B_{U}}\left[\mathbf{p}\left(\mathbf{y} \mid \mathbf{x}^{u}\right)\right]來(lái)進(jìn)行估計(jì)。對(duì)于每個(gè)無(wú)標(biāo)注樣本的預(yù)測(cè)概率\mathbf{p},進(jìn)行以下正則化:
\mathrm{UA}(\mathbf{p})=\text { Normalize }\left(\mathbf{p} \cdot \frac{\mathbf{u}(C)}{\hat{\mathbb{E}}_{B_{U}}[\mathbf{p}]}\right)
其中\mathbf{u}(C) \in \mathbb{R}^{C}是均勻分布,\text { Normalize }(\cdot)=(\cdot) / \sum(\cdot)確保歸一化的概率加和為1.0。最終的加權(quán)函數(shù)改寫為:
\lambda(\mathbf{p})=\left\{\begin{array}{ll} \lambda_{\max } \exp \left(-\frac{\left(\max (\mathrm{UA}(\mathbf{p}))-\hat{\mu}_{t}\right)^{2}}{2 \hat{\sigma}_{t}^{2}}\right), & \text { if } \max (\mathrm{UA}(\mathbf{p}))<\hat{\mu}_{t}, \\ \lambda_{\max }, & \text { otherwise. } \end{array}\right.

??在計(jì)算樣本權(quán)重時(shí),均勻性對(duì)齊模塊為less-predicted的偽標(biāo)簽施加更大的權(quán)重,為more-predicted偽標(biāo)簽施加更少的權(quán)重,以緩解不平衡問(wèn)題。

??均勻性對(duì)齊(UA)和分布對(duì)齊(DA)的重要區(qū)別在于無(wú)監(jiān)督學(xué)習(xí)損失的計(jì)算上。DA在進(jìn)行歸一化之后可能產(chǎn)生很多錯(cuò)誤的偽標(biāo)簽,它們被用作交叉熵?fù)p失中的soft target。而UA利用原始的預(yù)測(cè)值來(lái)計(jì)算偽標(biāo)簽,利用歸一化的預(yù)測(cè)結(jié)果來(lái)計(jì)算樣本的權(quán)重。算法整體流程如下圖所示:


實(shí)驗(yàn)

  • 類別平衡半監(jiān)督實(shí)驗(yàn)


  • 類別不平衡半監(jiān)督實(shí)驗(yàn)


  • 文本分類數(shù)據(jù)集的半監(jiān)督實(shí)驗(yàn)


  • 定性分析


  • 消融實(shí)驗(yàn)


?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容