第48章 使用稀松函數(shù)擬合數(shù)據(jù)及優(yōu)化器的使用

本章接續(xù)上一章,使用稀松矩陣函數(shù)來進(jìn)行回歸擬合,另外再介紹一下jax.example_libraries.optimizers的優(yōu)化器。

在真實(shí)場景中,往往遇到大量值為0的特殊矩陣,為了節(jié)約存儲(chǔ)空間,減少計(jì)算量,在進(jìn)行模型創(chuàng)建和處理的過程中,一個(gè)好的做法就是使用稀松函數(shù)處理。

使用稀松函數(shù)對數(shù)據(jù)擬合

與之前實(shí)戰(zhàn)類似,本任務(wù)也分為數(shù)據(jù)準(zhǔn)備、模型設(shè)計(jì)及模型訓(xùn)練等步驟。

數(shù)據(jù)準(zhǔn)備

為了簡單起見,采用one_hot方式生成年若干條數(shù)據(jù),并希望模型計(jì)算的結(jié)果對輸入的one_hot數(shù)據(jù)進(jìn)行恢復(fù)。


import jax
from jax.experimental import sparse

def setup():

    key = jax.random.PRNGKey(15)
    number_classes = 10
    classes = jax.numpy.arange(number_classes)
    
    inputs = []
    genuines = []
    
    for i in range(1024):
    
        x = jax.random.choice(key + 1, classes, shape = (1,))
        x = x[0]
        x_one_hoted = jax.nn.one_hot(x, num_classes = number_classes)
        
        inputs.append(x_one_hoted)
        genuines.append(x)
        
    params = [jax.random.normal(key = key, shape = (number_classes, 1)), jax.random.normal(key = key, shape = (1,))]
    
    sparsed_inputs = sparse.BCOO.fromdense(jax.numpy.array(inputs))
    genuines = jax.numpy.array(genuines)
    
    epochs = 10000
    learning_rate = 1e-3
    
    return (key, number_classes, params, epochs, learning_rate), (inputs, sparsed_inputs, genuines)

模型設(shè)計(jì)

使用一個(gè)簡單的邏輯回歸模型對數(shù)據(jù)進(jìn)行擬合,其模型設(shè)計(jì)如下,


# Activation function
def sigmoid(inputs):

    return 1 / 2 * (jax.numpy.tanh(inputs / 2) + 1)
    
# Prediction model
def predict(params, inputs):

    predictions = jax.numpy.dot(inputs, params[0]) + params[1]
    
    return sigmoid(predictions)

def loss_function(params, sparsed_inputs, genuines):

    sparsed_predict = sparse.sparsify(predict)
    predictions = sparsed_predict(params, sparsed_inputs)
    
    losses = genuines * jax.numpy.log(predictions) + (1 - genuines) * jax.numpy.log(1 - predictions)
    losses = -jax.numpy.mean(losses)
    
    return losses

訓(xùn)練模型

完成了數(shù)據(jù)準(zhǔn)備和模型的設(shè)計(jì),開始訓(xùn)練模型,


def train():

    (key, number_classes, params, epochs, learning_rate), (inputs, sparsed_inputs, genuines) = setup()

    losses = loss_function(params, sparsed_inputs, genuines)
    
    print("losses post the train = ", losses)
    print("params prior to the train = ", params)
    
    for i in range(epochs):
    
        loss_function_grad = jax.grad(loss_function)
        
        gradients = loss_function_grad(params, sparsed_inputs, genuines)
        params = [param - gradient * learning_rate for param, gradient in zip(params, gradients)]
        
        if (i + 1) % 100 == 0:
        
            print(f"losses after {i + 1} = ", losses)
        
    losses = loss_function(params, sparsed_inputs, genuines)
    
    print("losses post the train = ", losses)

def main():

    train()
    
if __name__ == "__main__":

    main()

上述代碼僅僅為了演示目的而使用了一個(gè)最賤的邏輯回歸模型對數(shù)據(jù)進(jìn)行擬合,更具體的應(yīng)用還需要在司機(jī)中深入掌握。

jax.example_libraries.optimizers優(yōu)化器

下面介紹jax.example_libraries.optimizers優(yōu)化器模塊的使用。

在第2章《上手JAX》中的深度學(xué)習(xí)簡單例子中使用了optimizers模塊,這個(gè)模塊就是對JAX的優(yōu)化器Optimizers的封裝。該模塊包含累一些方便使用的優(yōu)化器,特別是初始化和更新函數(shù),可以與ndarray或任意嵌套的jax.numpy函數(shù)和數(shù)據(jù)類型一起使用。

在第2章MNIST模型中使用下面的函數(shù),


optimizer_init_function, optimizer_update_function, get_params_function = optimizers.adam(step_size = step_size)

從定義的返回值來看,該函數(shù)返回了3個(gè)值,其實(shí)是指向函數(shù)的“指針”,即函數(shù)地址,

  • optimizer_init_function(initial_params),對優(yōu)化器數(shù)據(jù)初始化設(shè)置,主要是對封裝后的模型進(jìn)行參數(shù)的初始化。
  • optimizer_update_function(step, gradients, optimizer_state),三個(gè)參數(shù)分別是,
    • Step,表示步驟。
    • gradients表示梯度函數(shù)。
    • optimzier_state,既是優(yōu)化器的輸入值又是優(yōu)化器的輸出值。
  • get_params_function,返回優(yōu)化器中的參數(shù)。

下面是一個(gè)jax.example_libraries.optimizers優(yōu)化器使用的一般流程,


import jax.example_libraries.optimizers

def model():

    optimizer_init_function, optimzier_update_function, get_params_function = jax.example_libraries.optimizers.adam(step_size = 1e-3)
    
    _, initial_params = init_random_params(key, input_shape)
    
    optimizer_state = optimizer_init_function = optimizer_init_function(initial_params)
    # ...
    
    params = get_params_function(optimizer_state)
    loss_function_grad = jax.grad(loss_function)
    gradients = loss_function_grad(params, inputs, targets)
    # Optimize the inputs and update the params
    optimizer_state = optimzier_update_function(_, gradients, optimizer_state)
    
def main():

    model()
    
if __name__ == "__main__":

    main()

jax.example_libraries.optimizers提供了多種優(yōu)化函數(shù),

  • jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)
  • jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)
  • jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)
  • jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)
  • Clip gradients stored as a pytree of arrays to maximum norm max_norm.
  • jax.example_libraries.optimizers.constant(step_size)
  • jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)
  • jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)
  • jax.example_libraries.optimizers.l2_norm(tree)
  • jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)
  • jax.example_libraries.optimizers.momentum(step_size, mass)
  • jax.example_libraries.optimizers.nesterov(step_size, mass)

本系列課程結(jié)束后的實(shí)戰(zhàn)里,可以酌情使用這些優(yōu)化器。

jax.example_libraries.stax使用

jax.example_libraries.stax包含了目前神經(jīng)網(wǎng)絡(luò)計(jì)算所需要的絕大部分計(jì)算函數(shù),并且jax.example_libraries.stax.serial函數(shù)的作用就是將不同的函數(shù)封裝起來,成為一個(gè)可以用于神經(jīng)網(wǎng)絡(luò)訓(xùn)練的組合模型。

下面是一個(gè)簡單的用法,


import jax.example_libraries.stax

def model():

    init_random_params, predict = jax.example_libraries.stax.serial(
        
            jax.example_libraries.stax.Dense(1024),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Dense(1024),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Dense(10),
            
            jax.example_libraries.stax.LogSoftmax,
        )
    
def main():

    model()
    
if __name__ == "__main__":

    main()

這是使用stax庫封裝一個(gè)神經(jīng)網(wǎng)絡(luò)模型。其中實(shí)現(xiàn)了全連接層以及多個(gè)激活層,基本上和第2章相同。

結(jié)論

細(xì)心的人可能注意到了,無論是優(yōu)化器函數(shù)stax封裝都是一些便捷方法封裝,供那些不想從零開始編寫各種深度學(xué)習(xí)模型組件的開發(fā)者使用,熟悉PyTorch或者TensorFlow的人應(yīng)該熟悉這種形式,也就說框架提供了各式各樣的組件供使用者直接使用,而不關(guān)心該組件到底怎么實(shí)現(xiàn),

當(dāng)然,這和JAX初衷不同,JAX是想讓開發(fā)者完成掌控深度學(xué)習(xí)模塊的設(shè)計(jì),這也是為什么這些模塊放在實(shí)驗(yàn)性包或者示例包里,而不是核心庫。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容