Chapter08 expectation vs sample

引用来自ShangtongZhang的代码chapter08/expectation_vs_sample.py

通过一个简单的示例表现了使用expected 和sample update训练产生的相对误差

引入模块

1
2
3
4
5
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm

计算next state的sample和expect的均方误差

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# for figure 8.8, run a simulation of 2 * @b steps
def b_steps(b):
# set the value of the next b states
# it is not clear how to set this
# distribution 是维度为b的正态分布array
distribution = np.random.randn(b)

# true value of the current state
true_v = np.mean(distribution)

samples = []
errors = []

# sample 2b steps
for t in range(2 * b):
v = np.random.choice(distribution)
samples.append(v)
errors.append(np.abs(np.mean(samples) - true_v))

return errors

绘制图线,表征均方误差随抽样次数的变化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def figure_8_8():
runs = 100
branch = [2, 10, 100, 1000]
for b in branch:
errors = np.zeros((runs, 2 * b))
for r in tqdm(np.arange(runs)):
errors[r] = b_steps(b)
errors = errors.mean(axis=0)
x_axis = (np.arange(len(errors)) + 1) / float(b)
plt.plot(x_axis, errors, label='b = %d' % (b))

plt.xlabel('number of computations')
plt.xticks([0, 1.0, 2.0], ['0', 'b', '2b'])
plt.ylabel('RMS error')
plt.legend()

plt.savefig('./figure_8_8.png')
plt.show()

figure_8_8()
100%|██████████| 100/100 [00:00<00:00, 9047.05it/s]
100%|██████████| 100/100 [00:00<00:00, 2282.52it/s]
100%|██████████| 100/100 [00:00<00:00, 176.76it/s]
100%|██████████| 100/100 [00:20<00:00,  4.89it/s]

png

通过这个示例,可以看到expected update可以避免偏差,sample相应的误差就凸显出来了,but no free lunch,获得更好的训练效果自然需要付出相应的计算资源消耗。