用動(dòng)態(tài)規(guī)劃法求解強(qiáng)化學(xué)習(xí)問(wèn)題

強(qiáng)化學(xué)習(xí)環(huán)境:
gym
'CartPole-v1'

from collections import defaultdict
import gym
import numpy as np
import joblib
from pathlib import Path
from random import random

cart_pos_bin = np.linspace(-2.4, 2.4, num=6)[1:-1]
cart_vel_bin = np.linspace(-3, 3, num=4)[1:-1]
pole_ang_bin = np.linspace(-0.21, 0.21, num=8)[1:-1]
pole_vel_bin = np.linspace(-2.0, 2.0, num=6)[1:-1]

def state_coding(observation):
    cart_pos = np.digitize([observation[0]], cart_pos_bin)[0]
    cart_vel = np.digitize([observation[1]], cart_vel_bin)[0]
    pole_ang = np.digitize([observation[2]], pole_ang_bin)[0]
    pole_vel = np.digitize([observation[3]], pole_vel_bin)[0]
    return (cart_pos, cart_vel, pole_ang, pole_vel)

def choose_action(s):
    return 0 if  0< random() < softmax_for_choose(Q_table[s][:2])[0] else 1

def softmax_for_backoff(x):
    return np.exp(-x)/np.exp(-x).sum()

def softmax_for_choose(x):
    return np.exp(x)/np.exp(x).sum()

def E(s):
    qs = Q_table[s][:2]
    return qs.dot(softmax_for_backoff(qs))

def update_action_chain(action_chain):
    s_last = action_chain[-1][0]
    Q_table[s_last] = np.array([0., 0., -5.])
    for i in range(len(action_chain)-1)[::-1]:
        s, action, reward = action_chain[i]
        s_,action_,reward_ = action_chain[i+1]
        Q_table[s][action] = reward_ + Q_table[s_][-1]
        Q_table[s][-1] = E(s)

def test(n=50,visiable=False):
    scores = []
    for _ in range(n):
        score = 0
        observation = env.reset()
        s = state_coding(observation)
        action = np.argmax(Q_table[s][:2])
        while True:
            if visiable:env.render()
            observation_, _, done, _ = env.step(action)
            s_ = state_coding(observation_)
            action_ = np.argmax(Q_table[s_][:2])
            score += 1
            if done:
                scores.append(score)
                break
            s,action= s_,action_
        return np.mean(scores)

def train(n=100):
    for _ in range(100):
        action_chain = []
        observation = env.reset()
        s = state_coding(observation)
        action,reward = choose_action(s) ,1
        while True:
            observation_, reward_, done, _ = env.step(action)
            s_ = state_coding(observation_)
            action_ = choose_action(s_)
            if done:reward_ = -5
            action_chain.append([s_,action_,reward_])
            if done:
                update_action_chain(action_chain)
                break
            s,action,reward = s_,action_,reward_
    return test()

env = gym.make('CartPole-v1')
Q_table = defaultdict(lambda:np.zeros((3,)))

while True:
    if train(100)==500  and test(n=20000) == 500:
        while True:test(n=1,visiable=True)

env.close()

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容