1、torch.cat(inputs, dim=0) -> Tensor?
參考鏈接:
Pytorch學(xué)習(xí)筆記(一):torch.cat()模塊的詳解
函數(shù)作用:cat 是 concatnate 的意思:拼接,聯(lián)系在一起。在給定維度上對輸入的 Tensor 序列進(jìn)行拼接操作。torch.cat 可以看作是 torch.split 和 torch.chunk 的反操作
參數(shù):
inputs(sequence of Tensors):可以是任意相同類型的 Tensor 的 python 序列
dim(int, optional):defaults=0
dim=0: 按列進(jìn)行拼接?
dim=1: 按行進(jìn)行拼接
dim=-1: 如果行和列數(shù)都相同則按行進(jìn)行拼接,否則按照行數(shù)或列數(shù)相等的維度進(jìn)行拼接
假設(shè) a 和 b 都是 Tensor,且 a 的維度為 [2, 3],b 的維度為 [2, 4],則
torch.cat((a, b), dim=1) 的維度為 [2, 7]
2、torch.nn.CrossEntropyLoss()
函數(shù)作用:CrossEntropy 是交叉熵的意思,故而 CrossEntropyLoss 的作用是計算交叉熵。CrossEntropyLoss 函數(shù)是將 torch.nn.Softmax 和 torch.nn.NLLLoss 兩個函數(shù)組合在一起使用,故而傳入的預(yù)測值不需要先進(jìn)行 torch.nnSoftmax 操作。
參數(shù):
input(N, C):N 是 batch_size,C 則是類別數(shù),即在定義模型輸出時,輸出節(jié)點個數(shù)要定義為 [N, C]。其中特別注意的是 target 的數(shù)據(jù)類型需要是浮點數(shù),即 float32
target(N):N 是 batch_size,故 target 需要是 1D 張量。其中特別注意的是 target 的數(shù)據(jù)類型需要是 long,即 int64
例子:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True, dtype=torch.float32)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output
輸出為:
tensor(1.6916, grad_fn=<NllLossBackward>)