Chapter08 trajectory_sampling

引用来自ShangtongZhang的代码chapter08/trajectory_sampling.py

通过一个MDP的例子比较了均匀采样和on-policy采样的性能

问题描述

问题建立在一个MDP上,首先假设一个MDP有n个state,其中[0,n)是一般状态,n是terminal-state,每个状态有两个action=0;1。每个action都用epsilon的概率使state跳转到terminal-state;如果没有跳转到terminal-state,action会导致state等概率的跳转到b个branch对应的state上,同时每个branch对应的state也是|S|上等概率的,所以MDP示意图如下:

png

引入模块并定义常量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm

# 2 actions
ACTIONS = [0, 1]

# each transition has a probability to terminate with 0
TERMINATION_PROB = 0.1

# maximum expected updates
MAX_STEPS = 20000

# epsilon greedy for behavior policy
EPSILON = 0.1

定义argmax函数,返回value中最大值的任一个索引

1
2
3
4
# break tie randomly
def argmax(value):
max_q = np.max(value)
return np.random.choice([a for a, q in enumerate(value) if q == max_q])

定义Task类,用来模拟agent和环境的交互

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Task():
# @n_states: number of non-terminal states
# @b: branch
# Each episode starts with state 0, and state n_states is a terminal state
def __init__(self, n_states, b):
self.n_states = n_states
self.b = b

# transition matrix, each state-action pair leads to b possible states
# 返回size尺寸的range(n_states)内的随机数
# 这里根据self.transition来选定next_state这样理解的:
# 首先选定state和action,因为每个action都有b种分支branch的选择,每种branch又随机对应一种|S|中的state
# 所以在step函数里根据state和action选择next_state的时候,需要先random选择一个branch,然后这个branch对应的
# state再由self.transition给出
self.transition = np.random.randint(n_states, size=(n_states, len(ACTIONS), b))

# it is not clear how to set the reward, I use a unit normal distribution here
# reward is determined by (s, a, s')
self.reward = np.random.randn(n_states, len(ACTIONS), b)

def step(self, state, action):
if np.random.rand() < TERMINATION_PROB:
# 直接跳转到terminal状态,并给reward=0
return self.n_states, 0
next = np.random.randint(self.b)
return self.transition[state, action, next], self.reward[state, action, next]

计算使用基于value function的greedy policy得到的平均episode-reward

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Evaluate the value of the start state for the greedy policy
# derived from @q under the MDP @task
def evaluate_pi(q, task):
# use Monte Carlo method to estimate the state value
runs = 1000
returns = []
for r in range(runs):
rewards = 0
state = 0
while state < task.n_states:
action = argmax(q[state])
state, r = task.step(state, action)
rewards += r
returns.append(rewards)
return np.mean(returns)

使用expected update更新value function,使用一致采样

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# perform expected update from a uniform state-action distribution of the MDP @task
# evaluate the learned q value every @eval_interval steps
# 根据给定的评估间隔,在expected update中穿插着reward的评估,来对更新算法进行评估
def uniform(task, eval_interval):
performance = []
q = np.zeros((task.n_states, 2))
for step in tqdm(range(MAX_STEPS)):
# //表示整除
# 这里可以看到所有的state-action都会被sample,只要MAX_STEPS足够大
state = step // len(ACTIONS) % task.n_states
action = step % len(ACTIONS)

# next_states是一个维度为(b,)的array
next_states = task.transition[state, action]
# task.reward[state,action]是一个维度为(b,)的array
# q[next_states,:]是维度为(b,2)的array,np.max操作之后就是(b,)维度了
# 这个更新公式是expected update的,P141
# 这里还有一个细节,就是真实的概率应该是(1-TERMINATION_PROB)/b,但是np.mean里隐含了1/b
q[state, action] = (1 - TERMINATION_PROB) * np.mean(
task.reward[state, action] + np.max(q[next_states, :], axis=1))

if step % eval_interval == 0:
# 如果到评估间隔了,则进行一次greedy policy更新
v_pi = evaluate_pi(q, task)
performance.append([step, v_pi])

return zip(*performance)

使用基于on-policy的采样方法进行value function的更新

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# perform expected update from an on-policy distribution of the MDP @task
# evaluate the learned q value every @eval_interval steps
def on_policy(task, eval_interval):
performance = []
q = np.zeros((task.n_states, 2))
# 可以看到两种更新方法都是从state=0开始训练的
state = 0
# 可以看到采样的时候没有使用uniform的均匀采样,而是根据value function来生成episode,并紧随着state的迭代来更新
for step in tqdm(range(MAX_STEPS)):
if np.random.rand() < EPSILON:
action = np.random.choice(ACTIONS)
else:
action = argmax(q[state])

next_state, _ = task.step(state, action)

# value function的更新和uniform()的更新是一样的
next_states = task.transition[state, action]
q[state, action] = (1 - TERMINATION_PROB) * np.mean(
task.reward[state, action] + np.max(q[next_states, :], axis=1))

if next_state == task.n_states:
next_state = 0
state = next_state

if step % eval_interval == 0:
v_pi = evaluate_pi(q, task)
performance.append([step, v_pi])

return zip(*performance)

通过图像来比较两种采样方法的性能

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def figure_8_9():
num_states = [1000, 10000]
branch = [1, 3, 10]
methods = [on_policy, uniform]

# average accross 30 tasks
# woc这总共要跑2*3*2*30=360个max_steps,太变态了。。。
# n_tasks = 30
n_tasks = 5
# number of evaluation points
x_ticks = 100

plt.figure(figsize=(10, 20))
for i, n in enumerate(num_states):
plt.subplot(2, 1, i+1)
for b in branch:
tasks = [Task(n, b) for _ in range(n_tasks)]
for method in methods:
value = []
for task in tasks:
steps, v = method(task, MAX_STEPS / x_ticks)
value.append(v)
value = np.mean(np.asarray(value), axis=0)
plt.plot(steps, value, label='b = %d, %s' % (b, method.__name__))
plt.title('%d states' % (n))

plt.ylabel('value of start state')
plt.legend()

plt.subplot(2, 1, 2)
plt.xlabel('computation time, in expected updates')

plt.savefig('./figure_8_9.png')
plt.show()

figure_8_9()
100%|██████████| 20000/20000 [00:23<00:00, 863.09it/s]
100%|██████████| 20000/20000 [00:22<00:00, 882.61it/s]
100%|██████████| 20000/20000 [00:22<00:00, 873.29it/s]
100%|██████████| 20000/20000 [00:22<00:00, 884.96it/s]
100%|██████████| 20000/20000 [00:22<00:00, 870.28it/s]
100%|██████████| 20000/20000 [00:22<00:00, 888.02it/s]
100%|██████████| 20000/20000 [00:22<00:00, 891.23it/s]
100%|██████████| 20000/20000 [00:22<00:00, 883.12it/s]
100%|██████████| 20000/20000 [00:22<00:00, 885.09it/s]
100%|██████████| 20000/20000 [00:23<00:00, 864.50it/s]
100%|██████████| 20000/20000 [00:23<00:00, 833.70it/s]
100%|██████████| 20000/20000 [00:23<00:00, 847.89it/s]
100%|██████████| 20000/20000 [00:22<00:00, 872.53it/s]
100%|██████████| 20000/20000 [00:22<00:00, 879.40it/s]
100%|██████████| 20000/20000 [00:23<00:00, 868.80it/s]
100%|██████████| 20000/20000 [00:22<00:00, 906.87it/s]
100%|██████████| 20000/20000 [00:22<00:00, 876.36it/s]
100%|██████████| 20000/20000 [00:22<00:00, 890.05it/s]
100%|██████████| 20000/20000 [00:22<00:00, 888.76it/s]
100%|██████████| 20000/20000 [00:21<00:00, 910.71it/s]
100%|██████████| 20000/20000 [00:22<00:00, 874.98it/s]
100%|██████████| 20000/20000 [00:22<00:00, 896.89it/s]
100%|██████████| 20000/20000 [00:22<00:00, 884.20it/s]
100%|██████████| 20000/20000 [00:22<00:00, 904.25it/s] 
100%|██████████| 20000/20000 [00:21<00:00, 924.97it/s]
100%|██████████| 20000/20000 [00:22<00:00, 881.82it/s]
100%|██████████| 20000/20000 [00:22<00:00, 904.37it/s]
100%|██████████| 20000/20000 [00:22<00:00, 896.47it/s]
100%|██████████| 20000/20000 [00:21<00:00, 909.28it/s]
100%|██████████| 20000/20000 [00:22<00:00, 884.27it/s]
100%|██████████| 20000/20000 [00:23<00:00, 857.58it/s]
100%|██████████| 20000/20000 [00:23<00:00, 861.50it/s]
100%|██████████| 20000/20000 [00:23<00:00, 858.42it/s]
100%|██████████| 20000/20000 [00:22<00:00, 876.96it/s]
100%|██████████| 20000/20000 [00:22<00:00, 872.85it/s]
100%|██████████| 20000/20000 [00:22<00:00, 900.74it/s]
100%|██████████| 20000/20000 [00:22<00:00, 889.23it/s]
100%|██████████| 20000/20000 [00:22<00:00, 896.29it/s]
100%|██████████| 20000/20000 [00:22<00:00, 891.87it/s]
100%|██████████| 20000/20000 [00:22<00:00, 879.51it/s]
100%|██████████| 20000/20000 [00:23<00:00, 854.93it/s]
100%|██████████| 20000/20000 [00:23<00:00, 853.38it/s]
100%|██████████| 20000/20000 [00:23<00:00, 851.27it/s]
100%|██████████| 20000/20000 [00:23<00:00, 839.56it/s]
100%|██████████| 20000/20000 [00:23<00:00, 868.91it/s]
100%|██████████| 20000/20000 [00:23<00:00, 858.23it/s]
100%|██████████| 20000/20000 [00:22<00:00, 869.78it/s]
100%|██████████| 20000/20000 [00:22<00:00, 883.71it/s]
100%|██████████| 20000/20000 [00:22<00:00, 894.70it/s]
100%|██████████| 20000/20000 [00:22<00:00, 886.71it/s]
100%|██████████| 20000/20000 [00:23<00:00, 868.78it/s]
100%|██████████| 20000/20000 [00:22<00:00, 882.62it/s]
100%|██████████| 20000/20000 [00:22<00:00, 887.52it/s]
100%|██████████| 20000/20000 [00:22<00:00, 881.84it/s]
100%|██████████| 20000/20000 [00:22<00:00, 878.90it/s]
100%|██████████| 20000/20000 [00:22<00:00, 902.32it/s]
100%|██████████| 20000/20000 [00:22<00:00, 901.14it/s]
100%|██████████| 20000/20000 [00:22<00:00, 882.30it/s]
100%|██████████| 20000/20000 [00:22<00:00, 872.38it/s]
100%|██████████| 20000/20000 [00:22<00:00, 894.51it/s]

png