从 Q 表到函数逼近
一直想要系统学习RL的相关知识,脑子里想着能不能完全依靠 Codex 进行学习,于是就有了下面的教程。
完整内容:https://github.com/lanheixingkong/rl-course
前几课里,我们一直用表格保存价值。
第三课Q-learning,在不知道环境模型时学习行动价值 里,代码保存的是:
Q(s,a)
第四课SARSA vs Q-learning,看懂 on-policy 和 off-policy里,虽然 target 不同,但两者学习的仍然是:
Q(s,a)
到目前为止,Q(s,a) 都是用表格存的:
q[state][action] = 一个数字
这在小地图里很清楚。但真实问题往往不能这么做。
如果状态很多,甚至是连续的,Q 表会变得非常大,或者根本无法枚举。这时就需要进入一个新阶段:
function approximation,函数逼近
这节课先不直接上神经网络,而是用最简单的线性模型过渡。
1. 本课问题
表格 Q-learning 的逻辑是:
每个 state/action 单独存一个 Q 值
例如:
Q((7,0), U)
Q((7,0), R)
Q((6,0), U)
Q((6,0), R)
都是不同格子里的数字。
更新其中一个值,不会直接改变另一个值。
这很容易理解,但问题是:
如果 state 很多怎么办?
假设有:
10000 个状态
10 个动作
表格方法就要存:
10000 * 10 = 100000 个 Q 值
如果状态是连续值,比如:
位置
速度
角度
温度
库存数量
用户行为向量
那状态可能根本不能被完整列出来。
所以第五课要解决的问题是:
如果不能给每个 state/action 单独存 Q 值,
能不能用一个函数来估计 Q(s,a)?
2. 本课案例
这一课仍然使用 GridWorld,但地图变大:
8 x 8 GridWorld
起点在左下角:
start = (7,0)
终点在右上角:
goal = (0,7)
地图中还有墙和坑。
代码会训练两个 agent:
|
|
|
|
|---|---|---|
table |
|
(state, action) 单独存值
|
linear |
|
Q(s,a)
|
这节课不是为了证明线性方法一定比表格方法好。
它的目的更具体:
让你看懂 Q(s,a) 不一定来自表格,
也可以来自一个可训练的函数。
3. 运行实验
在项目根目录运行:
python lessons/05_function_approximation/linear_q_gridworld.py
你会看到类似输出:
Function Approximation GridWorld config:
agent : both
rows x cols : 8 x 8
episodes : 1000
max_steps : 120
alpha_table : 0.2
alpha_linear : 0.05
gamma : 0.95
epsilon : 0.2
step_reward : -0.04
slip_probability : 0.0
seed : 0
start : (7, 0)
goal : (0, 7)
pits : [(6, 6), (7, 6)]
walls : 5 cells
table episode 1000 | avg return last 100: 0.343 | avg steps: 17.4 | success rate: 1.00
linear episode 1000 | avg return last 100: 0.270 | avg steps: 17.3 | success rate: 0.96
table summary:
parameters stored : 224
table greedy reachability:
from start (7, 0): goal in 14 steps
from all non-terminal states: 49/56 reach goal
linear summary:
parameters stored : 10
linear greedy reachability:
from start (7, 0): goal in 14 steps
from all non-terminal states: 56/56 reach goal
4. 先看输出
config
agent = both 表示同一次运行里训练两种表示方式:
table
linear
rows x cols = 8 x 8 表示地图大小。
alpha_table 和 alpha_linear 是两个学习率。它们分开设置,是因为表格方法和函数逼近对学习率的敏感程度不同。
训练日志
table episode 1000 | avg return last 100: 0.343 | avg steps: 17.4 | success rate: 1.00
linear episode 1000 | avg return last 100: 0.270 | avg steps: 17.3 | success rate: 0.96
含义和前几课一样:
|
|
|
|---|---|
avg return last 100 |
|
avg steps |
|
success rate |
|
这里不要只看谁高一点。
本课更重要的是下面这个指标:
parameters stored
parameters stored
表格方法输出:
parameters stored : 224
意思是:
它存了 224 个独立 Q 值。
线性方法输出:
parameters stored : 10
意思是:
它只存了 10 个共享权重。
这个对比是本课重点。
表格方法的参数量随状态和动作增长。
线性方法的参数量取决于特征数量。
success rate 和 greedy reachability
输出里还有一个容易误解的地方:
success rate last 100
greedy reachability
它们不是同一个指标。
success rate last 100 统计的是训练过程:
最近 100 个训练 episode,
从固定 start 出发,
有多少次到达 goal。
它只说明“从默认起点出发,训练时表现如何”。
它不说明“从地图上每个格子出发,最终 greedy policy 是否都能到 goal”。
坐标的读法是:
(row, col)
row 从上到下数;
col 从左到右数。
所以 (0,0) 是左上角,(7,0) 是左下角,(0,7) 是右上角 goal。
所以可能出现:
table success rate last 100 = 1.00
但 table greedy reachability = 49/56
这不是矛盾。
它说明表格方法已经把从 start = (7,0) 到 goal = (0,7) 的常用路径学好了,所以最近 100 局都成功;但一些不在常用路径上、很少被访问的格子,Q 表里对应动作还没学好。
表格方法不会自动把一个格子的经验推广到另一个格子。
失败预览显示的是“从这个格子出发,按 greedy policy 一步步走下去的路径”,不是只显示第一步动作。
例如:
(6, 5)->(7, 5)->(7, 5)->loop
含义是:
(6,5) 的动作是 D,走到 (7,5);
(7,5) 的动作也是 D,但已经在最下面一行;
继续向下会撞边界,留在 (7,5),于是形成 loop。
再比如:
(6, 7)->(6, 7)->loop
因为 (6,7) 的动作是 R,但它已经在最右一列,向右会撞边界并留在原地。
也可能出现:
linear success rate last 100 = 0.96
但 linear greedy reachability = 56/56
这也不是矛盾。
训练时使用的是 epsilon-greedy:
epsilon = 0.2
也就是说,训练过程中仍然有 20% 概率随机探索。随机探索可能掉坑、绕路,导致训练 episode 失败。
但 greedy reachability 检查的是训练结束后的最终 greedy policy:
不再随机探索;
每个状态都选择当前 Q 值最大的动作。
所以训练成功率可以不是 100%,但最终 greedy policy 仍然能从所有非终止格子走到 goal。
线性方法还有一个额外原因:它使用共享特征和权重,容易学出全局规则,例如:
先往上靠近第一行;
或者往右靠近最后一列;
再走向 goal。
这种规则会推广到很多没有被充分单独训练过的格子。这就是函数逼近的泛化。
5. 表格方法到底在存什么
代码里的表格类是:
class TabularQ:
初始化时:
self.q = {
state: {action: 0.0 for action in ACTIONS}
for state in env.states()
if not env.is_terminal(state)
}
这表示:
每个非终止 state
每个 action
都存一个独立的数字
读取 Q 值:
return self.q[state][action]
更新 Q 值:
self.q[state][action] = old_q + alpha * td_error
这很像在 Excel 表格里改一个单元格。
它的优点是:
清楚;
稳定;
每个 Q 值相互独立;
容易 debug。
缺点是:
状态越多,表越大;
没见过的状态没有经验;
连续状态无法直接完整建表。
6. 线性函数在做什么
线性方法的类是:
class LinearQ:
它不存每个 (state, action) 的 Q 值,而是存一组权重:
self.weights = {name: 0.0 for name in self.feature_names}
代码里使用的特征包括:
|
|
|
|---|---|
bias |
|
row |
|
col |
|
goal_closeness |
|
pit_closeness |
|
bump |
|
action_U/R/D/L |
|
逐行解释 features()
features() 的作用是:
把一个具体的 (state, action)
转换成一组数字特征,
让线性模型可以用这些数字估计 Q(s,a)。
第一步:
next_state = self.env.next_state(state, action)
这里不是只看当前 state,而是先计算:
如果在 state 执行 action,会走到哪里?
所以很多特征描述的是 next_state。例如:
state = (7,0)
action = U
next_state = (6,0)
接下来:
max_distance = max(1, (self.env.rows - 1) + (self.env.cols - 1))
这是地图中最大可能的曼哈顿距离,用来把距离缩放到大约 0~1 范围。max(1, ...) 是防御性写法,避免除以 0。
goal_distance = manhattan(next_state, self.env.goal)
计算 next_state 到 goal 的曼哈顿距离:
manhattan((r1,c1), (r2,c2)) = |r1-r2| + |c1-c2|
nearest_pit_distance = min((manhattan(next_state, pit) for pit in self.env.pits), default=max_distance)
计算 next_state 到最近 pit 的距离。如果地图里没有 pit,就使用 max_distance 作为默认值。
然后开始构造特征:
"bias": 1.0
固定为 1,相当于模型的基础分。
"row": next_state[0] / max(1, self.env.rows - 1)
"col": next_state[1] / max(1, self.env.cols - 1)
表示下一格的行、列位置,并归一化到大约 0~1。
在 8x8 地图里:
row = 0 / 7 = 0.0 表示顶部
row = 7 / 7 = 1.0 表示底部
col = 0 / 7 = 0.0 表示最左边
col = 7 / 7 = 1.0 表示最右边
"goal_closeness": 1.0 - goal_distance / max_distance
表示下一格离 goal 有多近。越接近 goal,这个值越大。
"pit_closeness": 1.0 - nearest_pit_distance / max_distance
表示下一格离最近 pit 有多近。越接近 pit,这个值越大。它本身不表示好坏,具体好坏由训练学出来的权重决定。通常靠近 pit 是坏事,所以它的权重可能会学成负数。
"bump": 1.0 if next_state == state else 0.0
表示是否撞墙或撞边界。如果动作执行后位置没变,bump = 1.0。
"action_U": 1.0 if action == "U" else 0.0
action_U/R/D/L 是动作的 one-hot 表示。例如动作是 U:
action_U = 1.0
action_R = 0.0
action_D = 0.0
action_L = 0.0
所以,features() 自己不学习。它只是把 (state, action) 变成模型能使用的一组数字。真正学习的是 weights。
然后用这些特征和权重计算 Q 值:
return sum(self.weights[name] * value for name, value in self.features(state, action).items())
也就是:
Q(s,a) = w1*x1 + w2*x2 + ... + wk*xk
这里:
x = features
w = weights
这就是最简单的函数逼近。
7. 线性方法怎么学习
训练循环仍然是 Q-learning:
target = reward + gamma * next_best
td_error = target - old_q
区别在更新方式。
表格方法更新一格:
q[state][action]
线性方法更新一组权重:
self.weights[name] += alpha * td_error * feature_value
这句话可以这样理解:
如果这次估计偏低,td_error 为正,就把相关特征的权重往上调;
如果这次估计偏高,td_error 为负,就把相关特征的权重往下调;
某个特征值越大,它受到这次更新的影响越大。
所以线性方法不是记住一个格子,而是在调整一套通用规则。
8. 为什么这叫泛化
表格方法更新:
Q((7,0), U)
通常只影响这一格。
线性方法更新:
weights
这些权重会参与很多状态动作的 Q 值计算。
例如 goal_closeness 的权重变了,所有“离 goal 更近或更远”的动作估计都会受到影响。
这就是泛化:
一次经验,不只影响当前 state/action,
还会影响具有相似特征的其他 state/action。
泛化是函数逼近的优势,也是风险。
9. 函数逼近的优势
函数逼近的优势:
参数更少;
能处理更大的状态空间;
能从相似状态中共享经验;
能处理连续状态;
可以进一步换成神经网络。
这就是为什么它是深度 RL 的入口。
如果没有函数逼近,就很难把 Q-learning 扩展到图像、复杂游戏、机器人控制这类问题。
10. 函数逼近的代价
函数逼近也有代价。
表格方法里,一个 Q 值错了,通常只影响一个格子。
线性方法里,一个权重错了,会影响很多估计。
所以它可能出现:
估计互相干扰;
训练更不稳定;
学习率更敏感;
特征设计不好时学不到好策略;
结果不如 Q 表容易解释。
这就是为什么后面的 DQN 不是简单地把 Q 表换成神经网络就完了。
DQN 还需要:
experience replay
target network
这些技巧都是为了让函数逼近下的 Q-learning 更稳定。
这一课的线性方法也可能出现类似提醒:训练更久不一定总是更好。因为共享权重会让不同状态动作互相影响,后面的更新可能破坏前面已经学到的估计。这不是本课代码特有的问题,而是函数逼近进入 RL 后需要认真处理的稳定性问题。
11. 查看单步更新
运行:
python lessons/05_function_approximation/linear_q_gridworld.py --episodes 3 --debug-episodes 1 --log-every 0 --max-steps 15 --show-weights
你会看到:
table episode 1:
step state action reward next old Q target td err new Q
linear episode 1:
step state action reward next old Q target td err new Q
linear weights:
bias : ...
row : ...
col : ...
goal_closeness: ...
两种方法的 debug 表看起来很像,因为它们都在做:
old_q -> target -> td_error -> new_q
但 new_q 的来源不同:
table:
new_q 来自 q[state][action] 这一格被改掉
linear:
new_q 来自 weights 被改掉后重新计算
12. 参数实验
实验 1:只运行表格方法
python lessons/05_function_approximation/linear_q_gridworld.py --agent table
目的:确认这仍然是前面学过的 Q-learning,只是地图更大。
实验 2:只运行线性方法
python lessons/05_function_approximation/linear_q_gridworld.py --agent linear --show-weights
重点看:
parameters stored 是否仍然是 10?
linear weights 有哪些正负变化?
greedy policy 是否能走向 goal?
实验 3:扩大地图
python lessons/05_function_approximation/linear_q_gridworld.py --rows 12 --cols 12 --episodes 2000 --log-every 1000
观察:
table 的 parameters stored 是否明显增长?
linear 的 parameters stored 是否仍然是 10?
这能说明:
表格参数量跟 state/action 数量相关;
线性参数量跟 feature 数量相关。
如果 12x12 地图里 linear 的成功率很低,这也是有效观察。它说明:
参数少不等于一定效果好;
特征太简单时,线性模型可能表达不了复杂地图里的好策略。
实验 4:调线性学习率
python lessons/05_function_approximation/linear_q_gridworld.py --agent linear --alpha-linear 0.01 --log-every 1000
python lessons/05_function_approximation/linear_q_gridworld.py --agent linear --alpha-linear 0.2 --log-every 1000
观察:
学习是否变慢?
训练日志是否更波动?
greedy policy 是否更不稳定?
函数逼近通常比表格方法更怕不合适的学习率。
实验 5:训练更久
python lessons/05_function_approximation/linear_q_gridworld.py --episodes 6000 --log-every 3000
观察:
linear 是否一定比默认 1000 局更好?
训练日志是否可能后期变差?
如果变差,不要急着认为程序错了。它正好说明函数逼近可能不如表格方法稳定。
实验 6:加入打滑
python lessons/05_function_approximation/linear_q_gridworld.py --slip-probability 0.1 --log-every 1000
观察:
随机环境下,table 和 linear 的 success rate 是否下降?
linear 的 greedy policy 是否更容易出现不稳定动作?
13. 本课总结
本课最重要的一句话是:
Q(s,a) 不一定来自表格,也可以来自一个函数。
表格方法:
每个 state/action 单独存一个 Q 值;
清楚稳定;
但状态多时无法扩展。
线性函数逼近:
用 features 和 weights 计算 Q(s,a);
参数少;
能泛化;
但可能互相干扰,也更依赖特征设计。
这节课是 DQN 的前置桥梁。
DQN 可以先粗略理解为:
Q-learning + 神经网络函数逼近
区别是:
本课用人工设计的 features + 线性 weights;
DQN 用神经网络从状态中学习更复杂的表示。
14. 进入下一课前的检查
你应该能用自己的话回答:
-
1. 为什么 Q 表在大状态空间里会出问题? -
2. table和linear都是在估计什么? -
3. parameters stored为什么重要? -
4. 线性方法里的 features和weights分别是什么? -
5. 为什么线性方法更新一次会影响很多状态动作? -
6. 函数逼近的优势是什么? -
7. 函数逼近的风险是什么? -
8. DQN 和本课的关系是什么?
如果这些问题能回答清楚,就可以进入下一阶段:
用神经网络估计 Q(s,a)
如果觉得内容不错,欢迎你点一下「在看」,或是将文章分享给其他有需要的人^^
相关好文推荐:

0条留言