PyTorch 基礎(chǔ)筆記

?PyTorch基礎(chǔ)筆記,變量與計算.

Torch 很好用, 但是 Lua 又不是特別流行, 所有開發(fā)團隊將 Lua 的 Torch 移植到了更流行的語言 Python 上,所以就有了PyTorch.

對比TensorFlow

PyTorch最大優(yōu)點就是建立的神經(jīng)網(wǎng)絡是動態(tài)的, 對比靜態(tài)的 Tensorflow, 他能更有效地處理一些問題, 比如說 RNN 變化時間長度的輸出. Tensorflow 自己說自己在分布式訓練上下了很大的功夫, 那我就默認 Tensorflow 在這一點上要超出 PyTorch, 但是 Tensorflow 的靜態(tài)計算圖使得他在 RNN 上有一點點被動 (雖然它用其他途徑解決了), 不過用 PyTorch 的時候, 你會對這種動態(tài)的 RNN 有更好的理解.

而且 Tensorflow 的高度工業(yè)化, 它的底層代碼… 你是看不懂的. PyTorch 好那么一點點, 如果你深入 API, 你至少能比看 Tensorflow 多看懂一點點 PyTorch 的底層在干嘛.

Numpy 和 Torch

Torch 自稱為神經(jīng)網(wǎng)絡界的 Numpy, 因為他能將 torch 產(chǎn)生的 tensor 放在 GPU 中加速運算 (前提是你有合適的 GPU), 就像 Numpy 會把 array 放在 CPU 中加速運算.所以神經(jīng)網(wǎng)絡用 Torch 的 tensor 形式數(shù)據(jù)是最好的. 就像 Tensorflow 當中的 tensor 一樣.

PyTorch把 torch 做的和 numpy 能很好的兼容.

比如這樣就能自由地轉(zhuǎn)換 numpy array 和 torch tensor 了:

import torch
import numpy as np

np_data = np.arange(6).reshape((2, 3))
torch_data = torch.from_numpy(np_data)
tensor2array = torch_data.numpy()
print(
    '\nnumpy array:', np_data,          # [[0 1 2], [3 4 5]]
    '\ntorch tensor:', torch_data,      #  0  1  2 \n 3  4  5    [torch.LongTensor of size 2x3]
    '\ntensor to array:', tensor2array, # [[0 1 2], [3 4 5]]
)

Torch中的數(shù)學計算

其實 torch 中 tensor 的運算和 numpy array 的如出一轍, 我們就以對比的形式來看.如果想了解 torch 中其它更多有用的運算符, 可以查閱API.

# abs 絕對值計算
data = [-1, -2, 1, 2]
tensor = torch.FloatTensor(data)  # 轉(zhuǎn)換成32位浮點 tensor
print(
    '\nabs',
    '\nnumpy: ', np.abs(data),          # [1 2 1 2]
    '\ntorch: ', torch.abs(tensor)      # [1 2 1 2]
)

# sin   三角函數(shù) sin
print(
    '\nsin',
    '\nnumpy: ', np.sin(data),      # [-0.84147098 -0.90929743  0.84147098  0.90929743]
    '\ntorch: ', torch.sin(tensor)  # [-0.8415 -0.9093  0.8415  0.9093]
)

# mean  均值
print(
    '\nmean',
    '\nnumpy: ', np.mean(data),         # 0.0
    '\ntorch: ', torch.mean(tensor)     # 0.0
)

除了簡單的計算, 矩陣運算才是神經(jīng)網(wǎng)絡中最重要的部分.

所以展示矩陣的乘法. 注意一下包含了一個 numpy 中可行, 但是 torch 中不可行的方式.

# matrix multiplication 矩陣點乘
data = [[1,2], [3,4]]
tensor = torch.FloatTensor(data)  # 轉(zhuǎn)換成32位浮點 tensor
# correct method
print(
    '\nmatrix multiplication (matmul)',
    '\nnumpy: ', np.matmul(data, data),     # [[7, 10], [15, 22]]
    '\ntorch: ', torch.mm(tensor, tensor)   # [[7, 10], [15, 22]]
)

# !!!!  下面是錯誤的方法 !!!!
data = np.array(data)
print(
    '\nmatrix multiplication (dot)',
    '\nnumpy: ', data.dot(data),        # [[7, 10], [15, 22]] 在numpy 中可行
    '\ntorch: ', tensor.dot(tensor)     # torch 會轉(zhuǎn)換成 [1,2,3,4].dot([1,2,3,4) = 30.0
)

變量Variable

在 Torch 中的 Variable 就是一個存放會變化的值的地理位置. 里面的值會不停的變化. 如果用一個 Variable 進行計算, 那返回的也是一個同類型的 Variable.

定義一個 Variable:

import torch
from torch.autograd import Variable # torch 中 Variable 模塊


tensor = torch.FloatTensor([[1,2],[3,4]])

variable = Variable(tensor, requires_grad=True)

print(tensor)
"""
 1  2
 3  4
[torch.FloatTensor of size 2x2]
"""

print(variable)
"""
Variable containing:
 1  2
 3  4
[torch.FloatTensor of size 2x2]
"""

Variable 計算, 梯度

再對比一下 tensor 的計算和 variable 的計算.

t_out = torch.mean(tensor*tensor)       # x^2
v_out = torch.mean(variable*variable)   # x^2
print(t_out)
print(v_out)    # 7.5

到目前為止, 看不出什么不同, 但是時刻記住, Variable 計算時, 它在背景幕布后面一步步默默地搭建著一個龐大的系統(tǒng),叫做計算圖, computational graph. 這個圖是用來將所有的計算步驟 (節(jié)點) 都連接起來,最后進行誤差反向傳遞的時候, 一次性將所有 variable 里面的修改幅度 (梯度) 都計算出來, tensor 沒有這個能力.

v_out = torch.mean(variable*variable) 就是在計算圖中添加的一個計算步驟, 計算誤差反向傳遞的時候有他一份功勞,
我們就來舉個例子:

v_out.backward()    # 模擬 v_out 的誤差反向傳遞

# Variable 是計算圖的一部分, 可以用來傳遞誤差.
# v_out = 1/4 * sum(variable*variable) 這是計算圖中的 v_out 計算步驟
# 針對于 v_out 的梯度就是, d(v_out)/d(variable) = 1/4*2*variable = variable/2

print(variable.grad)    # 初始 Variable 的梯度
'''
 0.5000  1.0000
 1.5000  2.0000
'''

獲取 Variable 里面的數(shù)據(jù)

直接print(variable)只會輸出 Variable 形式的數(shù)據(jù), 在很多時候是用不了的(比如想要用 plt 畫圖),
所以我們要轉(zhuǎn)換一下, 將它變成 tensor 形式.

print(variable)     #  Variable 形式
"""
Variable containing:
 1  2
 3  4
[torch.FloatTensor of size 2x2]
"""

print(variable.data)    # tensor 形式
"""
 1  2
 3  4
[torch.FloatTensor of size 2x2]
"""

print(variable.data.numpy())    # numpy 形式
"""
[[ 1.  2.]
 [ 3.  4.]]
"""
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容