更新日期截止2020年5月22日,項目定期維護和更新,維護各種SOTA的Federated Learning的攻防模型。Github 地址https://github.com/shanxuanchen/attacking_federate_learning
前言
聯(lián)邦學習通過只對梯度的傳輸,可以在互不公開數(shù)據(jù)集的前提下訓練模型。然后,也正是這種隱匿性,讓Federated Learning非常脆弱,天然不得不在non-iid的數(shù)據(jù)環(huán)境中進行訓練(真實情況絕大部分是non-iid)。因此,黑客可以通過poison data,或者backdooor的方式攻擊模型,從而使模型無法收斂或者留有后門。
Krum
Krum是2017年NIPS上具有拜占庭容錯能力的SGD方法,Krum最核心的思想就是選擇每次都會選擇最靠近其余個梯度的梯度。這個表達可能有點拗口,通俗易懂點理解就是,找出一個
個最集中的梯度集合,其中使用歐拉距離來進行度量。
實現(xiàn)代碼如下:
def krum(users_grads, users_count, corrupted_count, distances=None,return_index=False, debug=False):
if not return_index:
assert users_count >= 2*corrupted_count + 1,('users_count>=2*corrupted_count + 3', users_count, corrupted_count)
non_malicious_count = users_count - corrupted_count
minimal_error = 1e20
minimal_error_index = -1
if distances is None:
distances = _krum_create_distances(users_grads)
for user in distances.keys():
errors = sorted(distances[user].values())
current_error = sum(errors[:non_malicious_count])
if current_error < minimal_error:
minimal_error = current_error
minimal_error_index = user
if return_index:
return minimal_error_index
else:
return users_grads[minimal_error_index]
Trimmed Mean
基于均值的拜占庭容錯SGD,核心思想很簡單,就是找最接近均值的個梯度。
實現(xiàn)代碼如下:
def trimmed_mean(users_grads, users_count, corrupted_count):
number_to_consider = int(users_grads.shape[0] - corrupted_count) - 1
current_grads = np.empty((users_grads.shape[1],), users_grads.dtype)
for i, param_across_users in enumerate(users_grads.T):
med = np.median(param_across_users)
good_vals = sorted(param_across_users - med, key=lambda x: abs(x))[:number_to_consider]
current_grads[i] = np.mean(good_vals) + med
return current_grads
Bulyan
Bulyan是目前SOTA的一個拜占庭容錯算法,它十分巧妙,簡單地來說,就是不斷循環(huán)選擇,然后跑一次Trimmed Mean。而特別的是,該算法是使用Krum來選擇
的。所以,目前SOTA的容錯算法是Krum + Trimmed Mean的一個結(jié)合。
算法如下:

代碼如下:
def bulyan(users_grads, users_count, corrupted_count):
assert users_count >= 4*corrupted_count + 3
set_size = users_count - 2*corrupted_count
selection_set = []
distances = _krum_create_distances(users_grads)
while len(selection_set) < set_size:
currently_selected = krum(users_grads, users_count - len(selection_set), corrupted_count, distances, True)
selection_set.append(users_grads[currently_selected])
# remove the selected from next iterations:
distances.pop(currently_selected)
for remaining_user in distances.keys():
distances[remaining_user].pop(currently_selected)
return trimmed_mean(np.array(selection_set), len(selection_set), 2*corrupted_count)
總結(jié)
代碼里是實現(xiàn)了的,但是由于篇幅問題,最強的攻擊模型《A Little Is Enough: Circumventing Defenses For Distributed Learning》留著下一篇。這篇論文攻擊了上面三種SOTA的防御機制,值得深入研究與探討。關(guān)于Byzantine SGD的攻防這個方向,其實代碼量不大,需要有對SGD收斂性證明的能力,與SGD防御的證明能力。