強(qiáng)化學(xué)習(xí)基礎(chǔ)篇(二十二)DP小型網(wǎng)格問(wèn)題

強(qiáng)化學(xué)習(xí)基礎(chǔ)篇(二十二)DP小型網(wǎng)格問(wèn)題

該問(wèn)題基于《Reinforcement Learning: An Introduction》在第四章的例4.1。

1、問(wèn)題描述

考慮下面的這個(gè)4*4的網(wǎng)格圖

image.png

非終止?fàn)顟B(tài)集合S={1,2,...,14}。每個(gè)狀態(tài)有四種可能的動(dòng)作,A={up,down ,left,right}。每個(gè)動(dòng)作會(huì)導(dǎo)致?tīng)顟B(tài)轉(zhuǎn)移,但當(dāng)動(dòng)作會(huì)導(dǎo)致智能體移出網(wǎng)格時(shí),狀態(tài)保持不變。比如,p(6,-1 \mid 5,right)=1p(7,-1 \mid 7,right)=1和對(duì)于任意r \in R,都有p(10,r \mid 5,right)=0。這是一個(gè)無(wú)折扣的分幕式任務(wù)。在到達(dá)終止?fàn)顟B(tài)之前,所有動(dòng)作的收益均為-1。終止?fàn)顟B(tài)在圖中以陰影顯示(盡管圖中顯示了兩個(gè)格子,但實(shí)際僅有一個(gè)終止?fàn)顟B(tài))。對(duì)于所有的狀態(tài)ss'以及動(dòng)作a,期望的收益函數(shù)均為r(s,a,s')=-1。假設(shè)智能體采取等概率隨機(jī)策略(所有動(dòng)作等可能執(zhí)行),我們需要計(jì)算在迭代策略評(píng)估中價(jià)值函數(shù)序列的收斂情況。

2、實(shí)現(xiàn)過(guò)程

2.1、環(huán)境定義

首先導(dǎo)入庫(kù)函數(shù)以及定義環(huán)境信息:

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.table import Table

# 定義網(wǎng)格世界大小
WORLD_SIZE = 4

# 把動(dòng)作定義為對(duì)x,y坐標(biāo)的增減改變
# left, up, right, down
ACTIONS = [np.array([0, -1]),  # 向上
           np.array([-1, 0]),  # 向左
           np.array([0, 1]),   # 向下
           np.array([1, 0])]   # 向右
# 該問(wèn)題中每個(gè)動(dòng)作選擇的概率為0.25
ACTION_PROB = 0.25
# 定義畫(huà)圖會(huì)用到的動(dòng)作
ACTIONS_FIGS=[ '←', '↑', '→', '↓']

然后定義左上角和右下角兩個(gè)坐標(biāo)(0,0)與(WORLD_SIZE - 1,WORLD_SIZE - 1)兩個(gè)點(diǎn)為terminal

def is_terminal(state):
    '''
    返回是否為terminal
    '''
    x, y = state
    return (x == 0 and y == 0) or (x == WORLD_SIZE - 1 and y == WORLD_SIZE - 1)

2.2、定義動(dòng)作執(zhí)行過(guò)程

def step(state, action):
    # 當(dāng)?shù)竭_(dá)terminal時(shí),下一步狀態(tài)不變,獎(jiǎng)勵(lì)為0
    if is_terminal(state):
        return state, 0
    # 計(jì)算下一個(gè)狀態(tài)
    next_state = (np.array(state) + action).tolist()
    x, y = next_state
    # 當(dāng)運(yùn)動(dòng)未知超出格子世界則在原位置不變
    if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE:
        next_state = state

    reward = -1
    return next_state, reward

2.3、輔助函數(shù)

以下輔助函數(shù)主要用于畫(huà)圖

def draw_image(image):
    fig, ax = plt.subplots()
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = image.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    # Add cells
    for (i, j), val in np.ndenumerate(image):
        tb.add_cell(i, j, width, height, text=val,
                    loc='center', facecolor='white')

        # Row and column labels...
    for i in range(len(image)):
        tb.add_cell(i, -1, width, height, text=i+1, loc='right',
                    edgecolor='none', facecolor='none')
        tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
                    edgecolor='none', facecolor='none')
    ax.add_table(tb)

以下輔助函數(shù)用戶(hù)策略描述:

def draw_policy(optimal_values):
    fig, ax = plt.subplots()
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = optimal_values.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    # Add cells
    for (i, j), val in np.ndenumerate(optimal_values):
        next_vals=[]
        for action in ACTIONS:
            next_state, _ = step([i, j], action)
            next_vals.append(optimal_values[next_state[0],next_state[1]])

        best_actions=np.where(next_vals == np.max(next_vals))[0]
        val=''
        for ba in best_actions:
            val+=ACTIONS_FIGS[ba]
        
        # add state labels
        if [i, j] == [0,0]:
            val = "terminal"
        if [i, j] == [WORLD_SIZE - 1,WORLD_SIZE - 1]:
            val = "terminal"
        
        tb.add_cell(i, j, width, height, text=val,
                loc='center', facecolor='white')

    # Row and column labels...
    for i in range(len(optimal_values)):
        tb.add_cell(i, -1, width, height, text=i+1, loc='right',
                    edgecolor='none', facecolor='none')
        tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
                   edgecolor='none', facecolor='none')

    ax.add_table(tb)

2.4、使用迭代策略評(píng)估算法估算狀態(tài)值函數(shù)

使用迭代策略評(píng)估算法估算狀態(tài)值函數(shù),遵循的算法如下:

image.png

這里同時(shí)考慮了是否使用In-Place動(dòng)態(tài)規(guī)劃(In-place dynamic programming)。

在基于同步動(dòng)態(tài)規(guī)劃的值迭代算法中,存儲(chǔ)了兩個(gè)值函數(shù)的備份,分別是v_{new}(s)v_{old}(s)。
v_{new}(s)=\max_a(r+\gamma \sum_{s' \in S}p(s'|s,a)v_{old}(s'))
即在計(jì)算過(guò)程中,通過(guò)賦值的方式使舊的狀態(tài)值作為下一次計(jì)算新的狀態(tài)值。
而In-place動(dòng)態(tài)規(guī)劃(In-Place Dynamic Programming,IPDP)則是去掉舊的狀態(tài)值v_{old}(s),只保留最新的狀態(tài)值v_{new}(s),在更新的過(guò)程中可以減少存儲(chǔ)空間的浪費(fèi)。
v(s)=\max_a(r+\gamma \sum_{s' \in S}p(s'|s,a)v(s'))

直接原地更新下一個(gè)狀態(tài)值v(s),而不像同步迭代那樣需要額外存儲(chǔ)新的狀態(tài)值v_{new}(s)。在這種情況下,按何種次序更新?tīng)顟B(tài)值有時(shí)候會(huì)更具有意義。

def compute_state_value(in_place=True, discount=1.0):
    # 初始化狀態(tài)值函數(shù)為0
    new_state_values = np.zeros((WORLD_SIZE, WORLD_SIZE))
    # 在中間幾個(gè)迭代進(jìn)行可視化繪圖
    draw_iteration=[0,1,2,3,10]
    iteration = 0
    while True:
        if iteration in draw_iteration:
            draw_image(np.round(new_state_values, decimals=2))
        # 判斷是否使用In-Place動(dòng)態(tài)規(guī)劃(In-place dynamic programming)
        if in_place:
            state_values = new_state_values
        else:
            state_values = new_state_values.copy()
        old_state_values = state_values.copy()
        
        # 遍歷所有狀態(tài)
        for i in range(WORLD_SIZE):
            for j in range(WORLD_SIZE):
                value = 0
                # 遍歷所有動(dòng)作,按DP算法更新
                for action in ACTIONS:
                    (next_i, next_j), reward = step([i, j], action)
                    value += ACTION_PROB * (reward + discount * state_values[next_i, next_j])
                new_state_values[i, j] = value
        
        # 誤差小于門(mén)限則停止更新
        max_delta_value = abs(old_state_values - new_state_values).max()
        if max_delta_value < 1e-4:
            draw_image(np.round(new_state_values, decimals=2))
            break

        iteration += 1

    return new_state_values, iteration

3、實(shí)驗(yàn)結(jié)果

運(yùn)行如下代碼測(cè)試在使用In-place動(dòng)態(tài)規(guī)劃(In-Place Dynamic Programming,IPDP)的結(jié)果:

 _, asycn_iteration = compute_state_value(in_place=True)

結(jié)果整理后如下:

In-place: 113 iterations
image.png
image.png

運(yùn)行如下代碼測(cè)試在不使用In-place動(dòng)態(tài)規(guī)劃(In-Place Dynamic Programming,IPDP)的結(jié)果:

values, sync_iteration = compute_state_value(in_place=False)

結(jié)果整理后如下:

Synchronous: 172 iterations
image.png
image.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀(guān)點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容