Pytorch中torch.repeat_interleave()函數(shù)解析

一. torch.repeat_interleave()函數(shù)解析

1.函數(shù)說(shuō)明

官網(wǎng)torch.repeat_interleave(),函數(shù)說(shuō)明如下圖所示:

2. 函數(shù)原型

torch.repeat_interleave(input, repeats, dim=None) → Tensor

3. 函數(shù)功能

沿著指定的維度重復(fù)張量的元素

4. 輸入?yún)?shù):

1)input (類型:torch.Tensor):輸入張量
2)repeats(類型:int或torch.Tensor):每個(gè)元素的重復(fù)次數(shù)
3)dim(類型:int)需要重復(fù)的維度。默認(rèn)情況下dim=None,表示將把給定的輸入張量展平(flatten)為向量,然后將每個(gè)元素重復(fù)repeats次,并返回重復(fù)后的張量。

5. 注意

1)如果不指定dim,則默認(rèn)將輸入張量扁平化(維數(shù)是1,因此這時(shí)repeats必須是一個(gè)數(shù),不能是數(shù)組),并且返回一個(gè)扁平化的輸出數(shù)組。
2)返回的數(shù)組與輸入數(shù)組維數(shù)相同,并且除了給定的維度dim,其他維度大小與輸入數(shù)組相應(yīng)維度大小相同
3)repeats:如果傳入數(shù)組,則必須是tensor格式。并且只能是一維數(shù)組,數(shù)組長(zhǎng)度與輸入數(shù)組input的dim維度大小相同

6. 代碼例子

6.1 輸入一維張量,不指定dim,重復(fù)次數(shù)為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個(gè)元素重復(fù)2次,并返回重復(fù)后的張量。

a = torch.randn(5)
a,torch.repeat_interleave(a,2)

輸出結(jié)果如下所示:
(tensor([ 0.4030, -1.1536, -2.4513,  1.1454, -0.8818]),
 tensor([ 0.4030,  0.4030, -1.1536, -1.1536, -2.4513, -2.4513,  1.1454,  1.1454,
         -0.8818, -0.8818]))

6.2 輸入二維張量,不指定dim,重復(fù)次數(shù)為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個(gè)元素重復(fù)2次,并返回重復(fù)后的張量。

a = torch.randn(3,2)
a,a.repeat_interleave(2)

輸出結(jié)果如下:
(tensor([[-1.03, -0.32],
         [ 0.43,  0.78],
         [ 0.91, -0.11]]),
 tensor([-1.03, -1.03, -0.32, -0.32,  0.43,  0.43,  0.78,  0.78,  0.91,  0.91,
         -0.11, -0.11]))

6.3 輸入二維張量,指定dim=0,重復(fù)次數(shù)為3次,表示把輸入張量每行元素重復(fù)3次

a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)

輸出結(jié)果如下:
(tensor([[ 0.14,  1.47],
         [-1.52, -0.62],
         [-0.24, -0.27]]),
 tensor([[ 0.14,  1.47],
         [ 0.14,  1.47],
         [ 0.14,  1.47],
         [-1.52, -0.62],
         [-1.52, -0.62],
         [-1.52, -0.62],
         [-0.24, -0.27],
         [-0.24, -0.27],
         [-0.24, -0.27]]))

6.4 輸入二維張量,指定dim=1,重復(fù)次數(shù)為3次,表示把輸入張量每列元素重復(fù)3次

a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)

輸出結(jié)果如下:
(tensor([[-0.81,  0.56],
         [-2.41, -0.56],
         [ 0.38, -0.90]]),
 tensor([[-0.81, -0.81, -0.81,  0.56,  0.56,  0.56],
         [-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
         [ 0.38,  0.38,  0.38, -0.90, -0.90, -0.90]]))

6.5 輸入二維張量,指定dim=0,重復(fù)次數(shù)為一個(gè)張量列表[n1,n2,n3],表示在(dim=0)對(duì)應(yīng)行上面重復(fù)n1,n2,n3遍,張量列表的長(zhǎng)度必須與dim=0的維度的長(zhǎng)度一樣,否則會(huì)報(bào)錯(cuò)

a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)#表示第一行重復(fù)2遍,第二行重復(fù)3遍,第三行重復(fù)4遍

輸出結(jié)果如下:
(tensor([[-0.79,  0.54],
         [-0.47, -0.25],
         [-0.13,  1.03]]),
 tensor([[-0.79,  0.54],
         [-0.79,  0.54],
         [-0.47, -0.25],
         [-0.47, -0.25],
         [-0.47, -0.25],
         [-0.13,  1.03],
         [-0.13,  1.03],
         [-0.13,  1.03],
         [-0.13,  1.03]]))

7. 與torch.repeat()函數(shù)區(qū)別:

兩個(gè)函數(shù)方法最大的區(qū)別就是repeat_interleave是一個(gè)元素一個(gè)元素地重復(fù),而repeat是一組元素一組元素地重復(fù).

參考知識(shí)文章

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容