Chapter06 TD(0) vs constant-alpha-MC

引用来自ShangtongZhang的代码chapter06/random_walk.py

通过一个例子比较了TD(0)和constant-α MC方法的训练性能

问题描述

通过一个MRP来比较constant-α MC Method和TD(0)方法之间的训练性能。

MRP的状态转移示意图:

png

其中state C 是起始状态,向右达到终点return=1,向左return=0,其余每一步reward=0

所以每个state的true value就是该state到达最右端端点的概率。A-E:1/6,2/6,..5/6

计算也很容易,通过迭代计算列出5个式子,P(s)代表从改点到达右端点的概率:

P(E)=1/2 + 1/2 * P(D)

P(D)=1/2 P(C) + 1/2 P(E)

P(C)=1/2 P(B) + 1/2 P(D)

P(B)=1/2 P(A) + 1/2 P(C)

P(A)=1/2 * P(B)

计算即可得到上面提到的答案。当然这个方法就是我们在Chapter03讲到的Bellman equation:

png

引入模块并定义常量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 6.2 Advantages of TD Prediction Methods

import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm

# 0 is the left terminal state
# 6 is the right terminal state
# 1 ... 5 represents A ... E
VALUES = np.zeros(7)
VALUES[1:6] = 0.5
# For convenience, we assume all rewards are 0
# and the left terminal state has value 0, the right terminal state has value 1
# This trick has been used in Gambler's Problem
VALUES[6] = 1

# set up true state values
TRUE_VALUE = np.zeros(7)
TRUE_VALUE[1:6] = np.arange(1, 6) / 6.0
TRUE_VALUE[6] = 1

ACTION_LEFT = 0
ACTION_RIGHT = 1

使用TD(0)方法update state-value

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# @values: current states value, will be updated if @batch is False
# @alpha: step size
# @batch: whether to update @values
def temporal_difference(values, alpha=0.1, batch=False):
state = 3
trajectory = [state]
# TD方法的rewards array是每一步action产生的reward
rewards = [0]
while True:
old_state = state
if np.random.binomial(1, 0.5) == ACTION_LEFT:
state -= 1
else:
state += 1
# Assume all rewards are 0
reward = 0
trajectory.append(state)
# TD update
if not batch:
values[old_state] += alpha * (reward + values[state] - values[old_state])
if state == 6 or state == 0:
break
rewards.append(reward)
return trajectory, rewards

使用Monte Carlo方法update state-value

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# @values: current states value, will be updated if @batch is False
# @alpha: step size
# @batch: whether to update @values
def monte_carlo(values, alpha=0.1, batch=False):
state = 3
trajectory = [3]

# if end up with left terminal state, all returns are 0
# if end up with right terminal state, all returns are 1
while True:
if np.random.binomial(1, 0.5) == ACTION_LEFT:
state -= 1
else:
state += 1
trajectory.append(state)
if state == 6:
# monte carlo方法的return是指从state s开始到结束产生的收益,即G_t
returns = 1.0
break
elif state == 0:
returns = 0.0
break

if not batch:
for state_ in trajectory[:-1]:
# MC update
values[state_] += alpha * (returns - values[state_])
# 因为中间状态的reward=0,所以所以trajectory上的state的return都是一样的,并且取决于最终结束状态
return trajectory, [returns] * (len(trajectory) - 1)

通过使用TD(0)方法建立state-value,并绘制收敛图像

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Example 6.2 left
def compute_state_value():
episodes = [0, 1, 10, 100]
current_values = np.copy(VALUES)
plt.figure(1)
for i in tqdm(range(episodes[-1] + 1)):
if i in episodes:
plt.plot(current_values, label=str(i) + ' episodes')
temporal_difference(current_values)
plt.plot(TRUE_VALUE, label='true values')
plt.xlabel('state')
plt.ylabel('estimated value')
plt.legend()

# test
# 可以看到结果逐渐收敛了
# compute_state_value()

计算并比较TD(0)和constant α-MC method的rms(均方差误差)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Example 6.2 right
def rms_error():
# Same alpha value can appear in both arrays
td_alphas = [0.15, 0.1, 0.05]
mc_alphas = [0.01, 0.02, 0.03, 0.04]
episodes = 100 + 1
runs = 100
# list 相加相当于将两个list首尾相接
for i, alpha in enumerate(td_alphas + mc_alphas):
total_errors = np.zeros(episodes)
if i < len(td_alphas):
method = 'TD'
linestyle = 'solid'
else:
method = 'MC'
linestyle = 'dashdot'
for r in tqdm(range(runs)):
errors = []
current_values = np.copy(VALUES)
for i in range(0, episodes):
errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0))
if method == 'TD':
temporal_difference(current_values, alpha=alpha)
else:
monte_carlo(current_values, alpha=alpha)
total_errors += np.asarray(errors)
total_errors /= runs
plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha))
plt.xlabel('episodes')
plt.ylabel('RMS')
plt.legend()

# test
# 可以看到TD方法收敛的更快,rms更低
# rms_error()

使用一般的方法进行state-value收敛并绘制图像

1
2
3
4
5
6
7
8
9
10
11
12
13
def example_6_2():
plt.figure(figsize=(10, 20))
plt.subplot(2, 1, 1)
compute_state_value()

plt.subplot(2, 1, 2)
rms_error()
plt.tight_layout()

plt.savefig('./example_6_2.png')
plt.show()

example_6_2()
100%|██████████| 101/101 [00:00<00:00, 7677.14it/s]
100%|██████████| 100/100 [00:00<00:00, 173.89it/s]
100%|██████████| 100/100 [00:00<00:00, 179.95it/s]
100%|██████████| 100/100 [00:00<00:00, 169.18it/s]
100%|██████████| 100/100 [00:00<00:00, 247.49it/s]
100%|██████████| 100/100 [00:00<00:00, 239.68it/s]
100%|██████████| 100/100 [00:00<00:00, 236.73it/s]
100%|██████████| 100/100 [00:00<00:00, 229.08it/s]

png

使用batch-update方法优化state-value的收敛过程并绘制图像

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Figure 6.2(Example 6.3)
# @method: 'TD' or 'MC'
def batch_updating(method, episodes, alpha=0.001):
# perform 100 independent runs
runs = 100
total_errors = np.zeros(episodes)
for r in tqdm(range(0, runs)):
current_values = np.copy(VALUES)
errors = []
# track shown trajectories and reward/return sequences
trajectories = []
rewards = []
for ep in range(episodes):
# batch=True将导致state-value不被更新
if method == 'TD':
trajectory_, rewards_ = temporal_difference(current_values, batch=True)
else:
trajectory_, rewards_ = monte_carlo(current_values, batch=True)
trajectories.append(trajectory_)
rewards.append(rewards_)
while True:
# keep feeding our algorithm with trajectories seen so far until state value function converges
updates = np.zeros(7)
for trajectory_, rewards_ in zip(trajectories, rewards):
# len(trajectory)-1表明不考虑终止状态
for i in range(0, len(trajectory_) - 1):
if method == 'TD':
updates[trajectory_[i]] += rewards_[i] + current_values[trajectory_[i + 1]] - current_values[trajectory_[i]]
else:
updates[trajectory_[i]] += rewards_[i] - current_values[trajectory_[i]]
updates *= alpha
if np.sum(np.abs(updates)) < 1e-3:
break
# perform batch updating
current_values += updates
# calculate rms error
errors.append(np.sqrt(np.sum(np.power(current_values - TRUE_VALUE, 2)) / 5.0))
total_errors += np.asarray(errors)
total_errors /= runs
return total_errors

def figure_6_2():
episodes = 100 + 1
td_erros = batch_updating('TD', episodes)
mc_erros = batch_updating('MC', episodes)

plt.plot(td_erros, label='TD')
plt.plot(mc_erros, label='MC')
plt.xlabel('episodes')
plt.ylabel('RMS error')
plt.legend()

plt.savefig('./figure_6_2.png')
plt.show()


figure_6_2()
100%|██████████| 100/100 [00:45<00:00,  2.29it/s]
100%|██████████| 100/100 [00:37<00:00,  2.49it/s]

png