Chapter07 random walk 19-states

引用来自ShangtongZhang的代码chapter07/random_walk.py

n-step TD方法在random-walk问题上的应用

问题描述

本例通过将不同step的TD方法应用在Chapter06的random-walk问题中,不过将原来的5-state问题修改为了19-state问题。以此来对比不同的n-step算法的性能。

引入模块并定义常量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm

# 除了终止state以外的19个state
N_STATES = 19

# discount
GAMMA = 1

# all states but terminal states
STATES = np.arange(1, N_STATES + 1)

# start from the middle state
# 0和20是终止state,所以有效state是1-19
START_STATE = 10

# two terminal states
# an action leading to the left terminal state has reward -1
# an action leading to the right terminal state has reward 1
END_STATES = [0, N_STATES + 1]

# 通过bellman equation计算得到真实的V*,这个方法在下面两个地方都有用到:
# https://xinge650.github.io/2018/11/22/Chapter03-gird-world/
# https://xinge650.github.io/2018/12/01/Chapter06-TD-0-vs-constant-alpha-MC/
TRUE_VALUE = np.arange(-20, 22, 2) / 20.0
TRUE_VALUE[0] = TRUE_VALUE[-1] = 0

使用n-step TD方法来对policy π进行predict

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# n-steps TD method
# @value: values for each state, will be updated
# @n: # of steps
# @alpha: # step size
def temporal_difference(value, n, alpha):
# initial starting state
state = START_STATE

# arrays to store states and rewards for an episode
# space isn't a major consideration, so I didn't use the mod trick
states = [state]
rewards = [0]

# track the time
time = 0

# the length of this episode
T = float('inf')
while True:
# go to next time step
# time总是next state的索引
time += 1

if time < T:
# choose an action randomly
if np.random.binomial(1, 0.5) == 1:
next_state = state + 1
else:
next_state = state - 1

if next_state == 0:
reward = -1
elif next_state == 20:
reward = 1
else:
reward = 0

# store new state and new reward
states.append(next_state)
rewards.append(reward)

if next_state in END_STATES:
# T记录的是terminal state的time-index
# 可以看到n-step TD方法也和MC方法一样需要等待整个episode结束,但是Q的更新是在结束前就开始的
# 所以这种方法仍然比MC方法要快
T = time

# get the time of the state to update
update_time = time - n
if update_time >= 0:
returns = 0.0
# calculate corresponding rewards
for t in range(update_time + 1, min(T, update_time + n) + 1):
returns += pow(GAMMA, t - update_time - 1) * rewards[t]
# add state value to the return
# 如果update_time+n>T,则n-step TD算法退化为constant-α算法,即使用estimate of return来作为target
if update_time + n <= T:
returns += pow(GAMMA, n) * value[states[(update_time + n)]]
state_to_update = states[update_time]
# update the state value
if not state_to_update in END_STATES:
value[state_to_update] += alpha * (returns - value[state_to_update])
if update_time == T - 1:
break
state = next_state

绘制图表,通过图表观察不同n的算法的性能

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42



# Figure 7.2, it will take quite a while
def figure_7_2():
# all possible n steps
steps = np.power(2, np.arange(0, 10))

# all possible alphas
alphas = np.arange(0, 1.1, 0.1)

# each run has 10 episodes
episodes = 10

# perform 100 independent runs
runs = 100

# track the errors for each (step, alpha) combination
errors = np.zeros((len(steps), len(alphas)))
for run in tqdm(range(0, runs)):
for step_ind, step in zip(range(len(steps)), steps):
for alpha_ind, alpha in zip(range(len(alphas)), alphas):
# print('run:', run, 'step:', step, 'alpha:', alpha)
value = np.zeros(N_STATES + 2)
for ep in range(0, episodes):
temporal_difference(value, step, alpha)
# calculate the RMS error
errors[step_ind, alpha_ind] += np.sqrt(np.sum(np.power(value - TRUE_VALUE, 2)) / N_STATES)
# take average
errors /= episodes * runs

for i in range(0, len(steps)):
plt.plot(alphas, errors[i, :], label='n = %d' % (steps[i]))
plt.xlabel('alpha')
plt.ylabel('RMS error')
plt.ylim([0.25, 0.55])
plt.legend()

plt.savefig('./figure_7_2.png')
plt.show()

figure_7_2()
100%|██████████| 100/100 [08:19<00:00,  4.90s/it]

png