Pytorch中torch.sort()和torch.argsort()函數(shù)解析

一. torch.sort()函數(shù)解析

1. 官網(wǎng)鏈接

torch.sort(),如下圖所示:

2. torch.sort()函數(shù)解析

torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

輸入input,在dim維進(jìn)行排序,默認(rèn)是dim=-1對(duì)最后一維進(jìn)行排序,descending表示是否按降序排,默認(rèn)為False,輸出排序后的值以及對(duì)應(yīng)值在原輸入imput中的下標(biāo)

3. 代碼舉例

3.1 dim = -1 表示對(duì)每行中的元素進(jìn)行升序排序,descending=False表示升序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x)
x,sorted,indices

輸出結(jié)果如下:
(tensor([[-1.3864,  0.5811, -0.1056, -0.3237],
         [-0.2136, -1.4806,  0.4986,  0.9382],
         [-0.2820,  0.1171, -0.3983, -0.8061]]),
 tensor([[-1.3864, -0.3237, -0.1056,  0.5811],
         [-1.4806, -0.2136,  0.4986,  0.9382],
         [-0.8061, -0.3983, -0.2820,  0.1171]]),
 tensor([[0, 3, 2, 1],
         [1, 0, 2, 3],
         [3, 2, 0, 1]]))

3.2 dim = 0 表示對(duì)每列中的元素進(jìn)行升序排序,descending=False表示升序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x,dim=0)
x,sorted,indices

輸出結(jié)果如下:
(tensor([[ 0.7081,  1.0502,  2.0434, -0.2592],
         [ 1.2052,  0.8809,  0.5771,  1.2978],
         [-1.5873, -0.4808, -2.1774, -0.2503]]),
 tensor([[-1.5873, -0.4808, -2.1774, -0.2592],
         [ 0.7081,  0.8809,  0.5771, -0.2503],
         [ 1.2052,  1.0502,  2.0434,  1.2978]]),
 tensor([[2, 2, 2, 0],
         [0, 1, 1, 2],
         [1, 0, 0, 1]]))

3.3 dim = 0 表示對(duì)每列中的元素進(jìn)行降序排序,descending=True表示降序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x,dim=0,descending=True)
x,sorted,indices

輸出結(jié)果如下:
(tensor([[ 0.9142, -0.2178,  0.5602,  2.3951],
         [-0.6977,  0.4915,  0.3988,  0.6406],
         [ 0.4880,  1.1646, -0.3466,  0.5801]]),
 tensor([[ 0.9142,  1.1646,  0.5602,  2.3951],
         [ 0.4880,  0.4915,  0.3988,  0.6406],
         [-0.6977, -0.2178, -0.3466,  0.5801]]),
 tensor([[0, 2, 0, 0],
         [2, 1, 1, 1],
         [1, 0, 2, 2]]))

3.4 dim = 1 表示對(duì)每行中的元素進(jìn)行降序排序,descending=True表示降序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x,dim=1,descending=True)
x,sorted,indices

輸出結(jié)果如下:
(tensor([[-0.3048, -1.9915, -0.0888,  0.3881],
         [ 1.0677, -1.3520,  0.2944, -0.0772],
         [-0.9409, -0.9630, -0.7946,  1.4400]]),
 tensor([[ 0.3881, -0.0888, -0.3048, -1.9915],
         [ 1.0677,  0.2944, -0.0772, -1.3520],
         [ 1.4400, -0.7946, -0.9409, -0.9630]]),
 tensor([[3, 2, 0, 1],
         [0, 2, 3, 1],
         [3, 2, 0, 1]]))

二.torch.argsort()函數(shù)解析

1. 官網(wǎng)鏈接

torch.argsort(),如下圖所示:

image.png

2. torch.argsort()函數(shù)解析

用法跟上面torch.sort()函數(shù)一樣,不同的是torch.argsort()返回只是排序后的值所對(duì)應(yīng)原輸入input的下標(biāo),即torch.sort()返回的indices

3. 代碼舉例

dim = 1 表示對(duì)每行中的元素進(jìn)行降序排序,descending=True表示降序排序,輸出結(jié)果為返回排序后的值所對(duì)應(yīng)原輸入input的下標(biāo)indices

x = torch.randn(3, 4)
indices = torch.argsort(x,dim=1,descending=True)
x,indices

輸出結(jié)果如下:
(tensor([[-0.6069, -0.9252, -0.9177,  0.6997],
         [ 0.3245, -0.0665,  0.4600,  0.0722],
         [-1.0662,  2.2669, -0.1171, -0.9208]]),
 tensor([[3, 0, 2, 1],
         [2, 0, 3, 1],
         [1, 2, 3, 0]]))

參考知識(shí)文章

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容