深度強(qiáng)化學(xué)習(xí)DQN詳解CartPole(1)

一、 獲取并處理環(huán)境圖像

本文所刨析的代碼是“pytorch官網(wǎng)的DQN示例”(頁面),用卷積層配合強(qiáng)化訓(xùn)練去學(xué)習(xí)小車立桿,所使用的環(huán)境是“小車立桿環(huán)境”(CartPole)(源碼)。先劇透個悲觀的結(jié)果,官網(wǎng)的這個示例,并不能解決小車問題。單好消息是,一個簡單的改動,就可以讓結(jié)果好很多。

小車立桿環(huán)境

先import 各種:

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
from IPython import display

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

創(chuàng)建環(huán)境:

env = gym.make('CartPole-v0')

返回的這個env其實(shí)并非CartPole類本身,而是一個經(jīng)過包裝的環(huán)境:

env

<TimeLimit<CartPoleEnv<CartPole-v0>>>

據(jù)說gym的多數(shù)環(huán)境都用TimeLimit(源碼)包裝了,以限制Epoch,就是step的次數(shù)限制,比如限定為200次。所以小車保持平衡200步后,就會失敗。

env._max_episode_steps

200

用env.unwrapped可以得到原始的類,原始類想step多久就多久,不會200步后失?。?/p>

env.unwrapped

gym.envs.classic_control.cartpole.CartPoleEnv

環(huán)境 env 的 state 返回有4個變量:

env.state

array([0.00884328, 0.04488215, 0.00412898, 0.0128024 ])

它們分別是: (位置x,x加速度, 偏移角度theta, 角加速度)
初值值是4個[-0.05, 0.05)的隨機(jī)數(shù):

from gym.utils import seeding
np_random, seed = seeding.np_random(None)
np_random.uniform(low=-0.05, high=0.05, size=(4,))

其實(shí)就是:numpy.random.RandomState.uniform(low=-0.05, high=0.05, size=(4,))。uniform distribution是均勻分布,各個數(shù)的出場次數(shù)都大致相等。對應(yīng)的還有標(biāo)準(zhǔn)正態(tài)分布randn(),耿貝爾分布gumbel()等等。

這個環(huán)境的action有兩個 : 0 和 1

env.action_space.n

2

env.step(0) :小車向左
env.step(1) :小車向右

小車的世界,就一條x軸,變量env.x_threshold里存放著小車坐標(biāo)的最大值(=2.4),超過這個數(shù)值,世界結(jié)束,每step()一次,就會獎勵 1,直到上次done為True。這樣可以觀看到小車移動動畫:

env.reset()
for t in count(): 
    env.render()
    leftOrRight = random.randrange(env.action_space.n)
    _, reward, done, _ = env.step(leftOrRight)
    if done:
        break

有效世界的范圍是:[-x_threshold, x_threshold]。有效世界的總長度為4.8:

world_width = env.x_threshold * 2

可以用env.render()來繪制出這個有效世界,對應(yīng)的屏幕尺寸為高400、寬600:


400X600小車有效世界

世界坐標(biāo)0,是屏幕的中點(diǎn)(300處),世界轉(zhuǎn)屏幕系數(shù)為:

scale = screen_width / world_width

目前小車世界坐標(biāo)x,可以用state[0]取出,這樣通過scale,我們就可以計(jì)算出目前小車的屏幕坐標(biāo)了。用 x * scale 得到屏幕坐標(biāo)。

def get_cart_location(screen_width):
    #世界的總長度
    world_width = env.x_threshold * 2
    #世界轉(zhuǎn)屏幕系數(shù) : world_unit * scale = screen_unit
    scale = screen_width / world_width
    #世界中點(diǎn)在屏幕中間,所以偏移屏幕一半
    return int(env.state[0] * scale + screen_width / 2.0)

環(huán)境有個render函數(shù),可以繪制當(dāng)前場景。

  1. env = gym.make() 每個env有自己的繪制窗口
  2. 環(huán)境需要初始化env.reset()
  3. env.render()會打開一個繪制窗口,繪制當(dāng)前狀態(tài)
  4. 每次env.step()會更新狀態(tài)
  5. 用完以后需要調(diào)用env.close()關(guān)閉繪制窗口

render有一個參數(shù),如果指定為 mode='rgb_array'時,不但彈窗渲染,還會返回當(dāng)前窗口的像素值。整個開發(fā)過程,env自己的窗口都會一只存在,不用管它,每次render()它就會刷新,刷新完又“死”了。如果想隨時關(guān)掉,可以用close(),下次render()會自動打開。

env.reset()
screen = env.render(mode='rgb_array')
screen.shape

(400, 600, 3)

把screen畫出來看看:

plt.title('init state')
plt.imshow(screen)

小車大概在高40%(400X0.4=160)到80%(400X0.8=320)之間,所以整個畫面可以剪切一下。剪切前先調(diào)整一下圖片數(shù)據(jù)的順序,現(xiàn)在是 400高X600寬X3色,調(diào)整為 3色X400X600,便于后續(xù)往網(wǎng)絡(luò)里傳輸。numpy.transpose()函數(shù),可以指定新的維度順序,如(2,0,1) 就是將 維度 Y0,X1,C2調(diào)整為 C2,Y0,X1。 在pytorch里也有對應(yīng)的函數(shù),叫torch.Tensor.permute()。

def CutScreen(screen):
    Scr2 = screen.transpose((2, 0, 1))

再將高度按照160 - 320 截了:

    ScrCut = Scr2[:, int(screen_height*0.4):int(screen_height * 0.8)]

寬度只截取60%,左右各截30%:

    view_width = int(screen_width * 0.6)
    half_view_width = view_width // 2

如果小車左右還有30%的空間,則從小車位置前后截30%,如果小車太靠左則(或右則)沒有30%的空間,則從最左側(cè)(或最右側(cè))截取60%:

    cart_location = get_cart_location(screen_width)

    if cart_location < half_view_width:
        #太靠左了,左邊沒有30%空間,則從最左側(cè)截取  [:half_view_width)
        slice_range = slice(view_width) 

    elif cart_location > (screen_width - half_view_width):
        #太靠右了,同理 [-half_view_width:)
        slice_range = slice(-view_width, None)

    else:
        #左右兩側(cè)都有空間,則截小車在中間 [-half_view_width: +half_view_width)
        slice_range = slice(cart_location - half_view_width, cart_location + half_view_width)
    
    #最后將圖像X軸截了
    ScrCut = ScrCut[:, :, slice_range]
    return ScrCut

這樣截取函數(shù)就好了,看下截取出來的圖像。因?yàn)閜lt接受的是 (Y,X,顏色)所以我們還得把順序臨時調(diào)整回來:

C0 Y1 X2
Y1 X2 C0
CS = CutScreen(screen)
CS = CS.transpose((1, 2, 0))
plt.imshow(CS)

圖像還是太大,需要把圖片轉(zhuǎn)換為高40的圖片,可以用torchvision.transforms 的 Compose()和相關(guān)方法。
首先,目前為止,我們的screen都是numpy數(shù)組,需要用ToPILImage()轉(zhuǎn)換為PIL,python自帶圖像格式,然后才可以用torchvision去處理圖像,如 Resize(),最后記得轉(zhuǎn)換為pytorch使用的tensor格式:

resize = T.Compose([T.ToPILImage(),
                    T.Resize(40, interpolation=Image.CUBIC),
                    T.ToTensor()])

接下來獲取divice,以便pytorch可以使用顯卡GPU:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

定義我們的最終獲取圖像的處理過程:

def get_screen():
    screen = env.render(mode='rgb_array')
    screen = CutScreen(screen)

現(xiàn)在screen的格式是numpy數(shù)組,值范圍[0, 255],int8。而PIL接受的是float32的tensor,值范圍[0.0, 1.0],所以需要轉(zhuǎn)換一下,可以這樣:

    screen = torch.from_numpy(np.float32(screen)/255)

但是這樣會引起內(nèi)存數(shù)據(jù)拷貝。有一種inplace轉(zhuǎn)換數(shù)據(jù)類型的方法:

y = x.view('float32')

這樣y的內(nèi)存和x是一致的,修改y,也會修改掉x。但是這個函數(shù)有個要求,就是數(shù)據(jù)必須是contiguous的,而我們的screen,不是:

ValueError: To change to a dtype of a different size, the array must be C-contiguous

screen.flags

C_CONTIGUOUS : False
F_CONTIGUOUS : False
……

如果想要數(shù)據(jù)contiguous,就要用到ascontiguousarray()函數(shù),可以將內(nèi)存按照C方式對齊,這樣python就可以inplace轉(zhuǎn)換數(shù)據(jù)類了。所以示例用了這個方法,不用進(jìn)行內(nèi)存拷貝而達(dá)到同樣的效果:

    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)

最后放入resize。
因?yàn)閜ytorch.nn.Conv2d() 的輸入形式為(N, C, Y, X)
N表示batch數(shù)
C表示channel數(shù)
Y,X表示圖片的高和寬。
所以需要再增加一個N,最后再放入GPU:

    return resize(screen).unsqueeze(0).to(device)

unsqueeze()的作用是在n維之前增加一個維度,這里是在0維之前增加一個維度,增加前 screen尺寸是

torch.Size([3, 40, 90])

增加維度后,變?yōu)椋?/p>

torch.Size([1, 3, 40, 90])

再來實(shí)際看一下這個,尺寸等比縮小,高為40的圖片。想要plt get_screen()返回的東西,先要將其放回到CPU,然后去掉batch,調(diào)換方向把顏色放到后邊,再轉(zhuǎn)換為numpy:

scr = get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy()

plt.figure()
plt.imshow(scr)
plt.title('Example extracted screen')
plt.show()

40 X 90


OK。圖像處理完了,接下來要定義網(wǎng)絡(luò),訓(xùn)練網(wǎng)絡(luò)了。
第二部分(連接)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容