【Pytorch】torch.stack()的使用

pytorch中,常見的拼接函數(shù)主要是兩個,分別是:

  1. stack()

  2. cat()

實際使用中,這兩個函數(shù)互相輔助:關(guān)于cat()參考torch.cat(),但是本文主要說stack()。

函數(shù)的意義:使用stack可以保留兩個信息:[1. 序列] 和 [2. 張量矩陣] 信息,屬于【擴張再拼接】的函數(shù);可以認(rèn)為把一個個矩陣按時間序列壓緊成一個矩陣。 常出現(xiàn)在自然語言處理(NLP)和圖像卷積神經(jīng)網(wǎng)絡(luò)(CV)中。

1. stack()

官方解釋:沿著一個新維度對輸入張量序列進行連接。 序列中所有的張量都應(yīng)該為相同形狀。

淺顯說法:把多個2維的張量湊成一個3維的張量;多個3維的湊成一個4維的張量…以此類推,也就是在增加新的維度進行堆疊。

outputs = torch.stack(inputs, dim=0) → Tensor

參數(shù)

  • inputs : 待連接的張量序列。 注:python的序列數(shù)據(jù)只有listtuple。

  • dim : 新的維度, 必須在0len(outputs)之間。 注:len(outputs)是生成數(shù)據(jù)的維度大小,也就是outputs的維度值。

2. 重點

  1. 函數(shù)中的輸入inputs只允許是序列;且序列內(nèi)部的張量元素,必須shape相等

----舉例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必須tensor_1.shape == tensor_2.shape

  1. dim是選擇生成的維度,必須滿足0<=dim<len(outputs);len(outputs)是輸出后的tensor的維度大小

不懂的看例子,再回過頭看就懂了。

3. 例子

1.準(zhǔn)備2個tensor數(shù)據(jù),每個的shape都是[3,3]

 T1 = torch.tensor([[1, 2, 3],
  [4, 5, 6],
  [7, 8, 9]])
 # 假設(shè)是時間步T2
 T2 = torch.tensor([[10, 20, 30],
  [40, 50, 60],
  [70, 80, 90]])

2.測試stack函數(shù)

 print(torch.stack((T1,T2),dim=0).shape)
 print(torch.stack((T1,T2),dim=1).shape)
 print(torch.stack((T1,T2),dim=2).shape)
 print(torch.stack((T1,T2),dim=3).shape)
 # outputs:
 torch.Size([2, 3, 3])
 torch.Size([3, 2, 3])
 torch.Size([3, 3, 2])
 '選擇的dim>len(outputs),所以報錯'
 IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

可以復(fù)制代碼運行試試:拼接后的tensor形狀,會根據(jù)不同的dim發(fā)生變化。

dim shape
0 [2, 3, 3]
1 [3, 2, 3]
2 [3, 3, 2]
3 溢出報錯

4. 總結(jié)

  1. 函數(shù)作用: 函數(shù)stack()序列數(shù)據(jù)內(nèi)部的張量進行擴維拼接,指定維度由程序員選擇、大小是生成后數(shù)據(jù)的維度區(qū)間。

  2. 存在意義: 在自然語言處理和卷及神經(jīng)網(wǎng)絡(luò)中, 通常為了保留–[序列(先后)信息] 和 [張量的矩陣信息] 才會使用stack。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容