引用来自ShangtongZhang的代码chapter13/short_corridor.py
使用一个只含3个non-terminal state的简单grid-world问题讨论参数化policy算法的性能。
问题描述
每个reward=-1,第二个state的action是反转的,即left->right state, right->left state
引入模块,并给出第一个state的true value计算公式(based on Bellman equation)
1 | import numpy as np |
创建环境类,实现agent和环境之间的交互
1 | class ShortCorridor: |
创建agent类,实现参数化policy算法
1 | class ReinforceAgent: |
构造使用baseline的agent类,实现approximation policy with baseline算法
1 | class ReinforceBaselineAgent(ReinforceAgent): |
完成一个episode的训练
1 | def trial(num_episodes, agent_generator): |
1 |
|
100%|██████████| 30/30 [00:23<00:00, 1.18it/s]
100%|██████████| 30/30 [00:23<00:00, 1.26it/s]
100%|██████████| 30/30 [00:29<00:00, 1.09it/s]