FedReID: 聯(lián)邦學(xué)習(xí)在行人重識(shí)別上的首次深入實(shí)踐

行人重識(shí)別的訓(xùn)練需要收集大量的人體數(shù)據(jù)到一個(gè)中心服務(wù)器上,這些數(shù)據(jù)包含了個(gè)人敏感信息,因此會(huì)造成隱私泄露問題。聯(lián)邦學(xué)習(xí)是一種保護(hù)隱私的分布式訓(xùn)練方法,可以應(yīng)用到行人重識(shí)別上,以解決這個(gè)問題。但是在現(xiàn)實(shí)場(chǎng)景中,將聯(lián)邦學(xué)習(xí)應(yīng)用到行人重識(shí)別上因?yàn)閿?shù)據(jù)異構(gòu)性,會(huì)導(dǎo)致精度下降和收斂的問題。

數(shù)據(jù)異構(gòu)性:數(shù)據(jù)非獨(dú)立分布 (non-IID) 和 各端數(shù)據(jù)量不同。

這是篇來自 ACMMM20 Oral 的論文,主要通過構(gòu)建一個(gè) benchmark,并基于 benchmark 結(jié)果的深入分析,提出兩個(gè)優(yōu)化方法,提升現(xiàn)實(shí)場(chǎng)景下聯(lián)邦學(xué)習(xí)在行人重識(shí)別上碰到的數(shù)據(jù)異構(gòu)性問題。

論文地址:Performance Optimization for Federated Person Re-identification via Benchmark Analysis
開源代碼:https://github.com/cap-ntu/FedReID

本文主要對(duì)這篇文章的這三個(gè)方面內(nèi)容做簡要介紹:

  1. Benchmark: 包括數(shù)據(jù)集、新的算法、場(chǎng)景等
  2. Benchmark 的結(jié)果分析
  3. 優(yōu)化方法:知識(shí)蒸餾、權(quán)重重分配

Benchmark

數(shù)據(jù)集

數(shù)據(jù)集由9個(gè)最常用的 行人重識(shí)別 數(shù)據(jù)集構(gòu)成,具體的信息如下:

Datasets

這些數(shù)據(jù)集的數(shù)據(jù)量、ID數(shù)量、領(lǐng)域都不同,能夠有效的模擬現(xiàn)實(shí)情況下的數(shù)據(jù)異構(gòu)性問題。

算法

傳統(tǒng)聯(lián)邦學(xué)習(xí)算法 Federated Averaging (FedAvg) 要求端邊全模型同步,但是 ReID 的分類層的維度由 ID數(shù)量決定,很可能是不同的。所以這篇論文提出了只同步部分的模型 Federated Partial Averaging (FedPav).

Federated Partial Averaging

FedPav 的每一輪訓(xùn)練可以通過4個(gè)步驟完成:

  1. Server 下發(fā)一個(gè)全局模型到每個(gè) Client
  2. 每個(gè) Client 收到全局模型后,將全局模型加上本地的分類器,用本地?cái)?shù)據(jù)進(jìn)行訓(xùn)練,每個(gè) Client 得到一個(gè) local model
  3. Client 將 local model 的 backbone 上傳到 Server
  4. Server 對(duì)所有 client 收到的 model 進(jìn)行加權(quán)平均。

完整的算法可以參考下圖:

Benchmark 結(jié)果

通過 Benchmark 的實(shí)驗(yàn),論文里描述了不少聯(lián)邦學(xué)習(xí)和行人重識(shí)別結(jié)合的洞見。這邊著重提出兩點(diǎn)因數(shù)據(jù)異構(gòu)性導(dǎo)致的問題。

1. 大數(shù)據(jù)集在聯(lián)邦學(xué)習(xí)中的精度低于單個(gè)數(shù)據(jù)集訓(xùn)練的精度

  • FedPav: 聯(lián)邦學(xué)習(xí)總模型的精度
  • FedPav Local Model: 聯(lián)邦學(xué)習(xí)各邊端模型模型上傳前在各自邊端測(cè)試的精度
  • Local Training: 基準(zhǔn),每個(gè)數(shù)據(jù)集單獨(dú)訓(xùn)練和測(cè)試的精度

Local Training 效果比聯(lián)邦學(xué)習(xí)的效果好,說明這些大數(shù)據(jù)集沒法在聯(lián)邦學(xué)習(xí)中受益。需要有更好的算法來提高精度。

2. 聯(lián)邦學(xué)習(xí)訓(xùn)練不收斂

通過這兩個(gè)數(shù)據(jù)集測(cè)試曲線可以看出,因?yàn)閿?shù)據(jù)異構(gòu)性的影響,精度波動(dòng)較大,收斂性差。

優(yōu)化方法

采用知識(shí)蒸餾,提高收斂

因?yàn)閿?shù)據(jù)的異構(gòu)性的原因,導(dǎo)致參與聯(lián)邦學(xué)習(xí)多方上傳前的本地模型的性能優(yōu)于云端服務(wù)器進(jìn)行模型融合后的模型性能,另外數(shù)據(jù)異構(gòu)性還導(dǎo)致了訓(xùn)練的不穩(wěn)定性和難收斂的問題。針對(duì)這個(gè)問題,本方案提出使用知識(shí)蒸餾的方法,將參與聯(lián)邦學(xué)習(xí)的多方的本地模型當(dāng)成教師模型,云端服務(wù)器的模型作為學(xué)生模型,用知識(shí)蒸餾的方法更好的將教師模型的知識(shí)傳遞到學(xué)生模型,以此提高了模型訓(xùn)練的穩(wěn)定性和收斂性。完整算法可以參考下圖:

Knowledge Distillation
image-20201016033811427.png

下面的實(shí)驗(yàn)結(jié)果顯示,采用知識(shí)蒸餾(橙線)的訓(xùn)練收斂效果能夠得到有效提高。

提出權(quán)重重分配,提高精度

原算法在 Server 上做模型整合,采用的是加權(quán)平均的方法,用每個(gè) Client 的數(shù)據(jù)量作為權(quán)重,進(jìn)行加權(quán)平均。每個(gè) Client 的數(shù)據(jù)量差距可能非常大,有的占比 40%,有的占比不到 1%,所以該論文提出了進(jìn)行權(quán)重分配。調(diào)整聯(lián)邦學(xué)習(xí)模型融合時(shí)各方模型更新的權(quán)重:給訓(xùn)練效果越好的邊端,分配更大的權(quán)重,在模型融合時(shí)產(chǎn)生更大的影響。訓(xùn)練效果的衡量是通過比較每一方本地訓(xùn)練前后模型用一批數(shù)據(jù)做推理產(chǎn)生的特征的余弦距離,余弦距離越大,該訓(xùn)練產(chǎn)生的變化越大,該分配的權(quán)重越大。完整算法可以參考下圖:

Cosine Distance Weight

下表格的實(shí)驗(yàn)結(jié)果顯示,權(quán)重重分配使所有邊端模型的性能都超過 Local Training,帶來普遍的性能提升。

總結(jié)

針對(duì)數(shù)據(jù)隱私問題,這篇論文將聯(lián)邦學(xué)習(xí)應(yīng)用到行人重識(shí)別,并做了深入的研究分析。構(gòu)建了一個(gè) Benchmark,并基于實(shí)驗(yàn)結(jié)果帶來的洞見,提出了使用<u>知識(shí)蒸餾</u>和<u>權(quán)重重分配</u>的方法來解決數(shù)據(jù)異構(gòu)性帶來的性能問題。

算法細(xì)節(jié)和更多實(shí)驗(yàn)結(jié)果,推薦閱讀原論文和開源代碼。

資源

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容