
一. torch.repeat_interleave()函數(shù)解析
1.函數(shù)說(shuō)明
官網(wǎng):torch.repeat_interleave(),函數(shù)說(shuō)明如下圖所示:

2. 函數(shù)原型
torch.repeat_interleave(input, repeats, dim=None) → Tensor
3. 函數(shù)功能
沿著指定的維度重復(fù)張量的元素
4. 輸入?yún)?shù):
1)input (類型:torch.Tensor):輸入張量
2)repeats(類型:int或torch.Tensor):每個(gè)元素的重復(fù)次數(shù)
3)dim(類型:int)需要重復(fù)的維度。默認(rèn)情況下dim=None,表示將把給定的輸入張量展平(flatten)為向量,然后將每個(gè)元素重復(fù)repeats次,并返回重復(fù)后的張量。
5. 注意
1)如果不指定dim,則默認(rèn)將輸入張量扁平化(維數(shù)是1,因此這時(shí)repeats必須是一個(gè)數(shù),不能是數(shù)組),并且返回一個(gè)扁平化的輸出數(shù)組。
2)返回的數(shù)組與輸入數(shù)組維數(shù)相同,并且除了給定的維度dim,其他維度大小與輸入數(shù)組相應(yīng)維度大小相同
3)repeats:如果傳入數(shù)組,則必須是tensor格式。并且只能是一維數(shù)組,數(shù)組長(zhǎng)度與輸入數(shù)組input的dim維度大小相同
6. 代碼例子
6.1 輸入一維張量,不指定dim,重復(fù)次數(shù)為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個(gè)元素重復(fù)2次,并返回重復(fù)后的張量。
a = torch.randn(5)
a,torch.repeat_interleave(a,2)
輸出結(jié)果如下所示:
(tensor([ 0.4030, -1.1536, -2.4513, 1.1454, -0.8818]),
tensor([ 0.4030, 0.4030, -1.1536, -1.1536, -2.4513, -2.4513, 1.1454, 1.1454,
-0.8818, -0.8818]))
6.2 輸入二維張量,不指定dim,重復(fù)次數(shù)為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個(gè)元素重復(fù)2次,并返回重復(fù)后的張量。
a = torch.randn(3,2)
a,a.repeat_interleave(2)
輸出結(jié)果如下:
(tensor([[-1.03, -0.32],
[ 0.43, 0.78],
[ 0.91, -0.11]]),
tensor([-1.03, -1.03, -0.32, -0.32, 0.43, 0.43, 0.78, 0.78, 0.91, 0.91,
-0.11, -0.11]))
6.3 輸入二維張量,指定dim=0,重復(fù)次數(shù)為3次,表示把輸入張量每行元素重復(fù)3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)
輸出結(jié)果如下:
(tensor([[ 0.14, 1.47],
[-1.52, -0.62],
[-0.24, -0.27]]),
tensor([[ 0.14, 1.47],
[ 0.14, 1.47],
[ 0.14, 1.47],
[-1.52, -0.62],
[-1.52, -0.62],
[-1.52, -0.62],
[-0.24, -0.27],
[-0.24, -0.27],
[-0.24, -0.27]]))
6.4 輸入二維張量,指定dim=1,重復(fù)次數(shù)為3次,表示把輸入張量每列元素重復(fù)3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)
輸出結(jié)果如下:
(tensor([[-0.81, 0.56],
[-2.41, -0.56],
[ 0.38, -0.90]]),
tensor([[-0.81, -0.81, -0.81, 0.56, 0.56, 0.56],
[-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
[ 0.38, 0.38, 0.38, -0.90, -0.90, -0.90]]))
6.5 輸入二維張量,指定dim=0,重復(fù)次數(shù)為一個(gè)張量列表[n1,n2,n3],表示在(dim=0)對(duì)應(yīng)行上面重復(fù)n1,n2,n3遍,張量列表的長(zhǎng)度必須與dim=0的維度的長(zhǎng)度一樣,否則會(huì)報(bào)錯(cuò)
a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)#表示第一行重復(fù)2遍,第二行重復(fù)3遍,第三行重復(fù)4遍
輸出結(jié)果如下:
(tensor([[-0.79, 0.54],
[-0.47, -0.25],
[-0.13, 1.03]]),
tensor([[-0.79, 0.54],
[-0.79, 0.54],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03]]))
7. 與torch.repeat()函數(shù)區(qū)別:
兩個(gè)函數(shù)方法最大的區(qū)別就是repeat_interleave是一個(gè)元素一個(gè)元素地重復(fù),而repeat是一組元素一組元素地重復(fù).
參考知識(shí)文章