pytorch 常用函數(shù)參數(shù)詳解

1、torch.cat(inputs, dim=0) -> Tensor?

參考鏈接:

[Pytorch] 詳解 torch.cat()

Pytorch學(xué)習(xí)筆記(一):torch.cat()模塊的詳解

函數(shù)作用:cat 是 concatnate 的意思:拼接,聯(lián)系在一起。在給定維度上對輸入的 Tensor 序列進(jìn)行拼接操作。torch.cat 可以看作是 torch.split 和 torch.chunk 的反操作

參數(shù):

inputs(sequence of Tensors):可以是任意相同類型的 Tensor 的 python 序列

dim(int, optional):defaults=0

dim=0: 按列進(jìn)行拼接?

dim=1: 按行進(jìn)行拼接

dim=-1: 如果行和列數(shù)都相同則按行進(jìn)行拼接,否則按照行數(shù)或列數(shù)相等的維度進(jìn)行拼接

假設(shè) a 和 b 都是 Tensor,且 a 的維度為 [2, 3],b 的維度為 [2, 4],則

torch.cat((a, b), dim=1) 的維度為 [2, 7]


2、torch.nn.CrossEntropyLoss()

函數(shù)作用:CrossEntropy 是交叉熵的意思,故而 CrossEntropyLoss 的作用是計算交叉熵。CrossEntropyLoss 函數(shù)是將 torch.nn.Softmax 和 torch.nn.NLLLoss 兩個函數(shù)組合在一起使用,故而傳入的預(yù)測值不需要先進(jìn)行 torch.nnSoftmax 操作。

參數(shù):

input(N, C):N 是 batch_size,C 則是類別數(shù),即在定義模型輸出時,輸出節(jié)點個數(shù)要定義為 [N, C]。其中特別注意的是 target 的數(shù)據(jù)類型需要是浮點數(shù),即 float32

target(N):N 是 batch_size,故 target 需要是 1D 張量。其中特別注意的是 target 的數(shù)據(jù)類型需要是 long,即 int64

例子:

loss = nn.CrossEntropyLoss()

input = torch.randn(3, 5, requires_grad=True, dtype=torch.float32)

target = torch.empty(3, dtype=torch.long).random_(5)

output = loss(input, target)

output

輸出為:

tensor(1.6916, grad_fn=<NllLossBackward>)

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容