1、torch.mul(a, b)是矩陣a和b對(duì)應(yīng)位相乘,a和b的維度必須相等,比如a的維度是(1, 2),b的維度是(1, 2),返回的仍是(1, 2)的矩陣;
2、torch.mm(a, b)是矩陣a和b矩陣相乘,比如a的維度是(1, 2),b的維度是(2, 3),返回的就是(1, 3)的矩陣。
PS:更接地氣來(lái)說(shuō)區(qū)別就是點(diǎn)乘,和矩陣乘法的區(qū)別
import torch
a = torch.rand(1, 2)
b = torch.rand(1, 2)
c = torch.rand(2, 3)
print(torch.mul(a, b)) # 返回 1*2 的tensor
print(torch.mm(a, c)) # 返回 1*3 的tensor
print(torch.mul(a, c)) # 由于a、b維度不同,報(bào)錯(cuò)
torch.bmm()
torch.matmul()
torch.bmm()強(qiáng)制規(guī)定維度和大小相同
torch.matmul()沒(méi)有強(qiáng)制規(guī)定維度和大小,可以用利用廣播機(jī)制進(jìn)行不同維度的相乘操作
當(dāng)進(jìn)行操作的兩個(gè)tensor都是3D時(shí),兩者等同。
torch.bmm()
官網(wǎng):https://pytorch.org/docs/stable/torch.html#torch.bmm
torch.bmm(input, mat2, out=None) → Tensor
torch.bmm()是tensor中的一個(gè)相乘操作,類(lèi)似于矩陣中的A*B。
參數(shù):
input,mat2:兩個(gè)要進(jìn)行相乘的tensor結(jié)構(gòu),兩者必須是3D維度的,每個(gè)維度中的大小是相同的。
output:輸出結(jié)果
并且相乘的兩個(gè)矩陣,要滿足一定的維度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。這個(gè)要求,可以類(lèi)比于矩陣相乘。前一個(gè)矩陣的列等于后面矩陣的行才可以相乘。
例子:
import torch
x = torch.rand(2,3,6)
y = torch.rand(2,6,7)
print(torch.bmm(x,y).size())
output:
torch.Size([2, 3, 7])
###############################
y = torch.rand(2,5,7) ##維度不匹配,報(bào)錯(cuò)
print(torch.bmm(x,y).size())
output:
Expected tensor to have size 6 at dimension 1, but got size 5 for argument #2 'batch2' (while checking arguments for bmm)
torch.matmul()
torch.matmul(input, other, out=None) → Tensor
torch.matmul()也是一種類(lèi)似于矩陣相乘操作的tensor聯(lián)乘操作。但是它可以利用python 中的廣播機(jī)制,處理一些維度不同的tensor結(jié)構(gòu)進(jìn)行相乘操作。這也是該函數(shù)與torch.bmm()區(qū)別所在。
參數(shù):
input,other:兩個(gè)要進(jìn)行操作的tensor結(jié)構(gòu)
output:結(jié)果
一些規(guī)則約定:
(1)若兩個(gè)都是1D(向量)的,則返回兩個(gè)向量的點(diǎn)積
import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x,y),torch.matmul(x,y).size())
output:
tensor(0.1353) torch.Size([])
(2)若兩個(gè)都是2D(矩陣)的,則按照(矩陣相乘)規(guī)則返回2D
x = torch.rand(2,4)
y = torch.rand(4,3) ###維度也要對(duì)應(yīng)才可以乘
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
output:
tensor([[0.9128, 0.8425, 0.7269],
[1.4441, 1.5334, 1.3273]])
torch.Size([2, 3])
(3)若input維度1D,other維度2D,則先將1D的維度擴(kuò)充到2D(1D的維數(shù)前面+1),然后得到結(jié)果后再將此維度去掉,得到的與input的維度相同。即使作擴(kuò)充(廣播)處理,input的維度也要和other維度做對(duì)應(yīng)關(guān)系。
import torch
x = torch.rand(4) #1D
y = torch.rand(4,3) #2D
print(x.size())
print(y.size())
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
### 擴(kuò)充x =>(,4)
### 相乘x(,4) * y(4,3) =>(,3)
### 去掉1D =>(3)
output:
torch.Size([4])
torch.Size([4, 3])
tensor([0.9600, 0.5736, 1.0430])
torch.Size([3])
(4)若input是2D,other是1D,則返回兩者的點(diǎn)積結(jié)果。(個(gè)人覺(jué)得這塊也可以理解成給other添加了維度,然后再去掉此維度,只不過(guò)維度是(3, )而不是規(guī)則(3)中的( ,4)了,但是可能就是因?yàn)閮?nèi)部機(jī)制不同,所以官方說(shuō)的是點(diǎn)積而不是維度的升高和下降)
import torch
x = torch.rand(3) #1D
y = torch.rand(4,3) #2D
print(torch.matmul(y,x),'\n',torch.matmul(y,x).size()) #2D*1D
output:
torch.Size([3])
torch.Size([4, 3])
tensor([0.8278, 0.5970, 1.0370, 0.2681])
torch.Size([4])
(5)如果一個(gè)維度至少是1D,另外一個(gè)大于2D,則返回的是一個(gè)批矩陣乘法( a batched matrix multiply)。
(a)若input是1D,other是大于2D的,則類(lèi)似于規(guī)則(3)。
import torch
x = torch.randn(2, 3, 4)
y = torch.randn(3)
print(torch.matmul(y, x),'\n',torch.matmul(y, x).size()) #1D*3D
output:
tensor([[-0.9747, -0.6660, -1.1704, -1.0522],
[ 0.0901, -1.5353, 1.5601, -0.0252]])
torch.Size([2, 4])
(b)若other是1D,input是大于2D的,則類(lèi)似于規(guī)則(4)。
import torch
x = torch.randn(2, 3, 4)
y = torch.randn(4)
print(torch.matmul(x, y),'\n',torch.matmul(x, y).size()) # 3D*1D
output:
tensor([[ 0.6217, -0.1259, -0.2377],
[ 0.6874, 0.0733, 0.1793]])
torch.Size([2, 3])
(c)若input和other都是3D的,則與torch.bmm()函數(shù)功能一樣。
import torch
x = torch.randn(2,2,4)
y = torch.randn(2,4,5)
print(torch.matmul(x, y).size(),'\n',torch.bmm(x, y).size())
print(torch.equal(torch.matmul(x,y),torch.bmm(x,y)))
output:
torch.Size([2, 2, 5])
torch.Size([2, 2, 5])
True
(d)如果input中某一維度滿足可以廣播(擴(kuò)充),那么也是可以進(jìn)行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)。
import torch
x = torch.randn(10,1,2,4)
y = torch.randn(2,4,5)
print(torch.matmul(x, y).size())
output:
torch.Size([10, 2, 2, 5])
這個(gè)例子中,可以理解為x中dim=1這個(gè)維度可以擴(kuò)充(廣播),y中可以添加一個(gè)維度,然后在進(jìn)行批乘操作。