Chapter04 gird-world using DP

引用来自ShangtongZhang的代码chapter04/gird_world.py

grid_world(policy evaluation)

问题描述

4X4 网格:

png

左上角和右下角是终止状态(terminal state),如果action使得state跳转到外面了,就返回上次位置,所有的action造成的reward都是-1。

引入模块并定义常量

1
2
3
4
5
6
7
8
9
10
11
12
13
14

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.table import Table
%matplotlib inline

WORLD_SIZE = 4
# left, up, right, down
ACTIONS = [np.array([0, -1]),
np.array([-1, 0]),
np.array([0, 1]),
np.array([1, 0])]
ACTION_PROB = 0.25

action控制代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# judge whether comes to terminal state
def is_terminal(state):
x, y = state
return (x == 0 and y == 0) or (x == WORLD_SIZE - 1 and y == WORLD_SIZE - 1)

# return the next_state s' and reward r
def step(state, action):
state = np.array(state)
next_state = (state + action).tolist()
x, y = next_state

if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE:
next_state = state.tolist()

reward = -1
return next_state, reward

绘制方格图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def draw_image(image):
fig, ax = plt.subplots()
ax.set_axis_off()
tb = Table(ax, bbox=[0, 0, 1, 1])

nrows, ncols = image.shape
width, height = 1.0 / ncols, 1.0 / nrows

# Add cells
for (i,j), val in np.ndenumerate(image):
# Index either the first or second item of bkg_colors based on
# a checker board pattern
idx = [j % 2, (j + 1) % 2][i % 2]
color = 'white'

tb.add_cell(i, j, width, height, text=val,
loc='center', facecolor=color)

# Row Labels...
for i, label in enumerate(range(len(image))):
tb.add_cell(i, -1, width, height, text=label+1, loc='right',
edgecolor='none', facecolor='none')
# Column Labels...
for j, label in enumerate(range(len(image))):
tb.add_cell(-1, j, width, height/2, text=label+1, loc='center',
edgecolor='none', facecolor='none')
ax.add_table(tb)

进行policy evaluation,即计算value-state

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def compute_state_value(in_place=False):
new_state_values = np.zeros((WORLD_SIZE, WORLD_SIZE))
state_values = new_state_values.copy()
iteration = 1
while True:
# in place algorithm is faster than 2-array edition
src = new_state_values if in_place else state_values
for i in range(WORLD_SIZE):
for j in range(WORLD_SIZE):
if is_terminal([i, j]):
continue
value = 0
for action in ACTIONS:
(next_i, next_j), reward = step([i, j], action)
value += ACTION_PROB * (reward + src[next_i, next_j])
new_state_values[i, j] = value
if np.sum(np.abs(new_state_values - state_values)) < 1e-4:
state_values = new_state_values.copy()
break

state_values = new_state_values.copy()
iteration += 1

return state_values, iteration

运行并显示评估结果

1
2
3
4
5
6
7
8
9
10
11
12
def figure_4_1():
values, sync_iteration = compute_state_value(in_place=False)
_, asycn_iteration = compute_state_value(in_place=True)
# show the sync DP evaluation
draw_image(np.round(values, decimals=2))
print('In-place: %d iterations' % (asycn_iteration))
print('Synchronous: %d iterations' % (sync_iteration))

plt.savefig('./figure_4_1.png')
plt.show()

figure_4_1()
In-place: 142 iterations
Synchronous: 218 iterations

png