本文主要參考反向傳播之一:softmax函數(shù),添加相應(yīng)的pytorch的實現(xiàn)
softmax函數(shù)
定義及簡單實現(xiàn)
記輸入為,標(biāo)簽為
, 通過softmax之后的輸出為
,則:
其中,
pytorch庫函數(shù)和手動實現(xiàn):
def seed_torch(seed=1234):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
def softmax(x, dim=1):
x=x-torch.max(x, dim=dim)[0].unsqueeze(dim=dim) # 防止溢出
res = torch.exp(x) / torch.sum(torch.exp(x), dim=dim).unsqueeze(dim=dim)
return res
if __name__ == '__main__':
seed_torch(1234)
x=torch.rand(4,7, requires_grad=True)
print(torch.softmax(x,dim=-1))
print(softmax(x, dim=-1))
求導(dǎo)推導(dǎo)

當(dāng)i=j

當(dāng)i≠j
綜上:
其中,
公式推導(dǎo)來自:https://zhuanlan.zhihu.com/p/37740860。圖中的
交叉熵
計算公式
說明:
-
表示分類任務(wù)中的類別數(shù)
-
表示的該樣本對應(yīng)的標(biāo)簽,是一個
維的向量,一般使用one-hot編碼,
表示其第
維的值,為0或者1
-
表示模型計算出的分類概率向量,
維,每一維表示預(yù)測得到該類的概率,即
,由softmax函數(shù)計算得到
pytorch 實現(xiàn):
seed_torch(1234)
x=torch.rand(4,7, requires_grad=True) # 4個樣本,共7類
y=torch.LongTensor([1,3,5,0]) # 對應(yīng)的標(biāo)簽
criterion = torch.nn.CrossEntropyLoss() # pytroch庫
out = criterion(x,y)
print(out)
# 自己實現(xiàn)
gt = torch.zeros(4,7).scatter(1, y.view(4,1),1) # 生成one-hot標(biāo)簽,scatter的用法可參考jianshu.com/p/b4e9fd4048f4
loss = -(torch.log(softmax(x, dim=-1)) * gt).sum() / 4 # 對樣本求平均
print(loss)
輸出:

反向傳播
所以:
pytorch實現(xiàn)
seed_torch(1234)
x=torch.rand(4,7) # 4個樣本,共7類
x1=x.clone() # 4個樣本,共7類
x2=x.clone() # 4個樣本,共7類
x1.requires_grad=True
x2.requires_grad=True
y=torch.LongTensor([1,3,5,0]) # 對應(yīng)的標(biāo)簽
criterion = torch.nn.CrossEntropyLoss()
out = criterion(x1,y)
out.backward()
print('pytorch庫 loss:', out, 'grad:', x1.grad)
gt = torch.zeros(4,7).scatter(1, y.view(4,1),1)
loss = -(torch.log(softmax(x2, dim=-1)) * gt).sum() / 4 # 對樣本求平均
loss.backward()
print('手動實現(xiàn) loss:', loss, 'grad:', x2.grad)
eta = (torch.softmax(x, dim=-1) - gt) / 4
print('直接計算 grad:', eta)
輸出:
