
一. torch.squeeze()函數(shù)解析
1. 官網(wǎng)鏈接
torch.squeeze(),如下圖所示:

2. torch.squeeze()函數(shù)解析
torch.squeeze(input, dim=None, out=None)
squeeze()函數(shù)的功能是維度壓縮。返回一個(gè)tensor(張量),其中 input 中維度大小為1的所有維都已刪除。
舉個(gè)例子:如果 input 的形狀為 (A×1×B×C×1×D),那么返回的tensor的形狀則為 (A×B×C×D)
當(dāng)給定 dim 時(shí),那么只在給定的維度(dimension)上進(jìn)行壓縮操作,注意給定的維度大小必須是1,否則不能進(jìn)行壓縮。
舉個(gè)例子:如果 input 的形狀為 (A×1×B),squeeze(input, dim=0)后,返回的tensor不變,因?yàn)榈?維的大小為A,不是1;squeeze(input, 1)后,返回的tensor將被壓縮為 (A×B)。
3. 代碼舉例
3.1 輸入size=(2, 1, 2, 1, 2)的張量
x = torch.randn(size=(2, 1, 2, 1, 2))
x.shape
輸出結(jié)果如下:
torch.Size([2, 1, 2, 1, 2])
3.2 把x中維度大小為1的所有維都已刪除
y = torch.squeeze(x)#表示把x中維度大小為1的所有維都已刪除
y.shape
輸出結(jié)果如下:
torch.Size([2, 2, 2])
3.3 把x中第一維刪除,但是第一維大小為2,不為1,因此結(jié)果刪除不掉
y = torch.squeeze(x,0)#表示把x中第一維刪除,但是第一維大小為2,不為1,因此結(jié)果刪除不掉
y.shape
輸出結(jié)果如下:
torch.Size([2, 1, 2, 1, 2])
3.4 把x中第二維刪除,因?yàn)榈诙S大小是1,因此可以刪掉
y = torch.squeeze(x,1)#表示把x中第二維刪除,因?yàn)榈诙S大小是1,因此可以刪掉
y.shape
輸出結(jié)果如下:
torch.Size([2, 2, 1, 2])
3.5 把x中最后一維刪除,但是最后一維大小為2,不為1,因此結(jié)果刪除不掉
y = torch.squeeze(x,dim=-1)#表示把x中最后一維刪除,但是最后一維大小為2,不為1,因此結(jié)果刪除不掉
y.shape
輸出結(jié)果如下:
torch.Size([2, 1, 2, 1, 2])
二.torch.unsqueeze()函數(shù)解析
1. 官網(wǎng)鏈接
torch.unsqueeze(),如下圖所示:

2. torch.unsqueeze()函數(shù)解析
torch.unsqueeze(input, dim) → Tensor
unsqueeze()函數(shù)起升維的作用,參數(shù)dim表示在哪個(gè)地方加一個(gè)維度,注意dim范圍在:[-input.dim() - 1, input.dim() + 1]之間,比如輸入input是一維,則dim=0時(shí)數(shù)據(jù)為行方向擴(kuò),dim=1時(shí)為列方向擴(kuò),再大錯(cuò)誤。
3. 代碼舉例
3.1 輸入一維張量,在第0維(行)擴(kuò)展,第0維大小為1
x = torch.tensor([1, 2, 3, 4])
y = torch.unsqueeze(x, 0)#在第0維擴(kuò)展,第0維大小為1
y,y.shape
輸出結(jié)果如下:
(tensor([[1, 2, 3, 4]]), torch.Size([1, 4]))
3.2 在第1維(列)擴(kuò)展,第1維大小為1
y = torch.unsqueeze(x, 1)#在第1維擴(kuò)展,第1維大小為1
y,y.shape
輸出結(jié)果如下:
(tensor([[1],
[2],
[3],
[4]]),
torch.Size([4, 1]))
3.3 在第最后一維(也就是倒數(shù)第一維進(jìn)行)擴(kuò)展,最后一維大小為1
y = torch.unsqueeze(x, -1)#在第最后一維擴(kuò)展,最后一維大小為1
y,y.shape
輸出結(jié)果如下:
(tensor([[1],
[2],
[3],
[4]]),
torch.Size([4, 1]))
參考知識(shí)文章