pytorch scatter的用法

官方給的用法:

scatter(dim, index, src)
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

一個例子

import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)

輸出:

tensor([[ 0.0461,  0.4024, -1.0115,  0.2167],
        [-0.6123,  0.5036,  0.2310,  0.6931]])
tensor([[ 0.2167,  0.4024, -1.0115,  0.0461,  0.0000],
        [ 0.2310, -0.6123,  0.5036,  0.6931,  0.0000]])

scatter(scatter_)是將input tensor按照index賦值給output tensor來達(dá)到更新output的效果的。
我們從index下手,
index[0][0]=3,由于dim=1,那么我們?nèi)nput[0][0]=0.0461, 賦值給output[0][index[0][0]]=output[0][3]
index[0][3]=0, input[0][3]=0.2167,賦值給output[0][inde[0][3]]=output[0][0]
index[1][2]=0, input[1][2]=0.2310, 賦值給output[1][index[1][2]]=output[1][0]
index[1][3]=3, input[1][3]=0.6931, 賦值給output[1][index[1][3]]=output[1][3]

也就是index的下標(biāo)和input的下標(biāo)是一致的,取出來的這個值賦值給誰呢,這個是index對應(yīng)的值以及dim來確定的,如果dim=1, 那么更新的是output[i][index[i][j]]=input[i][j],官方文檔給的是三維的情況,dim是多少,那么index的值就放在第幾維。
scatter一個很重要的應(yīng)用就是生成one-hot矩陣
假設(shè)總共有5類,現(xiàn)在一個batch有3個樣本,分別對應(yīng)的標(biāo)簽為1,2,0。那么生成的one-hot矩陣應(yīng)該是這樣的:

index=torch.tensor([[1], [2], [0]])
y=torch.zeros(3, 5)
y=y.scatter(1, index, 1)
print(y)

輸出:

tensor([[0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.]])
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 介紹 相比TensorFlow的靜態(tài)圖開發(fā),Pytorch的動態(tài)圖特性使得開發(fā)起來更加人性化,選擇Pytorch的...
    dawsonenjoy閱讀 25,223評論 2 18
  • stack使用stack是為了保留兩個信息: 序列(先后)和 張量矩陣信息。比如在循環(huán)神經(jīng)網(wǎng)絡(luò)中,網(wǎng)絡(luò)的輸出數(shù)據(jù)...
    lzjngu閱讀 477評論 0 0
  • 1.pytorch中的索引 index_select(x, dim, indices)dim代表維度,indice...
    yumiii_閱讀 5,505評論 0 0
  • scatter_(input, dim, index, src)將src中數(shù)據(jù)根據(jù)index中的索引按照dim的方...
    cjhfhb閱讀 6,078評論 0 0
  • 我是黑夜里大雨紛飛的人啊 1 “又到一年六月,有人笑有人哭,有人歡樂有人憂愁,有人驚喜有人失落,有的覺得收獲滿滿有...
    陌忘宇閱讀 8,848評論 28 54

友情鏈接更多精彩內(nèi)容