多任務模型融合如何平衡?

平衡目標

  • 不同任務loss量級相近
  • 不同任務相近速率學習
  • 不同任務數(shù)據(jù)量級接近
  • 不同任務重要性程度近似評估
  • 不同任務不確定性估計

廢話少說 上圖個球的


image.png

1. 加權融合

1.1 手動加權

image.png

人肉調(diào)權重。

1.2 動態(tài)加權平均

核心思想:利用loss變化率,平衡多任務學習速度。
[End-to-End Multi-Task Learning with Attention],CVPR 2019,Cites:107
https://arxiv.org/pdf/1803.10704v1.pdf
實現(xiàn):
https://github.com/lorenmt/mtan
本文提出了一種新的多任務學習體系結(jié)構(gòu),允許學習特定任務的特征級注意力。提出了MTAN(Multi-Task Attention Netwrok)網(wǎng)絡,由一個包含全局特征池化的共享網(wǎng)絡和基于特定任務的soft-attention模塊組成。這些模塊從全局共享特征中學習特定任務的特征,同時允許特征在不同任務間共享。
MTAN結(jié)構(gòu)主要包括兩大部分,一個任務共享的主網(wǎng)絡和K個特定任務的子網(wǎng)絡,共享網(wǎng)絡可以根據(jù)特定的任務進行設計,而每個特定于任務的子網(wǎng)絡由一組注意力模塊組成,這些模塊與共享網(wǎng)絡相連接。每個注意力模塊對共享網(wǎng)絡的特定層應用一個soft attention mask,以學習特定于任務的特征?;谶@種設計,共享主網(wǎng)絡可以看做是一個跨任務的特征表示,每一個attention mask都可以被看作是對主網(wǎng)絡的特征選擇器,決定哪些共享特征被用到自己的子任務中去。

image.png

最后損失函數(shù):
image.png

作者嘗試了不同的權重方案對模型效果的影響
image.png

DWA,使每個子任務首先計算前個epoch對應損失的比值,然后除以一個固定的值T,進行exp映射后,計算各個損失所占比。
image.png

K代表任務的數(shù)量;T是一個常數(shù),T=1時,w等同于softmax的結(jié)果;T足夠大時,w趨近1,各個任務的loss權重相同。該方法,只需要記錄不同階段的loss值,從而避免了為了獲取不同任務的梯度,運算較快。

2. loss梯度數(shù)量級差異

loss值的大小不重要,重要的是每一個loss產(chǎn)生的梯度的數(shù)量級不能差的特別大。如果梯度數(shù)量級差的很多的,可以給loss加權重。

2.1 梯度正則化

《Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks》,ICML 2018,Cites:177
https://arxiv.org/abs/1711.02257
核心思想:
同時考慮標簽損失與梯度損失,同時將梯度表示為權重的函數(shù),進行全局標準化優(yōu)化。
實現(xiàn):
https://github.com/brianlan/pytorch-grad-norm
與靜態(tài)加權不同,我們認為多任務的w是對于參數(shù)t的函數(shù),且有不同的wi(t)對多任務分配loss權重。

image.png

本文定義了兩種類型的損失:標簽損失和梯度損失,獨立優(yōu)化,不進行疊加。Label Loss即多任務學習中,每個任務的真實的數(shù)據(jù)標簽與網(wǎng)絡預測標簽計算的loss。Gradient Loss,用來衡量每個任務的loss的權重 wi(t)的好壞,Gradient Loss是關于權重wi(t)的函數(shù)。t表示網(wǎng)絡訓練中的第t步,權重為關于t的一個變量。
image.png

  • W是整個多任務學習網(wǎng)絡參數(shù)的子集,在文章中的實驗,作者選擇了網(wǎng)絡share bottom部分最后一層的參數(shù);
  • Giw(t)是任務i梯度標準化的值,是任務i的權重與loss 的乘積對參數(shù)W求梯度的L2范數(shù), 可以衡量某個任務loss的量級;
  • G---w(t)是全局梯度標準化的值(即所有任務梯度標準化值的期望值),通過所有求均值實現(xiàn)。
  • Li(0)與Li(t)分別代表子任務i的第0步和第t步的loss,L ~ i(t)在一定程度上衡量了任務i的反向的訓練速度,L~~i(t)越大,表示網(wǎng)絡訓練越慢。
  • image.png

    表示了各個任務反向訓練速度的期望;

  • 反向訓練速率除以反向訓練期望,得到相對反向訓練速率r,r是任務的相對反向訓練速度,r越大,表示任務i在所有任務中訓練越慢。
    從而:


    image.png

    計算完Gradient Loss后,通過以下函數(shù)對wi(t)進行更新:


    image.png

    image.png

2.2差異化學習率

對于多任務損失減小速度不一致的情況,需要自適應多任務的學習率,在 Adagrad, RMSProp, Adam 等等優(yōu)化算法中,自適應學習率主要針對高維空間模型參數(shù)變化不同的方向,使用不同的學習率。
對更新快的任務,使用小一點的學習率,對更新慢的任務,使用大一點的學習率。


image.png

trick:


image.png

3. 引入Bayesian框架下的平衡

3.1 確定性平衡

尤其針對于多任務指標提升有蹺蹺板現(xiàn)象的模型,如ESSM模型在做CTR、CVR目標時,引入pCTR與pCTCVR,可以基于先驗任務的估計可以作為下游任務的一個特征,從而組合loss。

3.2 動態(tài)任務優(yōu)先級

[Dynamic task prioritization for multitask learning],ECCV 2018,Cites:53
https://openaccess.thecvf.com/content_ECCV_2018/papers/Michelle_Guo_Focus_on_the_ECCV_2018_paper.pdf
本文提出了多任務學習的動態(tài)任務優(yōu)先級。這允許模型在訓練過程中動態(tài)的對困難任務進行優(yōu)先級排序,其中困難任務與性能成反比,并且困難隨著時間的推移而變化。與傳統(tǒng)認知上優(yōu)先學習簡單任務相反,本文通過多個實驗證明了優(yōu)先學習困難任務的重要性。
核心思想:DTP以為更難的任務對模型影響程度更高,希望讓更難學的任務具有更高的權重。
引入幾個概念:

image.png

  • 任務困難程度D,且D與性能指標k成反比,即


    image.png
  • 基礎公式:


    image.png
  • 任務難易程度降序排列:


    image.png

    image.png
  • 性能指標KPIs:
    對于每一個任務Tt,用Kt代表關鍵性能指標,也就是我們常聽到的KPI,其中 Kt∈[0,1]。那么什么樣的標準可以用來當做KPI呢?Kt必須是一個有意義的度量標準,比如準確率或者AP等。此外,還借鑒tensorflow上的知識,通過Kt計算一個滑動平均模型,公式如下:


    image.png

    阿爾法代表衰減因子,取[0,1],決定了模型更新的速度,越大任務越趨于穩(wěn)定。
    權重公式:


    image.png

    需要獲取不同step的KPI值,從而避免了為了獲取不同任務的梯度,運算較快。但是DTP沒有考慮不同任務的loss的量級,需要額外的操作把各個任務的量級調(diào)整到差不多,且需要額外計算KPI。
    猜想:
    是否可以結(jié)合梯度正則化,來平衡任務量級+考量KPI指標來進行多任務融合。

3.3不確定性平衡

本文希望讓“簡單”的任務有更高的權重。
貝葉斯建模中存在的兩類不確定性:

  • 認知不確定性(Epistemic uncertainty): 由于缺少訓練數(shù)據(jù)而引起的不確定性。

  • 偶然不確定性(Aleatoric uncertainty): 由于訓練數(shù)據(jù)無法解釋信息而引起的不確定性。
    對于偶然不確定性,又可以分為如下兩個子類:

  • 數(shù)據(jù)依賴(Data-dependant)或異方差(Heteroscedastic)不確定性,簡單的說就是對于不同的輸入,網(wǎng)絡輸出的噪聲大小不同。

  • 任務依賴(Task-dependant)或同方差(Homoscedastic)不確定性,是一個對所有輸入數(shù)據(jù)保持不變的量,并且在不同任務之間變化。

對每個task,受限于任務數(shù)據(jù)與模型刻畫能力,都存在task相關的不確定性,不確定性被量化為模型輸出與樣本標簽之間的噪聲。
[Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics]
https://arxiv.org/abs/1705.07115v3
實現(xiàn): https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example.ipynb
基礎假設:

  1. 回歸任務,噪聲服從高斯分布,噪聲強弱衡量指標高斯分布方差。
  2. 分類任務,噪聲服從玻爾茲曼分布,衡量標準為熱力學溫度。(這個在原文中沒有展開)
    // 回歸任務不確定性


    image.png

    // 分類任務不確定性


    image.png

    組合loss:
    image.png

    image.png
image.png

取log類似于正則項,防止不確定性估計過大,能夠?qū)⒉淮_定性壓縮在一定數(shù)值范圍內(nèi),不確定性估計越大,任務的不確定性越大,則任務的權重越小,即噪聲大且難學的任務權重會變小。
猜想:
與梯度正則化不同的是,不確定性估計


image.png

收斂受到初始化的影響較小,是否可以在梯度正則化中L(0)初始化時引入對不確定性的估計。

4. 梯度沖突系列方法

之前的有關梯度的多任務融合方式,主要針對任務loss梯度量級差異、不同任務學習速度差異而做研究,更進一步,討論梯度更新方向與梯度更新沖突的方法,正high起來。

從帕累托解集的角度來看

其實,從梯度角度出發(fā)的MTL平衡研究主要在討論,梯度從量化差異上升到多維空間離散解集,再到連續(xù)解集的過程。


image.png

從梯度沖突問題角度來看

1. 不同任務梯度之間的更新方向不同引起的梯度沖突,引出了本模塊要講的內(nèi)容
2. 非方向性沖突,由于梯度大小不同,如果一個梯度特別大,另一個梯度特別小,那么較大的梯度會形成主導從而影響較小梯度。
3. 在曲率很大的位置,高梯度值的任務的改進可能被高估,而高梯度值的任務的性能下降可能被低估

4.1 PCGrad

https://arxiv.org/pdf/2001.06782.pdf
實現(xiàn):
https://github.com/tianheyu927/PCGrad

  • 核心思想:
    將任務的梯度投影到具有沖突梯度的任何其他任務的梯度的法線平面上。


    image.png
  1. 首先,計算任務i梯度和隨機一個其他任務j梯度之間的余弦相似度,如果為負值表示是相互沖突的梯度,如上圖(a)
  2. 如果余弦相似度為負值,則通過下式計算任務i梯度在任務j梯度的法線平面上的投影(如上圖(b)):


    image.png
  3. 如果梯度沒有沖突,即余弦相似度為非負,則直接使用原始梯度即可,如上圖d。
  4. 選擇另外一個任務,作為任務i,重復上述流程,直到所有任務的梯度都修正過一遍。
  5. 將各任務修正后的梯度相加,得到最終的梯度。
  • 更新規(guī)則:


    image.png

4.2 MGDA 帕累托單解

Multiple-gradient descent algorithm (mgda) for multiobjective optimization. Comptes Rendus Mathematique, 350:313–318, 2012.
實現(xiàn):
在MMOE模型里增加MGDA的玩法。
https://github.com/king821221/tf-mmoe-mgda

image.png

4.3 CAGrad多梯度優(yōu)化

https://papers.nips.cc/paper/2021/file/9d27fdf2477ffbff837d73ef7ae23db9-Paper.pdf
實現(xiàn):
https://github.com/Cranial-XIX/CAGrad
核心思想:
通過構(gòu)造新的梯度更新方向來避免梯度沖突,但之前出現(xiàn)的梯度修正算法往往偏離了最小化平均loss的目標,導致最終結(jié)果無法收斂到全局最優(yōu)。文章在平均梯度方向鄰域內(nèi)求解梯度更新方向,最大化所有子目標中最小的局部提升。

  • 多任務學習與梯度沖突:


    image.png
  • 算法


    image.png

    image.png

5. 帕累托平衡

多目標優(yōu)化的解通常是一組均衡解(即一組由眾多 Pareto最優(yōu)解組成的最優(yōu)解集合 ,集合中的各個元素稱為 Pareto最優(yōu)解或非劣最優(yōu)解)。

5.1 幾個概念:

  • 什么是Pareto?
    上wiki:
    https://en.wikipedia.org/wiki/Pareto_efficiency
    (自己看去吧,三五三十五,三五太難了)
    Pareto optimality,即帕累托最優(yōu)。一個 Pareto 最優(yōu)解指的是在優(yōu)化任務有多個目標時,在這些目標中權衡得最好的解。如下圖所示,兩個目標的誤差值越小越好。由于模型能力的限制,其最左下角是理論上無法得到的不可行區(qū)域,所以最接近不可行區(qū)域的解即為 Pareto 最優(yōu)解。綠色的解由于能夠找到在兩個任務上都比它表現(xiàn)更好的解,所以它不是 Pareto 最優(yōu)解。所有的 Pareto 最優(yōu)解所組成的面就是 Pareto 前沿(Pareto front)。
    image.png
  • 為什么要有Pareto最優(yōu)?之前的MTL方法不香么?
    之前的MTL平衡方法主要關注的是如何在有多個任務時找到一個能夠權衡利弊的解,但由于 Pareto 的存在,一個獨立解是難以滿足多種偏好的,Pareto旨在找到多個Pareto解,進而利用 Pareto 前沿的連續(xù)性質(zhì),連續(xù)地從一個 Pareto 解出發(fā)找到其他的解。
  • 非劣解——多目標優(yōu)化問題并不存在一個最優(yōu)解,所有可能的解都稱為非劣解,也稱為Pareto解。
  • Pareto最優(yōu)解——無法在改進任何目標函數(shù)的同時不削弱至少一個其他目標函數(shù)。這種解稱作非支配解或Pareto最優(yōu)解。
    通俗的講,帕累托最優(yōu)解即不存在一個在任意子目標均不弱于它且至少在一個子目標上強于它的解,優(yōu)化帕累托最優(yōu)解的任意子目標必然弱化至少一個其他子目標。帕累托穩(wěn)定解即存在所有子目標上梯度的凸組合為零的可行解。

5.2 方法論

多目標優(yōu)化問題不存在唯一的全局最優(yōu)解 ,過多的非劣解是無法直接應用的 ,所以在求解時就是要尋找一個最終解。

  • 求最終解主要有三類方法:
  1. 生成法,即先求出大量的非劣解,構(gòu)成非劣解的一個子集,然后按照決策者的意圖找出最終解;(生成法主要有加權法﹑約束法﹑加權法和約束法結(jié)合的混合法以及多目標遺傳算法)。
  2. 交互法,不先求出很多的非劣解,而是通過分析者與決策者對話的方式,逐步求出最終解。
    3.目標重要度,算法以此為依據(jù),將多目標問題轉(zhuǎn)化為單目標問題進行求解。
  • 多目標優(yōu)化算法歸結(jié)起來有傳統(tǒng)優(yōu)化算法和智能優(yōu)化算法兩大類:
  1. 傳統(tǒng)優(yōu)化算法包括加權法、約束法和線性規(guī)劃法等,實質(zhì)上就是將多目標函數(shù)轉(zhuǎn)化為單目標函數(shù),通過采用單目標優(yōu)化的方法達到對多目標函數(shù)的求解。
  2. 智能優(yōu)化算法包括進化算法(Evolutionary Algorithm, 簡稱EA)、粒子群算法(Particle Swarm Optimization, PSO)等。
    兩者的區(qū)別——傳統(tǒng)優(yōu)化技術一般每次能得到Pareo解集中的一個,而用智能算法來求解,可以得到更多的Pareto解,這些解構(gòu)成了一個最優(yōu)解集,稱為Pareto最優(yōu)解(任一個目標函數(shù)值的提高都必須以犧牲其他目標函數(shù)值為代價的解集)。

5.3 帕累托多解

https://proceedings.neurips.cc/paper/2019/file/685bfde03eb646c27ed565881917c71c-Paper.pdf
該文章提出多個 Pareto 解的重要性,但由于沒有利用 Pareto 前沿的性質(zhì),該方法生成的每一個解都是從頭訓練的,并且無法形成連續(xù)的 Pareto 前沿。

image.png

5.4 連續(xù)帕累托優(yōu)化

paper:
https://arxiv.org/pdf/2006.16434.pdf
視頻:
https://icml.cc/virtual/2020/poster/5856

  • 核心思想:
  1. 利用 Hessian 矩陣對多目標問題的 Pareto 集(權衡多目標的理論最優(yōu)集)進行一階近似,并實驗說明沿該近似方向更新可以使已有的 Pareto 解保持其性質(zhì)。
  2. 提出使用 Hessian 向量積(Hessian-Vector Product)和 Krylov 子空間迭代法(共軛梯度法就是其中的一種)來高效解決上述近似問題,并將該方法應用于深度學習。
  3. 對生成的 Pareto 集進行線性插值,并由此得到無窮多的連續(xù)一階近似解,以滿足多目標問題下的不同偏好。
  • Pareto一階近似
    當一個解已經(jīng)到達 Pareto 最優(yōu),其各目標的梯度一定共面(當任務數(shù)為2時為共線)。從另一個角度說,這些梯度一定可以找到一組正權重使其加權和為0,寫為公式:


    image.png

    論文在參數(shù)空間尋找一個更新方向的參數(shù)表示:


    image.png

    v即參數(shù)更新方向,t為參數(shù)化變量,H為線性變化矩陣,目的將v表示為梯度加權的線性表示。
    下面介紹H這個鬼東西。
  • Hessian是什么鬼?Hessian 向量積又是什么鬼?
    H是加權過的各個目標函數(shù)的 Hessian 矩陣,這個矩陣的每個元素都是一個二階導數(shù),形式為:


    image.png

    就是為了避免直接計算H求H而做的一種方法,n為參數(shù)量,平方搞一搞,算死球了個喵的,哎,但是Hessian 和一個向量的點乘可以算,這個就是Hessian 向量積。


    image.png
  • Krylov 子空間迭代法求解v:
  1. MINRES
  2. 無需顯式構(gòu)造左側(cè)矩陣,只需要調(diào)用它和一個向量的點積。
  3. 迭代求解,可以 early-stopping 加快求解速度。
  • 如何保證連續(xù)?
    雖然我們已經(jīng)得到了更新方向,但基于這些更新方向生成的解依舊是離散的。幸運的是,由于我們是基于 Pareto 前沿在局部上是連續(xù)的這一假設,我們可以認為從同一解更新出的解都在同一連續(xù)曲面上,這也就意味著我們可以對其進行線性插值來生成連續(xù)曲面。如果我們有很多不同的初始 Pareto 解,我們就可以對它們分別進行更新獲得多個相交的 Pareto 前沿,然后篩選出最后的大片連續(xù)曲面。
  • Pareto 連續(xù)性方法總結(jié)
  1. 計算復雜度高,任務數(shù)量無法過度擴展,但這個面對一般工程3-4task其實還好。
  2. 在 Pareto 局部連續(xù)性的假設上進行的操作,但深度過參數(shù)化網(wǎng)絡往往高度非線性,難以從一個局部 Pareto 曲面跳到另一個局部 Pareto 曲面,容易陷入某 Pareto 曲面的局部最優(yōu)解。

僅供學習 無關利益 大論灌水 寶刀屠龍

REF:
https://www.zhihu.com/question/359962155/answer/928941168
https://zhuanlan.zhihu.com/p/362860026
https://www.zhihu.com/question/293188597/answer/1183278070
https://zhuanlan.zhihu.com/p/425672909
https://zhuanlan.zhihu.com/p/56613537
https://www.yanxishe.com/columnDetail/26367
https://www.zhihu.com/question/294635686/answer/606259229
https://www.zhihu.com/question/375794498/answer/1052779937
https://zhuanlan.zhihu.com/p/269492239
https://zhuanlan.zhihu.com/p/56986840
https://blog.csdn.net/Leon_winter/article/details/105014677
https://zhuanlan.zhihu.com/p/82234448
https://zhuanlan.zhihu.com/p/71012037
https://zhuanlan.zhihu.com/p/159000150
https://blog.csdn.net/weixin_43202635/article/details/82700342
https://www.zhihu.com/search?type=content&q=Pareto%20Multi-Task%20Learning
https://zhuanlan.zhihu.com/p/442858141
https://zhuanlan.zhihu.com/p/395220852
https://zhuanlan.zhihu.com/p/258327403
https://www.zhihu.com/column/c_1360363335737843712

?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容