Pytorch中torch.unsqueeze()和torch.squeeze()函數(shù)解析

一. torch.squeeze()函數(shù)解析

1. 官網(wǎng)鏈接

torch.squeeze(),如下圖所示:

2. torch.squeeze()函數(shù)解析

torch.squeeze(input, dim=None, out=None) 

squeeze()函數(shù)的功能是維度壓縮。返回一個(gè)tensor(張量),其中 input 中維度大小為1的所有維都已刪除。
舉個(gè)例子:如果 input 的形狀為 (A×1×B×C×1×D),那么返回的tensor的形狀則為 (A×B×C×D)
當(dāng)給定 dim 時(shí),那么只在給定的維度(dimension)上進(jìn)行壓縮操作,注意給定的維度大小必須是1,否則不能進(jìn)行壓縮。
舉個(gè)例子:如果 input 的形狀為 (A×1×B),squeeze(input, dim=0)后,返回的tensor不變,因?yàn)榈?維的大小為A,不是1;squeeze(input, 1)后,返回的tensor將被壓縮為 (A×B)。

3. 代碼舉例

3.1 輸入size=(2, 1, 2, 1, 2)的張量

x = torch.randn(size=(2, 1, 2, 1, 2))
x.shape

輸出結(jié)果如下:
torch.Size([2, 1, 2, 1, 2])

3.2 把x中維度大小為1的所有維都已刪除

y = torch.squeeze(x)#表示把x中維度大小為1的所有維都已刪除
y.shape

輸出結(jié)果如下:
torch.Size([2, 2, 2])

3.3 把x中第一維刪除,但是第一維大小為2,不為1,因此結(jié)果刪除不掉

y = torch.squeeze(x,0)#表示把x中第一維刪除,但是第一維大小為2,不為1,因此結(jié)果刪除不掉
y.shape

輸出結(jié)果如下:
torch.Size([2, 1, 2, 1, 2])

3.4 把x中第二維刪除,因?yàn)榈诙S大小是1,因此可以刪掉

y = torch.squeeze(x,1)#表示把x中第二維刪除,因?yàn)榈诙S大小是1,因此可以刪掉
y.shape

輸出結(jié)果如下:
torch.Size([2, 2, 1, 2])

3.5 把x中最后一維刪除,但是最后一維大小為2,不為1,因此結(jié)果刪除不掉

y = torch.squeeze(x,dim=-1)#表示把x中最后一維刪除,但是最后一維大小為2,不為1,因此結(jié)果刪除不掉
y.shape

輸出結(jié)果如下:
torch.Size([2, 1, 2, 1, 2])

二.torch.unsqueeze()函數(shù)解析

1. 官網(wǎng)鏈接

torch.unsqueeze(),如下圖所示:

2. torch.unsqueeze()函數(shù)解析

torch.unsqueeze(input, dim) → Tensor

unsqueeze()函數(shù)起升維的作用,參數(shù)dim表示在哪個(gè)地方加一個(gè)維度,注意dim范圍在:[-input.dim() - 1, input.dim() + 1]之間,比如輸入input是一維,則dim=0時(shí)數(shù)據(jù)為行方向擴(kuò),dim=1時(shí)為列方向擴(kuò),再大錯(cuò)誤。

3. 代碼舉例

3.1 輸入一維張量,在第0維(行)擴(kuò)展,第0維大小為1

x = torch.tensor([1, 2, 3, 4])
y = torch.unsqueeze(x, 0)#在第0維擴(kuò)展,第0維大小為1
y,y.shape

輸出結(jié)果如下:
(tensor([[1, 2, 3, 4]]), torch.Size([1, 4]))

3.2 在第1維(列)擴(kuò)展,第1維大小為1

y = torch.unsqueeze(x, 1)#在第1維擴(kuò)展,第1維大小為1
y,y.shape

輸出結(jié)果如下:
(tensor([[1],
         [2],
         [3],
         [4]]),
 torch.Size([4, 1]))

3.3 在第最后一維(也就是倒數(shù)第一維進(jìn)行)擴(kuò)展,最后一維大小為1

y = torch.unsqueeze(x, -1)#在第最后一維擴(kuò)展,最后一維大小為1
y,y.shape

輸出結(jié)果如下:
(tensor([[1],
         [2],
         [3],
         [4]]),
 torch.Size([4, 1]))

參考知識(shí)文章

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容