Pytorch:Tensor的合并與分割

關(guān)鍵方法一覽

方法 作用 區(qū)別
cat 合并 保持原有維度的數(shù)量
stack 合并 原有維度數(shù)量加1
split 分割 按照長度去分割
chunk 分割 等分

要點細述

cat

catconcatenate(連接)的縮寫,而不是指(貓)。作用是把2個tensor按照特定的維度連接起來。
要求:除被拼接的維度外,其他維度必須相同

Code Demo

import torch
a=torch.randn(3,4) #隨機生成一個shape(3,4)的tensort
b=torch.randn(2,4) #隨機生成一個shape(2,4)的tensor

torch.cat([a,b],dim=0) 
#返回一個shape(5,4)的tensor
#把a和b拼接成一個shape(5,4)的tensor,
#可理解為沿著行增加的方向(即縱向)拼接

stack

stack會增加一個新的維度,來表示拼接后的2個tensor,直觀些理解的話,咱們不妨把一個2維的tensor理解成一張長方形的紙張,cat相當于是把兩張紙縫合在一起,形成一張更大的紙,而stack相當于是把兩張紙上下堆疊在一起。
要求:兩個tensor拼接前的形狀完全一致

Code Demo

a=torch.randn(3,4)
b=torch.randn(3,4)

c=torch.stack([a,b],dim=0)
#返回一個shape(2,3,4)的tensor,新增的維度2分別指向a和b

d=torch.stack([a,b],dim=1)
#返回一個shape(3,2,4)的tensor,新增的維度2分別指向相應(yīng)的a的第i行和b的第i行

助記:
這里的關(guān)鍵詞參數(shù)dim的理解和cat方法中有些區(qū)別。

cat方法中可以理解為原tensor的維度,dim=0,就是沿著原來的0軸進行拼接,dim=1,就是沿著原來的1軸進行拼接。

stack方法中的dim則是指向新增維度的位置,dim=0,就是在新形成的tensor的維度的第0個位置新插入維度

split

split是根據(jù)長度去拆分tensor

Code Demo

a=torch.randn(3,4)

a.split([1,2],dim=0)
#把維度0按照長度[1,2]拆分,形成2個tensor,
#shape(1,4)和shape(2,4)

a.split([2,2],dim=1)
#把維度1按照長度[2,2]拆分,形成2個tensor,
#shape(3,2)和shape(3,2)

chunk

chunk可以理解為均等分的split,但是當維度長度不能被等分份數(shù)整除時,雖然不會報錯,但可能結(jié)果與預(yù)期的不一樣,建議只在可以被整除的情況下運用

Code Demo

a=torch.randn(4,6)

a.chunk(2,dim=0)
#返回一個shape(2,6)的tensor
a.chunk(2,dim=1)
#返回一個shape(4,3)的tensor
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容