本章接續(xù)上一章,使用稀松矩陣函數(shù)來進(jìn)行回歸擬合,另外再介紹一下jax.example_libraries.optimizers的優(yōu)化器。
在真實(shí)場景中,往往遇到大量值為0的特殊矩陣,為了節(jié)約存儲(chǔ)空間,減少計(jì)算量,在進(jìn)行模型創(chuàng)建和處理的過程中,一個(gè)好的做法就是使用稀松函數(shù)處理。
使用稀松函數(shù)對數(shù)據(jù)擬合
與之前實(shí)戰(zhàn)類似,本任務(wù)也分為數(shù)據(jù)準(zhǔn)備、模型設(shè)計(jì)及模型訓(xùn)練等步驟。
數(shù)據(jù)準(zhǔn)備
為了簡單起見,采用one_hot方式生成年若干條數(shù)據(jù),并希望模型計(jì)算的結(jié)果對輸入的one_hot數(shù)據(jù)進(jìn)行恢復(fù)。
import jax
from jax.experimental import sparse
def setup():
key = jax.random.PRNGKey(15)
number_classes = 10
classes = jax.numpy.arange(number_classes)
inputs = []
genuines = []
for i in range(1024):
x = jax.random.choice(key + 1, classes, shape = (1,))
x = x[0]
x_one_hoted = jax.nn.one_hot(x, num_classes = number_classes)
inputs.append(x_one_hoted)
genuines.append(x)
params = [jax.random.normal(key = key, shape = (number_classes, 1)), jax.random.normal(key = key, shape = (1,))]
sparsed_inputs = sparse.BCOO.fromdense(jax.numpy.array(inputs))
genuines = jax.numpy.array(genuines)
epochs = 10000
learning_rate = 1e-3
return (key, number_classes, params, epochs, learning_rate), (inputs, sparsed_inputs, genuines)
模型設(shè)計(jì)
使用一個(gè)簡單的邏輯回歸模型對數(shù)據(jù)進(jìn)行擬合,其模型設(shè)計(jì)如下,
# Activation function
def sigmoid(inputs):
return 1 / 2 * (jax.numpy.tanh(inputs / 2) + 1)
# Prediction model
def predict(params, inputs):
predictions = jax.numpy.dot(inputs, params[0]) + params[1]
return sigmoid(predictions)
def loss_function(params, sparsed_inputs, genuines):
sparsed_predict = sparse.sparsify(predict)
predictions = sparsed_predict(params, sparsed_inputs)
losses = genuines * jax.numpy.log(predictions) + (1 - genuines) * jax.numpy.log(1 - predictions)
losses = -jax.numpy.mean(losses)
return losses
訓(xùn)練模型
完成了數(shù)據(jù)準(zhǔn)備和模型的設(shè)計(jì),開始訓(xùn)練模型,
def train():
(key, number_classes, params, epochs, learning_rate), (inputs, sparsed_inputs, genuines) = setup()
losses = loss_function(params, sparsed_inputs, genuines)
print("losses post the train = ", losses)
print("params prior to the train = ", params)
for i in range(epochs):
loss_function_grad = jax.grad(loss_function)
gradients = loss_function_grad(params, sparsed_inputs, genuines)
params = [param - gradient * learning_rate for param, gradient in zip(params, gradients)]
if (i + 1) % 100 == 0:
print(f"losses after {i + 1} = ", losses)
losses = loss_function(params, sparsed_inputs, genuines)
print("losses post the train = ", losses)
def main():
train()
if __name__ == "__main__":
main()
上述代碼僅僅為了演示目的而使用了一個(gè)最賤的邏輯回歸模型對數(shù)據(jù)進(jìn)行擬合,更具體的應(yīng)用還需要在司機(jī)中深入掌握。
jax.example_libraries.optimizers優(yōu)化器
下面介紹jax.example_libraries.optimizers優(yōu)化器模塊的使用。
在第2章《上手JAX》中的深度學(xué)習(xí)簡單例子中使用了optimizers模塊,這個(gè)模塊就是對JAX的優(yōu)化器Optimizers的封裝。該模塊包含累一些方便使用的優(yōu)化器,特別是初始化和更新函數(shù),可以與ndarray或任意嵌套的jax.numpy函數(shù)和數(shù)據(jù)類型一起使用。
在第2章MNIST模型中使用下面的函數(shù),
optimizer_init_function, optimizer_update_function, get_params_function = optimizers.adam(step_size = step_size)
從定義的返回值來看,該函數(shù)返回了3個(gè)值,其實(shí)是指向函數(shù)的“指針”,即函數(shù)地址,
- optimizer_init_function(initial_params),對優(yōu)化器數(shù)據(jù)初始化設(shè)置,主要是對封裝后的模型進(jìn)行參數(shù)的初始化。
- optimizer_update_function(step, gradients, optimizer_state),三個(gè)參數(shù)分別是,
- Step,表示步驟。
- gradients表示梯度函數(shù)。
- optimzier_state,既是優(yōu)化器的輸入值又是優(yōu)化器的輸出值。
- get_params_function,返回優(yōu)化器中的參數(shù)。
下面是一個(gè)jax.example_libraries.optimizers優(yōu)化器使用的一般流程,
import jax.example_libraries.optimizers
def model():
optimizer_init_function, optimzier_update_function, get_params_function = jax.example_libraries.optimizers.adam(step_size = 1e-3)
_, initial_params = init_random_params(key, input_shape)
optimizer_state = optimizer_init_function = optimizer_init_function(initial_params)
# ...
params = get_params_function(optimizer_state)
loss_function_grad = jax.grad(loss_function)
gradients = loss_function_grad(params, inputs, targets)
# Optimize the inputs and update the params
optimizer_state = optimzier_update_function(_, gradients, optimizer_state)
def main():
model()
if __name__ == "__main__":
main()
jax.example_libraries.optimizers提供了多種優(yōu)化函數(shù),
- jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)
- jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)
- jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)
- jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)
- Clip gradients stored as a pytree of arrays to maximum norm max_norm.
- jax.example_libraries.optimizers.constant(step_size)
- jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)
- jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)
- jax.example_libraries.optimizers.l2_norm(tree)
- jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)
- jax.example_libraries.optimizers.momentum(step_size, mass)
- jax.example_libraries.optimizers.nesterov(step_size, mass)
本系列課程結(jié)束后的實(shí)戰(zhàn)里,可以酌情使用這些優(yōu)化器。
jax.example_libraries.stax使用
jax.example_libraries.stax包含了目前神經(jīng)網(wǎng)絡(luò)計(jì)算所需要的絕大部分計(jì)算函數(shù),并且jax.example_libraries.stax.serial函數(shù)的作用就是將不同的函數(shù)封裝起來,成為一個(gè)可以用于神經(jīng)網(wǎng)絡(luò)訓(xùn)練的組合模型。
下面是一個(gè)簡單的用法,
import jax.example_libraries.stax
def model():
init_random_params, predict = jax.example_libraries.stax.serial(
jax.example_libraries.stax.Dense(1024),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Dense(1024),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Dense(10),
jax.example_libraries.stax.LogSoftmax,
)
def main():
model()
if __name__ == "__main__":
main()
這是使用stax庫封裝一個(gè)神經(jīng)網(wǎng)絡(luò)模型。其中實(shí)現(xiàn)了全連接層以及多個(gè)激活層,基本上和第2章相同。
結(jié)論
細(xì)心的人可能注意到了,無論是優(yōu)化器函數(shù)stax封裝都是一些便捷方法封裝,供那些不想從零開始編寫各種深度學(xué)習(xí)模型組件的開發(fā)者使用,熟悉PyTorch或者TensorFlow的人應(yīng)該熟悉這種形式,也就說框架提供了各式各樣的組件供使用者直接使用,而不關(guān)心該組件到底怎么實(shí)現(xiàn),
當(dāng)然,這和JAX初衷不同,JAX是想讓開發(fā)者完成掌控深度學(xué)習(xí)模塊的設(shè)計(jì),這也是為什么這些模塊放在實(shí)驗(yàn)性包或者示例包里,而不是核心庫。