現(xiàn)象
使用 Hugging Face Trainer 在單機多卡環(huán)境下對 LLAMA2-7B 進行 LoRA finetuning 時,在第一次保存 checkpoint 時,程序 assert out,關鍵 error trace log 如下
[1711969312.876608051] rank2.python: Reading from remote process' memory failed. Disabling CMA support
[1711969312.876606027] rank4.python: Reading from remote process' memory failed. Disabling CMA support
[1711969312.876618213] rank5.python: Reading from remote process' memory failed. Disabling CMA support
rank5: Assertion failure at psm3/ptl_am/ptl.c:196: nbytes == req->req_data.recv_msglen
rank2: Assertion failure at psm3/ptl_am/ptl.c:196: nbytes == req->req_data.recv_msglen
根因
順藤摸瓜
-
accelerate的 FSDP 在保存 checkpoint 時,會調用其自己的save_fsdp_optimizer方法,該方法首先調用了 PyTorch 的FSDP.optim_state_dict方法以獲取并確保每個rank上都有其需要的最新的optimizer的state_dict,然后根據相應的fsdp_state_dict_type設置將其保存。Assert out 就發(fā)生在FSDP.optim_state_dict調用中。 - 找到 PyTorch
FSDP.optim_state_dict的實現(xiàn),發(fā)現(xiàn) assert out 發(fā)生在調用FullyShardedDataParallel._optim_state_dict_impl時。 - 再轉至
FullyShardedDataParallel._optim_state_dict_impl的實現(xiàn),發(fā)現(xiàn) assert out 發(fā)生在其調用_optim_state_dict時。 - 繼續(xù)轉至
_optim_state_dict的實現(xiàn),發(fā)現(xiàn) assert out 發(fā)生在其調用_map_param_key_to_optim_keys時。 - 繼續(xù)轉至
_map_param_key_to_optim_keys的實現(xiàn),發(fā)現(xiàn) assert out 發(fā)生在調用dist.broadcast_object_list。
至此,瓜已得,需分析 dist.broadcast_object_list。
抽絲剝繭
-
首先需要分析
dist.broadcast_object_listbroadcast 了啥,看下代碼:key_obj_list: List[Optional[List[_OptimStateKey]]] = ( [all_optim_state_keys] if rank == 0 else [None]) dist.broadcast_object_list(key_obj_list, src=0, group=group)由代碼可知,broad cast 的是一堆
_OptimStateKey,而_OptimStateKey是一個字符串組成的tuple,每個字符串里放的是optimizer 中每個模型參數(shù)的狀態(tài)(即 momentum, variance 等)的 unflat 的 fully qualified name。這些東西是在 CPU 上的,需要由 rank 0 廣播到其余 rank,以對齊參數(shù)名。 -
好了,知道數(shù)據是在 CPU 上的,那就知道為啥在 checkpointing 之前是好的了,因為此前都是涉及到 GPU 上 tensor 的 collective communication,那塊看來是好的。Intel CPU 和 GPU 平臺的 collective communication 后端走的是 oneCCL,其中 CPU 上數(shù)據的單機多卡 broadcast 走的是什么方案呢?再去看一眼 log:
[1711969312.876608051] rank2.python: Reading from remote process' memory failed. Disabling CMA support從 log 中大致可以猜出通信方案是 shared memory(SHM),不然不會有
Reading from remote process' memory failed,這很合理,因為是單機多卡;且采用的 SHM 方案是 CMA(Cross Memory Attach),這是 Linux 內核實現(xiàn)的一種 kernel assisted zero copy SHM 機制,示意如下(摘自此論文):
那就是 CMA 出了啥問題。以上只是猜想,猜想只是起點,總是要實證。既然 oneCCL 是集合通信后端,我們就要分析一下它。從這兒可以知道:oneCCL 有兩個 transport 后端, 即 OFI 和 MPI。從這兒又可以知道,intel MPI 的實現(xiàn)現(xiàn)在也基于 OFI 了,而 OFI 的實現(xiàn)是 libfabric, 如下:
那么我們就去 libfabric 的代碼庫中找找有沒有以下 log 相關的代碼:[1711969312.876608051] rank2.python: Reading from remote process' memory failed. Disabling CMA support rank2: Assertion failure at psm3/ptl_am/ptl.c:196: nbytes == req->req_data.recv_msglen然后就從這里到了如下代碼:
size_t nbytes = psm3_cma_get(pid, (void *)req->rts_sbuf, req->req_data.buf, req->req_data.recv_msglen); if (nbytes == -1) { ptl->psmi_kassist_mode = PSMI_KASSIST_OFF; _HFI_ERROR("Reading from remote process' memory failed. Disabling CMA support\n"); } else { psmi_assert_always(nbytes == req->req_data.recv_msglen); cma_succeed = 1; } psmi_assert_always(nbytes == req->req_data.recv_msglen);從代碼可以看到,
psm3_cma_get調用返回錯誤,首先觸發(fā)了Reading from remote process' memory failed. Disabling CMA support\n錯誤信息打印,隨后又通過psmi_assert_alwaysassert out 了,與我們看到的 log 完全一樣。至此,絲抽完了,已經找到問題發(fā)生的地方了。 轉到
psm3_cma_get的實現(xiàn)代碼,可知是process_vm_readv返回錯誤了。查看process_vm_readv的手冊可以看到如下表述:
Permission to read from or write to another process is governed by a ptrace access mode PTRACE_MODE_ATTACH_REALCREDS check; see ptrace(2).
因為 CMA 涉及到進程訪問別的進程的內存,一個有可能的合理懷疑就是當前進程沒有權限訪問另一個進程的內存,這個也通過 CMA patch the commit message 得到了印證,其中寫道:
Currently mem_read allows only processes who are currently ptrace'ing the target and are still able to ptrace the target to read from the target.
那就上谷歌搜一下 cma ptrace 看下 CMA 需要怎樣的 ptrace 設置,果然首個鏈接就找到了答案。
Same issue... Try the following:
https://groups.io/g/OpenHPC-users/topic/openmpi_and_shared_memory/16489081?p=,,,20,0,0,0::recentpostdate%2Fsticky,,,20,2,0,16489081
$ echo 0 > /proc/sys/kernel/yama/ptrace_scope
or
$ sudo echo 0 > /proc/sys/kernel/yama/ptrace_scope
or
$ echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope
嘗試一下,搞定!
解法
$ echo 0 > /proc/sys/kernel/yama/ptrace_scope
ptrace_scope 說明見此。所測系統(tǒng)之前 ptrace_scope 值是 1。也就是說,非 rank 0 的進程要讀 rank 0 的 SHM,必須滿足 rank 0 是它們的后代進程才行,這顯然不符合當前工作負載的實情。所以需要設成 0,以使得主要這些進程是同一個 uid 下的就可以讀 SHM。
最后的話
當前來看,結果很重要;長遠來看,過程很重要!這是工程的真諦。

