解決辦法參考: Get the mean from a list of tensors
問題背景
最近跑一個Siamese-FC的復現程序,要求配置是python2.7+pytorch0.4,之前安裝的是Pytorch1.0,降低版本下載過慢多次失敗,最終選擇在Pytorch1.0版本下解決這個問題。
問題描述
項目地址:https://github.com/zzwang058/SiamFC-PyTorch
在運行run_Train_SiamFC.py中
print ("Epoch %d training loss: %f, validation loss: %f" % (i+1, np.mean(train_loss), np.mean(val_loss)))
是Pytorch1.0存在的問題,似乎是因為對張量求平均?
解決辦法如下:
將 np.mean(a) 替換為 torch.mean(torch.stack(a))