MCTS,即蒙特卡羅樹(shù)搜索,是一類(lèi)搜索算法樹(shù)的統(tǒng)稱(chēng),可以較為有效地解決一些搜索空間巨大的問(wèn)題。
如一個(gè)8*8的棋盤(pán),第一步棋有64種著法,那么第二步則有63種,依次類(lèi)推,假如我們把第一步棋作為根節(jié)點(diǎn),那么其子節(jié)點(diǎn)就有63個(gè),再往下的子節(jié)點(diǎn)就有62個(gè)……
如果不加干預(yù),樹(shù)結(jié)構(gòu)將會(huì)繁雜,MCTS采用策略來(lái)對(duì)獲勝性較小的著法不予考慮,如第二步的63種著法中有10種是不可能勝利的,那么這十個(gè)子節(jié)點(diǎn)不予再次分配子節(jié)點(diǎn)。
MCTS的主要步驟分為四個(gè):
1, 選擇(Selection)
即找一個(gè)最好的值得探索的結(jié)點(diǎn),通常是先選擇沒(méi)有探索過(guò)的結(jié)點(diǎn),如果都探索過(guò)了,再選擇UCB值最大的進(jìn)行選擇(UCB是由一系列算法計(jì)算得到的值,這里先不詳細(xì)講,可以簡(jiǎn)單視為value)
2, 擴(kuò)展(Expansion)
已經(jīng)選擇好了需要進(jìn)行擴(kuò)展的結(jié)點(diǎn),那么就對(duì)其進(jìn)行擴(kuò)展,即對(duì)其一個(gè)子節(jié)點(diǎn)最為下一步棋的假設(shè),一般為隨機(jī)取一個(gè)可選的節(jié)點(diǎn)進(jìn)行擴(kuò)展。
3, 模擬(Simulation)
擴(kuò)展出了子節(jié)點(diǎn),就可以根據(jù)該子節(jié)點(diǎn)繼續(xù)進(jìn)行模擬了,我們隨機(jī)選擇一個(gè)可選的位置作為模擬下一步的落子,將其作為子節(jié)點(diǎn),然后依據(jù)該子節(jié)點(diǎn),繼續(xù)尋找可選的位置作為子節(jié)點(diǎn),依次類(lèi)推,直到博弈已經(jīng)判斷出了勝負(fù),將勝負(fù)信息作為最終得分。
4, 回溯更新(Backpropagation)
將最終的得分累加到父節(jié)點(diǎn),不斷從下向上累加更新。
對(duì)于UCB值,計(jì)算方法很簡(jiǎn)單,公式如下:
其中v'表示當(dāng)前樹(shù)節(jié)點(diǎn),v表示父節(jié)點(diǎn),Q表示這個(gè)樹(shù)節(jié)點(diǎn)的累計(jì)quality值,N表示這個(gè)樹(shù)節(jié)點(diǎn)的visit次數(shù),C是一個(gè) 常量參數(shù),通常值設(shè)為1/√2
接下來(lái)再討論怎么使用Python實(shí)現(xiàn)MCTS樹(shù)。
首先樹(shù)的每個(gè)節(jié)點(diǎn)Node需要記錄其父節(jié)點(diǎn)Node parent,和子節(jié)點(diǎn)Node children[],用于計(jì)算UCB的這個(gè)節(jié)點(diǎn)的quality值和visit次數(shù)。
def __init__(self):
self.parent = None
self.children = []
self.visit_times = 0
self.quality_value = 0.0
self.state = None
state中除了需要記錄每一步的選擇,還需要記錄每一步的層數(shù)round值與reward值。
class State(object):
def __init__(self):
self.value = 0
self.round = 0
self.choices = []
整棵樹(shù)需要實(shí)現(xiàn)的功能則是,在一個(gè)環(huán)境下,選擇出一個(gè)最有可能獲勝的策略。選擇的方法則是通過(guò)以上介紹的四個(gè)步驟不停模擬得到每個(gè)選擇的value。
其中,tree_policy函數(shù)實(shí)現(xiàn)了Selection和Expansion,default_poliy函數(shù)實(shí)現(xiàn)的是Simulation過(guò)程,backup函數(shù)是BackPropagation的實(shí)現(xiàn)。
def MCTS(node):
computation_budget = 3
for i in range(computation_budget):
# 1\. 找到最合適的可擴(kuò)展子節(jié)點(diǎn)
expand_node = tree_policy(node)
# 2\. 隨機(jī)選擇下一步策略對(duì)此子節(jié)點(diǎn)進(jìn)行模擬
reward = default_policy(expand_node)
# 3\. 將模擬結(jié)果向上回傳
backup(expand_node, reward)
# 最終得到勝利的可能性最大的子節(jié)點(diǎn)
best_next_node = best_child(node, False)
return best_next_node
tree_policy:選擇最合適的子節(jié)點(diǎn),選擇策略如下:
1,如果當(dāng)前的根節(jié)點(diǎn)是葉子節(jié)點(diǎn),即沒(méi)有子節(jié)點(diǎn)可以擴(kuò)展,以開(kāi)頭下棋的例子來(lái)講,即是已經(jīng)判斷出了勝負(fù)或者棋盤(pán)已滿(mǎn)的情況下,則直接返回當(dāng)前節(jié)點(diǎn)。
2,如果還有沒(méi)有選擇過(guò)的葉子節(jié)點(diǎn)(下一步的某個(gè)位置的著法還沒(méi)有被模擬過(guò)),就在沒(méi)有選擇過(guò)的方法中選擇一個(gè)返回。
3,如果所有可選擇的結(jié)點(diǎn)都已經(jīng)選擇過(guò)(當(dāng)前環(huán)境下所有的著法都已經(jīng)試過(guò)),那么往下選擇UCB值最大的子節(jié)點(diǎn),直到滿(mǎn)足1或2的情況,到達(dá)葉子節(jié)點(diǎn)或者出現(xiàn)未選擇過(guò)的結(jié)點(diǎn)。
def tree_policy(node):
# 是否是葉子節(jié)點(diǎn)
while not node.get_state().is_terminal():
# 如果全部可選的結(jié)點(diǎn)都選擇過(guò)
if node.is_all_expand():
# 選擇UCB最大的值
node = best_child(node, True)
else:
# 隨機(jī)選擇一個(gè)節(jié)點(diǎn)返回
sub_node = expand(node)
return sub_node
# 返回找到的最佳子節(jié)點(diǎn)
return node
default_policy:對(duì)當(dāng)前情況進(jìn)行模擬,直到判斷出勝負(fù)。
策略為:輸入需要擴(kuò)展的結(jié)點(diǎn),隨機(jī)操作后 創(chuàng)建新的結(jié)點(diǎn),直到最后遇到葉子節(jié)點(diǎn),得到該次模擬的reward,然后將reward返回。
def default_policy(node):
# 獲取當(dāng)前點(diǎn)的環(huán)境狀態(tài)
current_state = node.get_state()
# 如果沒(méi)有遇到葉子節(jié)點(diǎn),就一直循環(huán)
while current_state.is_terminal() == False:
# 隨機(jī)選取一個(gè)子節(jié)點(diǎn),返回新的環(huán)境參數(shù)
current_state = current_state.get_next_state_with_random_choice()
# 結(jié)束后,根據(jù)當(dāng)前的環(huán)境判斷勝負(fù),即獲得的reward值,并將其返回
final_state_reward = current_state.compute_reward()
return final_state_reward
關(guān)于這個(gè)算法,我簡(jiǎn)單做了一個(gè)實(shí)現(xiàn),每次從數(shù)組[1, -1, 2, -2]之間隨機(jī)取一個(gè)數(shù)做累加,共累計(jì)MAX_DEPTH層,使最終的和最大,我們根據(jù)運(yùn)行結(jié)果可以看到,開(kāi)始-1, -2的概率比較大,但是隨著訓(xùn)練層數(shù)的增大,越來(lái)越小,而1,2的比例會(huì)越來(lái)越大。
import sys
import math
import random
MAX_CHOICE = 4
MAX_DEPTH = 50
CHOICES = [1, -1, 2, -2]
class State(object):
def __init__(self):
self.value = 0
self.round = 0
self.choices = []
def new_state(self):
choice = random.choice(CHOICES)
state = State()
state.value = self.value + choice
state.round = self.round + 1
state.choices = self.choices + [choice]
return state
def __repr__(self):
return "State: {}, value: {}, choices: {}".format(
hash(self), self.value, self.choices)
class Node(object):
def __init__(self):
self.parent = None
self.children = []
self.quality = 0.0
self.visit = 0
self.state = None
def add_child(self, node):
self.children.append(node)
node.parent = self
def __repr__(self):
return "Node: {}, Q/N: {}/{}, state: {}".format(
hash(self), self.quality, self.visit, self.state)
def expand(node):
states = [nodes.state for nodes in node.children]
state = node.state.new_state()
while state in states:
state = node.state.new_state()
child_node = Node()
child_node.state = state
node.add_child(child_node)
return child_node
# 選擇, 擴(kuò)展
def tree_policy(node):
# 選擇是否是葉子節(jié)點(diǎn),
while node.state.round < MAX_DEPTH:
if len(node.children) < MAX_CHOICE:
node = expand(node)
return node
else:
node = best_child(node)
return node
# 模擬
def default_policy(node):
now_state = node.state
while now_state.round < MAX_DEPTH:
now_state = now_state.new_state()
return now_state.value
def backup(node, reward):
while node != None:
node.visit += 1
node.quality += reward
node = node.parent
def best_child(node):
best_score = -sys.maxsize
best = None
for sub_node in node.children:
C = 1 / math.sqrt(2.0)
left = sub_node.quality / sub_node.visit
right = 2.0 * math.log(node.visit) / sub_node.visit
score = left + C * math.sqrt(right)
if score > best_score:
best = sub_node
best_score = score
return best
def mcts(node):
times = 5
for i in range(times):
expand = tree_policy(node)
reward = default_policy(expand)
backup(expand, reward)
best = best_child(node)
return best
def main():
init_state = State()
init_node = Node()
init_node.state = init_state
current_node = init_node
for i in range(MAX_DEPTH):
a = 0.0
b = 0.0
c = 0.0
d = 0.0
current_node = mcts(current_node)
for j in range(len(current_node.state.choices)):
if current_node.state.choices[j] == -2:
a += 1
if current_node.state.choices[j] == -1:
b += 1
if current_node.state.choices[j] == 1:
c += 1
if current_node.state.choices[j] == 2:
d += 1
print("-2的概率為", round(a/(i + 1.0), 2),
"-1的概率為", round(b/(i + 1.0), 2),
"1的概率為", round(c/(i + 1.0), 2),
"2的概率為", round(d/(i + 1.0), 2))
if __name__ == "__main__":
main()
運(yùn)行結(jié)果:
