強(qiáng)化學(xué)習(xí)基礎(chǔ)篇(十六)蒙特卡洛預(yù)測(cè)算法在21點(diǎn)游戲的應(yīng)用
本節(jié)將介紹Monte Carlo prediction算法在Blackjack游戲中的進(jìn)行預(yù)測(cè)的過(guò)程。主要基于一個(gè)最簡(jiǎn)單的策略進(jìn)行評(píng)估,即“超過(guò)18點(diǎn)就不在要牌,低于18點(diǎn)就繼續(xù)要牌”。我將使用兩種類型的算法進(jìn)行評(píng)估,一個(gè)是首次訪問(wèn)型蒙特卡洛預(yù)測(cè)算法(First-visit MC prediction),另一個(gè)是每次訪問(wèn)型蒙特卡洛預(yù)測(cè)算法(Every-visit MC prediction)。
1、 首次訪問(wèn)型MC預(yù)測(cè)算法
回顧一下前面介紹的首次訪問(wèn)型MC預(yù)測(cè)算法。

2、21點(diǎn)游戲
21點(diǎn)游戲使用一副或多副標(biāo)準(zhǔn)的52張紙牌,每張牌都規(guī)定一個(gè)點(diǎn)值。2~10的牌其點(diǎn)值按面值計(jì)算。J、Q和K都算作10點(diǎn),A可算作1點(diǎn),也可算作11點(diǎn)。玩家的目標(biāo)是所抽牌的總點(diǎn)數(shù)比莊家的牌更接近21點(diǎn),但不超過(guò)21點(diǎn)。
首次發(fā)牌每人2張牌。莊家以順時(shí)針?lè)较蛳虮娡婕遗砂l(fā)一張暗牌(即不被揭開(kāi)的牌),隨后向自己派發(fā)一張暗牌;接著莊家會(huì)以順時(shí)針?lè)较蛳虮娡婕遗砂l(fā)一張明牌(即被揭開(kāi)的牌),之后向自己也派發(fā)一張明牌。當(dāng)眾人手上各擁一張暗牌和一張明牌時(shí),莊家就以順時(shí)針?lè)较蛑鹞辉儐?wèn)玩家是否再要牌(以明牌方式派發(fā))。在要牌的過(guò)程中。如果互家所有的牌加起來(lái)超過(guò)21點(diǎn),玩家就輸了(Bust),游戲介紹,該玩家的注碼歸莊家。
如果玩家無(wú)Bust,莊家詢問(wèn)完所有玩家之后,就必須揭開(kāi)自己上上的暗牌。若莊家總點(diǎn)數(shù)少于17點(diǎn),就必須繼續(xù)要牌;如果莊家Bust,便向沒(méi)有Bust的玩家,賠出該玩家所投的同等注碼。如果莊家無(wú)Bust且大于等于17點(diǎn),那么莊家與玩家比較點(diǎn)數(shù)決勝負(fù),大的為贏。點(diǎn)數(shù)相同,則為平手。
在該21點(diǎn)游戲例子中,收集經(jīng)驗(yàn)軌跡時(shí),首先需要確認(rèn)該游戲基于基策路下,進(jìn)行經(jīng)驗(yàn)數(shù)據(jù)收集。
為了便于理解,我們使用一個(gè)簡(jiǎn)單的策略,當(dāng)玩家手上的牌超過(guò)18點(diǎn)時(shí),返回0,表示不再要牌;當(dāng)點(diǎn)數(shù)少于18點(diǎn)時(shí),繼續(xù)要牌,并返回1。
def simple_policy(state):
"""
定義個(gè)簡(jiǎn)單的策略,當(dāng)玩家手上的牌超過(guò)18點(diǎn)時(shí),返回0,
表示不再要牌(hold);當(dāng)點(diǎn)數(shù)少于18點(diǎn)時(shí),繼續(xù)要牌(hit),并返回1。
"""
player_score, _, _ = state
return 0 if player_score >= 18 else 1
旅游戲的狀態(tài)是玩家的點(diǎn)數(shù)(Player),莊家的點(diǎn)數(shù)(Dealer)和是否有Blackjack(Ace)。具體到代碼中,player為玩家點(diǎn)數(shù),dealer為莊家點(diǎn)數(shù),ace為True時(shí)表明牌A算作11點(diǎn),
對(duì)于21點(diǎn)游戲,簡(jiǎn)化版的玩家動(dòng)作只有兩種:一種是拿牌,另一種是停牌。
拿牌(HIT):如果玩家拿牌,表示玩家希望再拿一張或多張牌,使總點(diǎn)數(shù)更接近21點(diǎn)。如果拿牌后玩家的總點(diǎn)數(shù)超過(guò)21點(diǎn),玩家就會(huì)Bust。
停牌(STAND):如果玩家停牌,表示玩家選擇不再抽牌并希望當(dāng)前總點(diǎn)數(shù)能夠打敗莊家。
3.First-visit MC prediction源碼清單
# coding: utf-8
import numpy as np
import gym
import sys
from collections import defaultdict
from Plot3D import plot_3D
def mc_firstvisit_prediction(policy, env, num_episodes,
episode_endtime= 10, discount = 1.0):
"""
該函數(shù)主要實(shí)現(xiàn)首次訪問(wèn)蒙特卡洛預(yù)測(cè)算法
"""
r_sum = defaultdict(float)
r_count = defaultdict(float)
r_V = defaultdict(float)
# 按照設(shè)定的num_episodes數(shù)量進(jìn)行對(duì)應(yīng)迭代
for each_episode in range(num_episodes):
# 打印迭代進(jìn)展
print("Episode {}/{}".format(each_episode,num_episodes),end = "\r")
sys.stdout.flush()
# 將episode初始化為列表
episode = []
# 重置環(huán)境
state = env.reset()
# 按照輸入的策略,采集episode
for _ in range(episode_endtime):
# 獲取當(dāng)前狀態(tài)是預(yù)定策略將會(huì)采取的動(dòng)作
action = policy(state)
# 與環(huán)境交互,獲取下一個(gè)狀態(tài),獎(jiǎng)勵(lì),以及結(jié)束標(biāo)志位
next_state, reward, done, info = env.step(action)
# 將(s,a,r)對(duì)插入episode中
episode.append((state, action, reward))
if done:
break
state = next_state
# 計(jì)算首次訪問(wèn)蒙特卡洛算法的值
for visit_pos, data in enumerate(episode):
# 遍歷episode過(guò)程中,state_visit為當(dāng)前遍歷的訪問(wèn)狀態(tài)
state_visit = data[0]
# x[2]為reward, 這里為對(duì)首次訪問(wèn)后的所有獎(jiǎng)勵(lì)都做帶discount的累加
G = sum([x[2] * np.power(discount, i) for i, x in enumerate(episode[visit_pos:])])
# 計(jì)算累積平均獎(jiǎng)勵(lì)
r_sum[state_visit] += G
r_count[state_visit] += 1.0
r_V[state_visit] = r_sum[state_visit] / r_count[state_visit]
return r_V
def simple_policy(state):
"""
定義個(gè)簡(jiǎn)單的策略,當(dāng)玩家手上的牌超過(guò)18點(diǎn)時(shí),返回0,
表示不再要牌(hold);當(dāng)點(diǎn)數(shù)少于18點(diǎn)時(shí),繼續(xù)要牌(hit),并返回1。
"""
player_score, _, _ = state
return 0 if player_score >= 18 else 1
def process_data_for_Blackjackproblem(V,ace=True):
"""
為Blackjack問(wèn)題進(jìn)行3D畫圖處理
"""
min_x = min(k[0] for k in V.keys())
max_x = max(k[0] for k in V.keys())
min_y = min(k[1] for k in V.keys())
max_y = max(k[1] for k in V.keys())
x_range = np.arange(min_x, max_x + 1)
y_range = np.arange(min_y, max_y + 1)
X, Y = np.meshgrid(x_range, y_range)
if ace:
Z = np.apply_along_axis(lambda _ : V[(_[0], _[1], True)], 2, np.dstack([X,Y]))
else:
Z = np.apply_along_axis(lambda _ : V[(_[0], _[1], False)], 2, np.dstack([X,Y]))
return X, Y, Z
if __name__ == "__main__":
# 調(diào)用gym的Blackjack-v0環(huán)境
env = gym.make("Blackjack-v0")
# 允許100萬(wàn)次的首次訪問(wèn)蒙特卡洛預(yù)測(cè)算法,并返回值函數(shù)
v1= mc_firstvisit_prediction(simple_policy, env, num_episodes=1000000)
print(v1)
# 進(jìn)行3D畫圖數(shù)據(jù)處理
X, Y, Z = process_data_for_Blackjackproblem(v1, ace=True)
fig = plot_3D(X, Y, Z, xlabel="Player sum", ylabel="Dealer sum", zlabel="Value", title="Usable Ace")
fig.show()
fig.savefig("./log/Usable_Ace.jpg")
X, Y, Z = process_data_for_Blackjackproblem(v1, ace= False)
fig = plot_3D(X, Y, Z, xlabel="Player sum", ylabel="Dealer sum", zlabel="Value", title="No Usable Ace")
fig.show()
fig.savefig("./log/No_Usable_Ace.jpg")
-
mc_firstvisit_prediction()方法定義了4個(gè)輸入:
? policy:定義的策略
? env:環(huán)境
? num_episodes:采樣的幕的數(shù)量
? episode_endtime= 10:設(shè)定個(gè)幕的最大數(shù)量
? discount :折扣因子
首先使用defaultdict定義了過(guò)程中使用的字典,defaultdict相比dict,當(dāng)字典里的key不存在但被查找時(shí),返回的不是keyError而是一個(gè)空默認(rèn)值。
gym返回的狀態(tài)有一個(gè)三元組組成,分別表示玩家當(dāng)前手上牌的總數(shù),莊家的牌點(diǎn)數(shù)(1為ace),以及玩家是否有ace的標(biāo)志。例如
表示玩家當(dāng)前手上牌的總數(shù)為14,莊家明牌點(diǎn)數(shù)為5,玩家當(dāng)前無(wú)ace。
-
以下代碼通過(guò)與環(huán)境交互可以獲得真實(shí)的episode。
for _ in range(episode_endtime): # 獲取當(dāng)前狀態(tài)是預(yù)定策略將會(huì)采取的動(dòng)作 action = policy(state) # 與環(huán)境交互,獲取下一個(gè)狀態(tài),獎(jiǎng)勵(lì),以及結(jié)束標(biāo)志位 next_state, reward, done, info = env.step(action) # 將(s,a,r)對(duì)插入episode中 episode.append((state, action, reward)) if done: break state = next_state結(jié)果:
[((4, 6, False), 1, 0.0), ((7, 6, False), 1, 0.0), ((16, 6, False), 1, 0.0), ((17, 6, False), 1, -1.0)]結(jié)果為一幕從開(kāi)始到結(jié)束的關(guān)于(state, action,reward)的三元組。
-
允許100萬(wàn)幕的首次訪問(wèn)蒙特卡洛預(yù)測(cè)的價(jià)值函數(shù)如下所示,其中包含了每個(gè)狀態(tài)的價(jià)值函數(shù)。
{ (19, 2, False): 0.3803322395406071, (14, 8, False): -0.37651049395802416, (20, 8, False): 0.7971809523809524, (12, 6, False): -0.3004375863657301, (18, 6, False): 0.28861230889847117, (18, 10, True): -0.252530199151159, .............................. (12, 5, True): 0.013071895424836602, (12, 3, True): 0.03160270880361174 }
4. First-visit MC prediction測(cè)試結(jié)果
下面兩張圖為我們對(duì)簡(jiǎn)單策略“超過(guò)18點(diǎn)就不在要牌,低于18點(diǎn)就繼續(xù)要牌”下,對(duì)應(yīng)的狀態(tài)值在有可用Ace與無(wú)可用Ace下的價(jià)值在三維空間的分布情況。顏色越深,狀態(tài)值越高。


5.每次訪問(wèn)型蒙特卡洛預(yù)測(cè)算法(Every-visit MC prediction)源碼清單
# coding: utf-8
import numpy as np
import gym
import sys
from collections import defaultdict
from Plot3D import plot_3D
def mc_everyvisit_prediction(policy, env, num_episodes,
episode_endtime= 10, discount = 1.0):
"""
該函數(shù)主要實(shí)現(xiàn)每次訪問(wèn)蒙特卡洛預(yù)測(cè)算法
"""
r_sum = defaultdict(float)
r_count = defaultdict(float)
r_V = defaultdict(float)
# 按照設(shè)定的num_episodes數(shù)量進(jìn)行對(duì)應(yīng)迭代
for each_episode in range(num_episodes):
# 打印迭代進(jìn)展
print("Episode {}/{}".format(each_episode,num_episodes),end = "\r")
sys.stdout.flush()
# 將episode初始化為列表
episode = []
state = env.reset()
# 按照輸入的策略,采集episode
for _ in range(episode_endtime):
action = policy(state)
# 與環(huán)境交互,獲取下一個(gè)狀態(tài),獎(jiǎng)勵(lì),以及結(jié)束標(biāo)志位
next_state, reward, done, info = env.step(action)
# 將(s,a,r)對(duì)插入episode中
episode.append((state, action, reward))
if done:
break
state = next_state
# 計(jì)算首次訪問(wèn)蒙特卡洛算法的價(jià)值
for visit_pos, data in enumerate(episode):
state_visit = data[0]
# x[2]為reward, 這里對(duì)所有獎(jiǎng)勵(lì)都做帶discount的累加
G = sum([x[2] * np.power(discount, i) for i, x in enumerate(episode)])
# 計(jì)算累積平均獎(jiǎng)勵(lì)
r_sum[state_visit] += G
r_count[state_visit] += 1.0
r_V[state_visit] = r_sum[state_visit] / r_count[state_visit]
return r_V
def simple_policy(state):
"""
定義個(gè)簡(jiǎn)單的策略,當(dāng)玩家手上的牌超過(guò)18點(diǎn)時(shí),返回0,
表示不再要牌(hold);當(dāng)點(diǎn)數(shù)少于18點(diǎn)時(shí),繼續(xù)要牌(hit),并返回1。
"""
player_score, _, _ = state
return 0 if player_score >= 18 else 1
def process_data_for_Blackjackproblem(V,ace=True):
"""
為Blackjack問(wèn)題進(jìn)行3D畫圖處理
"""
min_x = min(k[0] for k in V.keys())
max_x = max(k[0] for k in V.keys())
min_y = min(k[1] for k in V.keys())
max_y = max(k[1] for k in V.keys())
x_range = np.arange(min_x, max_x + 1)
y_range = np.arange(min_y, max_y + 1)
X, Y = np.meshgrid(x_range, y_range)
if ace:
Z = np.apply_along_axis(lambda _ : V[(_[0], _[1], True)], 2, np.dstack([X,Y]))
else:
Z = np.apply_along_axis(lambda _ : V[(_[0], _[1], False)], 2, np.dstack([X,Y]))
return X, Y, Z
if __name__ == "__main__":
# 調(diào)用gym的Blackjack-v0環(huán)境
env = gym.make("Blackjack-v0")
# 允許100萬(wàn)次的首次訪問(wèn)蒙特卡洛預(yù)測(cè)算法,并返回值函數(shù)
v1= mc_everyvisit_prediction(simple_policy, env, num_episodes=1000000)
print(v1)
# 進(jìn)行3D畫圖數(shù)據(jù)處理
X, Y, Z = process_data_for_Blackjackproblem(v1, ace=True)
fig = plot_3D(X, Y, Z, xlabel="Player sum", ylabel="Dealer sum", zlabel="Value", title="Usable Ace")
fig.savefig("./log/EveryVisit_Usable_Ace_1M.jpg")
X, Y, Z = process_data_for_Blackjackproblem(v1, ace= False)
fig = plot_3D(X, Y, Z, xlabel="Player sum", ylabel="Dealer sum", zlabel="Value", title="No Usable Ace")
fig.savefig("./log/EveryVisit_No_Usable_Ace_1M.jpg")
-
首次訪問(wèn)型蒙特卡洛預(yù)測(cè)算法(First-visit MC prediction)與每次訪問(wèn)型蒙特卡洛預(yù)測(cè)算法(Every-visit MC prediction)幾乎完全相同,唯一的區(qū)別在于計(jì)算未來(lái)折扣獎(jiǎng)勵(lì)方式的不同。
在每次訪問(wèn)蒙特卡洛預(yù)測(cè)算法中,每采集完一條經(jīng)驗(yàn)軌跡后,同樣首次訪問(wèn)型蒙特卡洛預(yù)測(cè)算法的方式對(duì)未來(lái)折扣累積獎(jiǎng)勵(lì)進(jìn)行計(jì)算,作為狀態(tài)值的期望。其中,算法使用到的參數(shù)都完全相同,r_sum表示該條經(jīng)驗(yàn)軌跡的總回報(bào),r_count表示該條經(jīng)驗(yàn)軌跡的統(tǒng)計(jì)次數(shù),r_V表示總體的狀態(tài)值。
每次訪同蒙特卡洛預(yù)測(cè)算法的核心在于:無(wú)論狀態(tài)
出現(xiàn)多少次,每一次的些勵(lì)返回值都被納入平均未來(lái)折扣累積獎(jiǎng)勵(lì)的計(jì)算:
G = sum([x[2] * np.power(discount, i) for i, x in enumerate(episode)])這里與首次獎(jiǎng)勵(lì)的計(jì)算方式有著細(xì)微的表述差異:
G = sum([x[2] * np.power(discount, i) for i, x in enumerate(episode[visit_pos:])])
5.Every-visit MC prediction測(cè)試結(jié)果
下面兩張圖為我們對(duì)簡(jiǎn)單策略“超過(guò)18點(diǎn)就不在要牌,低于18點(diǎn)就繼續(xù)要牌”下,對(duì)應(yīng)的狀態(tài)值在有可用Ace與無(wú)可用Ace下的價(jià)值在三維空間的分布情況。顏色越深,狀態(tài)值越高。
其實(shí)看起來(lái)測(cè)試結(jié)果與首次訪問(wèn)MC預(yù)測(cè)算法測(cè)試結(jié)果差異不大。

[圖片上傳失敗...(image-a5c558-1603118659973)]