關(guān)鍵方法一覽
| 方法 | 作用 | 區(qū)別 |
|---|---|---|
| cat | 合并 | 保持原有維度的數(shù)量 |
| stack | 合并 | 原有維度數(shù)量加1 |
| split | 分割 | 按照長(zhǎng)度去分割 |
| chunk | 分割 | 等分 |
要點(diǎn)細(xì)述
cat
cat是concatenate(連接)的縮寫,而不是指(貓)。作用是把2個(gè)tensor按照特定的維度連接起來。
要求:除被拼接的維度外,其他維度必須相同
Code Demo
import torch
a=torch.randn(3,4) #隨機(jī)生成一個(gè)shape(3,4)的tensort
b=torch.randn(2,4) #隨機(jī)生成一個(gè)shape(2,4)的tensor
torch.cat([a,b],dim=0)
#返回一個(gè)shape(5,4)的tensor
#把a(bǔ)和b拼接成一個(gè)shape(5,4)的tensor,
#可理解為沿著行增加的方向(即縱向)拼接
stack
stack會(huì)增加一個(gè)新的維度,來表示拼接后的2個(gè)tensor,直觀些理解的話,咱們不妨把一個(gè)2維的tensor理解成一張長(zhǎng)方形的紙張,cat相當(dāng)于是把兩張紙縫合在一起,形成一張更大的紙,而stack相當(dāng)于是把兩張紙上下堆疊在一起。
要求:兩個(gè)tensor拼接前的形狀完全一致
Code Demo
a=torch.randn(3,4)
b=torch.randn(3,4)
c=torch.stack([a,b],dim=0)
#返回一個(gè)shape(2,3,4)的tensor,新增的維度2分別指向a和b
d=torch.stack([a,b],dim=1)
#返回一個(gè)shape(3,2,4)的tensor,新增的維度2分別指向相應(yīng)的a的第i行和b的第i行
助記:
這里的關(guān)鍵詞參數(shù)dim的理解和cat方法中有些區(qū)別。
cat方法中可以理解為原tensor的維度,dim=0,就是沿著原來的0軸進(jìn)行拼接,dim=1,就是沿著原來的1軸進(jìn)行拼接。
stack方法中的dim則是指向新增維度的位置,dim=0,就是在新形成的tensor的維度的第0個(gè)位置新插入維度
split
split是根據(jù)長(zhǎng)度去拆分tensor
Code Demo
a=torch.randn(3,4)
a.split([1,2],dim=0)
#把維度0按照長(zhǎng)度[1,2]拆分,形成2個(gè)tensor,
#shape(1,4)和shape(2,4)
a.split([2,2],dim=1)
#把維度1按照長(zhǎng)度[2,2]拆分,形成2個(gè)tensor,
#shape(3,2)和shape(3,2)
chunk
chunk可以理解為均等分的split,但是當(dāng)維度長(zhǎng)度不能被等分份數(shù)整除時(shí),雖然不會(huì)報(bào)錯(cuò),但可能結(jié)果與預(yù)期的不一樣,建議只在可以被整除的情況下運(yùn)用
Code Demo
a=torch.randn(4,6)
a.chunk(2,dim=0)
#返回一個(gè)shape(2,6)的tensor
a.chunk(2,dim=1)
#返回一個(gè)shape(4,3)的tensor