一個(gè)異常糾結(jié)了好久,到處找原因,還好費(fèi)了半天功夫,真相水落石出:原來(lái)文檔中早就寫清楚了,怎么不看文檔!怎么不看文檔!怎么不看文檔!
事情是這樣子的:
我的電腦上目前的環(huán)境是這樣的:Tensorflow 1.7,CUDA 9.0, CUDNN 7.1。 用這個(gè)環(huán)境訓(xùn)練agents到目前位置也沒(méi)出現(xiàn)過(guò)什么問(wèn)題。直到今天,我訓(xùn)練了一個(gè)visual observation的agent。訓(xùn)練過(guò)程都很正常,沒(méi)有出現(xiàn)任何錯(cuò)誤。當(dāng)我把brain類型改成Internal,準(zhǔn)備再unity中使用訓(xùn)練好的模型時(shí),錯(cuò)誤出現(xiàn)了:
TFException: NodeDef mentions attr 'dilations' not in Op<name=Conv2D; signature=input:T, filter:T -> output:T; attr=T:type,allowed=[DT_HALF, DT_FLOAT]; attr=strides:list(int); attr=use_cudnn_on_gpu:bool,default=true; attr=padding:string,allowed=["SAME", "VALID"]; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW"]>; NodeDef: conv2d_2/Conv2D = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="VALID", strides=[1, 4, 4, 1], use_cudnn_on_gpu=true](visual_observation_0, conv2d_2/kernel/read). (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).
TensorFlow.TFStatus.CheckMaybeRaise (TensorFlow.TFStatus incomingStatus, System.Boolean last) (at <6ed6db22f8874deba74ffe3e566039be>:0)
TensorFlow.TFGraph.Import (TensorFlow.TFBuffer graphDef, TensorFlow.TFImportGraphDefOptions options, TensorFlow.TFStatus status) (at <6ed6db22f8874deba74ffe3e566039be>:0)
TensorFlow.TFGraph.Import (System.Byte[] buffer, TensorFlow.TFImportGraphDefOptions options, TensorFlow.TFStatus status) (at <6ed6db22f8874deba74ffe3e566039be>:0)
TensorFlow.TFGraph.Import (System.Byte[] buffer, System.String prefix, TensorFlow.TFStatus status) (at <6ed6db22f8874deba74ffe3e566039be>:0)
CoreBrainInternal.InitializeCoreBrain (Communicator communicator) (at Assets/ML-Agents/Scripts/CoreBrainInternal.cs:123)
Brain.InitializeBrain (Academy aca, Communicator communicator) (at Assets/ML-Agents/Scripts/Brain.cs:209)
Academy.InitializeEnvironment () (at Assets/ML-Agents/Scripts/Academy.cs:230)
Academy.Awake () (at Assets/ML-Agents/Scripts/Academy.cs:208)
跟著錯(cuò)誤輸出中的stack翻了一遍代碼,感覺(jué)沒(méi)有問(wèn)題??!Import出錯(cuò)有可能是TF訓(xùn)練的模型和Unity中TFCSharp的版本不匹配,但是不知道具體是什么版本啊!只好google去了。
最后在ML-Agents的github的issue中終于找到了答案(https://github.com/Unity-Technologies/ml-agents/issues/609):目前只支持TF1.4版本!我使用TF1.7訓(xùn)練時(shí),使用了TF1.4中沒(méi)有的操作,即錯(cuò)誤中提到的dilations這種2D卷積操作。那為什么以前使用是正常的呢?因?yàn)檫€沒(méi)有用visual observation來(lái)訓(xùn)練過(guò)。
真相如此,那只能把TF降級(jí)到1.4的版本了,但是還不止需要降級(jí)TF......還需要降級(jí)成配套的CUDA 8.0,以及CUDNN 6.0!這下會(huì)搞得其他依賴于cuda的conda環(huán)境也要出問(wèn)題呀!沒(méi)辦法。。。
這么重要的問(wèn)題,官方文檔應(yīng)該有講??!回去扒了扒官方的安裝文檔,果然......https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Installation-Windows.md