官方給的用法:
scatter(dim, index, src)
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
一個例子
import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)
輸出:
tensor([[ 0.0461, 0.4024, -1.0115, 0.2167],
[-0.6123, 0.5036, 0.2310, 0.6931]])
tensor([[ 0.2167, 0.4024, -1.0115, 0.0461, 0.0000],
[ 0.2310, -0.6123, 0.5036, 0.6931, 0.0000]])
scatter(scatter_)是將input tensor按照index賦值給output tensor來達(dá)到更新output的效果的。
我們從index下手,
index[0][0]=3,由于dim=1,那么我們?nèi)nput[0][0]=0.0461, 賦值給output[0][index[0][0]]=output[0][3]
index[0][3]=0, input[0][3]=0.2167,賦值給output[0][inde[0][3]]=output[0][0]
index[1][2]=0, input[1][2]=0.2310, 賦值給output[1][index[1][2]]=output[1][0]
index[1][3]=3, input[1][3]=0.6931, 賦值給output[1][index[1][3]]=output[1][3]
也就是index的下標(biāo)和input的下標(biāo)是一致的,取出來的這個值賦值給誰呢,這個是index對應(yīng)的值以及dim來確定的,如果dim=1, 那么更新的是output[i][index[i][j]]=input[i][j],官方文檔給的是三維的情況,dim是多少,那么index的值就放在第幾維。
scatter一個很重要的應(yīng)用就是生成one-hot矩陣
假設(shè)總共有5類,現(xiàn)在一個batch有3個樣本,分別對應(yīng)的標(biāo)簽為1,2,0。那么生成的one-hot矩陣應(yīng)該是這樣的:
index=torch.tensor([[1], [2], [0]])
y=torch.zeros(3, 5)
y=y.scatter(1, index, 1)
print(y)
輸出:
tensor([[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.]])