MCTS樹(shù)學(xué)習(xí)

MCTS,即蒙特卡羅樹(shù)搜索,是一類(lèi)搜索算法樹(shù)的統(tǒng)稱(chēng),可以較為有效地解決一些搜索空間巨大的問(wèn)題。

如一個(gè)8*8的棋盤(pán),第一步棋有64種著法,那么第二步則有63種,依次類(lèi)推,假如我們把第一步棋作為根節(jié)點(diǎn),那么其子節(jié)點(diǎn)就有63個(gè),再往下的子節(jié)點(diǎn)就有62個(gè)……

如果不加干預(yù),樹(shù)結(jié)構(gòu)將會(huì)繁雜,MCTS采用策略來(lái)對(duì)獲勝性較小的著法不予考慮,如第二步的63種著法中有10種是不可能勝利的,那么這十個(gè)子節(jié)點(diǎn)不予再次分配子節(jié)點(diǎn)。

MCTS的主要步驟分為四個(gè):

1, 選擇(Selection)

即找一個(gè)最好的值得探索的結(jié)點(diǎn),通常是先選擇沒(méi)有探索過(guò)的結(jié)點(diǎn),如果都探索過(guò)了,再選擇UCB值最大的進(jìn)行選擇(UCB是由一系列算法計(jì)算得到的值,這里先不詳細(xì)講,可以簡(jiǎn)單視為value)

2, 擴(kuò)展(Expansion)

已經(jīng)選擇好了需要進(jìn)行擴(kuò)展的結(jié)點(diǎn),那么就對(duì)其進(jìn)行擴(kuò)展,即對(duì)其一個(gè)子節(jié)點(diǎn)最為下一步棋的假設(shè),一般為隨機(jī)取一個(gè)可選的節(jié)點(diǎn)進(jìn)行擴(kuò)展。

3, 模擬(Simulation)

擴(kuò)展出了子節(jié)點(diǎn),就可以根據(jù)該子節(jié)點(diǎn)繼續(xù)進(jìn)行模擬了,我們隨機(jī)選擇一個(gè)可選的位置作為模擬下一步的落子,將其作為子節(jié)點(diǎn),然后依據(jù)該子節(jié)點(diǎn),繼續(xù)尋找可選的位置作為子節(jié)點(diǎn),依次類(lèi)推,直到博弈已經(jīng)判斷出了勝負(fù),將勝負(fù)信息作為最終得分。

4, 回溯更新(Backpropagation)

將最終的得分累加到父節(jié)點(diǎn),不斷從下向上累加更新。

對(duì)于UCB值,計(jì)算方法很簡(jiǎn)單,公式如下:

image

其中v'表示當(dāng)前樹(shù)節(jié)點(diǎn),v表示父節(jié)點(diǎn),Q表示這個(gè)樹(shù)節(jié)點(diǎn)的累計(jì)quality值,N表示這個(gè)樹(shù)節(jié)點(diǎn)的visit次數(shù),C是一個(gè) 常量參數(shù),通常值設(shè)為1/√2

接下來(lái)再討論怎么使用Python實(shí)現(xiàn)MCTS樹(shù)。

首先樹(shù)的每個(gè)節(jié)點(diǎn)Node需要記錄其父節(jié)點(diǎn)Node parent,和子節(jié)點(diǎn)Node children[],用于計(jì)算UCB的這個(gè)節(jié)點(diǎn)的quality值和visit次數(shù)。

    def __init__(self):
        self.parent = None
        self.children = []

        self.visit_times = 0
        self.quality_value = 0.0

        self.state = None

state中除了需要記錄每一步的選擇,還需要記錄每一步的層數(shù)round值與reward值。

class State(object):
    def __init__(self):
        self.value = 0
        self.round = 0
        self.choices = []

整棵樹(shù)需要實(shí)現(xiàn)的功能則是,在一個(gè)環(huán)境下,選擇出一個(gè)最有可能獲勝的策略。選擇的方法則是通過(guò)以上介紹的四個(gè)步驟不停模擬得到每個(gè)選擇的value。

其中,tree_policy函數(shù)實(shí)現(xiàn)了Selection和Expansion,default_poliy函數(shù)實(shí)現(xiàn)的是Simulation過(guò)程,backup函數(shù)是BackPropagation的實(shí)現(xiàn)。

def MCTS(node):

    computation_budget = 3

    for i in range(computation_budget):

        # 1\. 找到最合適的可擴(kuò)展子節(jié)點(diǎn)        
        expand_node = tree_policy(node)

        # 2\. 隨機(jī)選擇下一步策略對(duì)此子節(jié)點(diǎn)進(jìn)行模擬       
        reward = default_policy(expand_node)

        # 3\. 將模擬結(jié)果向上回傳
        backup(expand_node, reward)

    # 最終得到勝利的可能性最大的子節(jié)點(diǎn)

     best_next_node = best_child(node, False)

     return best_next_node

tree_policy:選擇最合適的子節(jié)點(diǎn),選擇策略如下:

1,如果當(dāng)前的根節(jié)點(diǎn)是葉子節(jié)點(diǎn),即沒(méi)有子節(jié)點(diǎn)可以擴(kuò)展,以開(kāi)頭下棋的例子來(lái)講,即是已經(jīng)判斷出了勝負(fù)或者棋盤(pán)已滿(mǎn)的情況下,則直接返回當(dāng)前節(jié)點(diǎn)。

2,如果還有沒(méi)有選擇過(guò)的葉子節(jié)點(diǎn)(下一步的某個(gè)位置的著法還沒(méi)有被模擬過(guò)),就在沒(méi)有選擇過(guò)的方法中選擇一個(gè)返回。

3,如果所有可選擇的結(jié)點(diǎn)都已經(jīng)選擇過(guò)(當(dāng)前環(huán)境下所有的著法都已經(jīng)試過(guò)),那么往下選擇UCB值最大的子節(jié)點(diǎn),直到滿(mǎn)足1或2的情況,到達(dá)葉子節(jié)點(diǎn)或者出現(xiàn)未選擇過(guò)的結(jié)點(diǎn)。

def tree_policy(node):

    # 是否是葉子節(jié)點(diǎn)
    while not node.get_state().is_terminal():

         # 如果全部可選的結(jié)點(diǎn)都選擇過(guò)
         if node.is_all_expand():
             # 選擇UCB最大的值
             node = best_child(node, True)

         else:

             # 隨機(jī)選擇一個(gè)節(jié)點(diǎn)返回
             sub_node = expand(node)
             return sub_node

    # 返回找到的最佳子節(jié)點(diǎn)
    return node

default_policy:對(duì)當(dāng)前情況進(jìn)行模擬,直到判斷出勝負(fù)。

策略為:輸入需要擴(kuò)展的結(jié)點(diǎn),隨機(jī)操作后 創(chuàng)建新的結(jié)點(diǎn),直到最后遇到葉子節(jié)點(diǎn),得到該次模擬的reward,然后將reward返回。

def default_policy(node): 
        # 獲取當(dāng)前點(diǎn)的環(huán)境狀態(tài)

        current_state = node.get_state() 

        # 如果沒(méi)有遇到葉子節(jié)點(diǎn),就一直循環(huán)
        while current_state.is_terminal() == False: 
                  # 隨機(jī)選取一個(gè)子節(jié)點(diǎn),返回新的環(huán)境參數(shù) 
                  current_state = current_state.get_next_state_with_random_choice()

        # 結(jié)束后,根據(jù)當(dāng)前的環(huán)境判斷勝負(fù),即獲得的reward值,并將其返回 
        final_state_reward = current_state.compute_reward()

        return final_state_reward

關(guān)于這個(gè)算法,我簡(jiǎn)單做了一個(gè)實(shí)現(xiàn),每次從數(shù)組[1, -1, 2, -2]之間隨機(jī)取一個(gè)數(shù)做累加,共累計(jì)MAX_DEPTH層,使最終的和最大,我們根據(jù)運(yùn)行結(jié)果可以看到,開(kāi)始-1, -2的概率比較大,但是隨著訓(xùn)練層數(shù)的增大,越來(lái)越小,而1,2的比例會(huì)越來(lái)越大。

import sys
import math
import random

MAX_CHOICE = 4
MAX_DEPTH = 50
CHOICES = [1, -1, 2, -2]

class State(object):
    def __init__(self):
        self.value = 0
        self.round = 0
        self.choices = []

    def new_state(self):
        choice = random.choice(CHOICES)
        state = State()
        state.value = self.value + choice
        state.round = self.round + 1
        state.choices = self.choices + [choice]

        return state

    def __repr__(self):
        return "State: {}, value: {}, choices: {}".format(
            hash(self), self.value, self.choices)

class Node(object):
    def __init__(self):
        self.parent = None
        self.children = []

        self.quality = 0.0
        self.visit = 0

        self.state = None

    def add_child(self, node):
        self.children.append(node)
        node.parent = self

    def __repr__(self):
        return "Node: {}, Q/N: {}/{}, state: {}".format(
            hash(self), self.quality, self.visit, self.state)

def expand(node):

    states = [nodes.state for nodes in node.children]
    state = node.state.new_state()

    while state in states:
        state = node.state.new_state()

    child_node = Node()
    child_node.state = state
    node.add_child(child_node)

    return child_node

# 選擇, 擴(kuò)展
def tree_policy(node):

    # 選擇是否是葉子節(jié)點(diǎn),
    while node.state.round < MAX_DEPTH:
        if len(node.children) < MAX_CHOICE:
            node = expand(node)
            return node
        else:
            node = best_child(node)

    return node

# 模擬
def default_policy(node):
    now_state = node.state
    while now_state.round < MAX_DEPTH:
        now_state = now_state.new_state()

    return now_state.value

def backup(node, reward):

    while node != None:
        node.visit += 1
        node.quality += reward
        node = node.parent

def best_child(node):

    best_score = -sys.maxsize
    best = None

    for sub_node in node.children:

        C = 1 / math.sqrt(2.0)
        left = sub_node.quality / sub_node.visit
        right = 2.0 * math.log(node.visit) / sub_node.visit
        score = left + C * math.sqrt(right)

        if score > best_score:
            best = sub_node
            best_score = score

    return best

def mcts(node):

    times = 5
    for i in range(times):

        expand = tree_policy(node)
        reward = default_policy(expand)
        backup(expand, reward)

    best = best_child(node)

    return best

def main():
    init_state = State()
    init_node = Node()
    init_node.state = init_state
    current_node = init_node

    for i in range(MAX_DEPTH):
        a = 0.0
        b = 0.0
        c = 0.0
        d = 0.0
        current_node = mcts(current_node)

        for j in range(len(current_node.state.choices)):
            if current_node.state.choices[j] == -2:
                a += 1
            if current_node.state.choices[j] == -1:
                b += 1
            if current_node.state.choices[j] == 1:
                c += 1
            if current_node.state.choices[j] == 2:
                d += 1
        print("-2的概率為", round(a/(i + 1.0), 2),
              "-1的概率為", round(b/(i + 1.0), 2),
              "1的概率為", round(c/(i + 1.0), 2),
              "2的概率為", round(d/(i + 1.0), 2))

if __name__ == "__main__":
    main()

運(yùn)行結(jié)果:


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容