从 Q 表到函数逼近

一直想要系统学习RL的相关知识,脑子里想着能不能完全依靠 Codex 进行学习,于是就有了下面的教程。

完整内容:https://github.com/lanheixingkong/rl-course

前几课里,我们一直用表格保存价值。

第三课Q-learning,在不知道环境模型时学习行动价值 里,代码保存的是:


   
   
    
   
   Q(s,a)

第四课SARSA vs Q-learning,看懂 on-policy 和 off-policy里,虽然 target 不同,但两者学习的仍然是:


   
   
    
   
   Q(s,a)

到目前为止,Q(s,a) 都是用表格存的:


   
   
    
   
   q[state][action] = 一个数字

这在小地图里很清楚。但真实问题往往不能这么做。

如果状态很多,甚至是连续的,Q 表会变得非常大,或者根本无法枚举。这时就需要进入一个新阶段:


   
   
    
   
   function approximation,函数逼近

这节课先不直接上神经网络,而是用最简单的线性模型过渡。

1. 本课问题

表格 Q-learning 的逻辑是:


   
   
    
   
   每个 state/action 单独存一个 Q 值

例如:


   
   
    
   
   Q((7,0), U)
Q((7,0), R)
Q((6,0), U)
Q((6,0), R)

都是不同格子里的数字。

更新其中一个值,不会直接改变另一个值。

这很容易理解,但问题是:


   
   
    
   
   如果 state 很多怎么办?

假设有:


   
   
    
   
   10000 个状态
10 个动作

表格方法就要存:


   
   
    
   
   10000 * 10 = 100000 个 Q 值

如果状态是连续值,比如:


   
   
    
   
   位置
速度
角度
温度
库存数量
用户行为向量

那状态可能根本不能被完整列出来。

所以第五课要解决的问题是:


   
   
    
   
   如果不能给每个 state/action 单独存 Q 值,
能不能用一个函数来估计 Q(s,a)?

2. 本课案例

这一课仍然使用 GridWorld,但地图变大:


   
   
    
   
   8 x 8 GridWorld

起点在左下角:


   
   
    
   
   start = (7,0)

终点在右上角:


   
   
    
   
   goal = (0,7)

地图中还有墙和坑。

代码会训练两个 agent:

agent
表示方式
含义
table
Q 表
每个 (state, action) 单独存值
linear
线性函数
用特征和权重计算 Q(s,a)

这节课不是为了证明线性方法一定比表格方法好。

它的目的更具体:


   
   
    
   
   让你看懂 Q(s,a) 不一定来自表格,
也可以来自一个可训练的函数。

3. 运行实验

在项目根目录运行:


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py

你会看到类似输出:


   
   
    
   
   Function Approximation GridWorld config:
  agent            : both
  rows x cols      : 8 x 8
  episodes         : 1000
  max_steps        : 120
  alpha_table      : 0.2
  alpha_linear     : 0.05
  gamma            : 0.95
  epsilon          : 0.2
  step_reward      : -0.04
  slip_probability : 0.0
  seed             : 0
  start            : (7, 0)
  goal             : (0, 7)
  pits             : [(6, 6), (7, 6)]
  walls            : 5 cells

table   episode  1000 | avg return last 100:  0.343 | avg steps:  17.4 | success rate:  1.00
linear  episode  1000 | avg return last 100:  0.270 | avg steps:  17.3 | success rate:  0.96

table summary:
  parameters stored             : 224

table greedy reachability:
  from start (7, 0): goal in 14 steps
  from all non-terminal states: 49/56 reach goal

linear summary:
  parameters stored             : 10

linear greedy reachability:
  from start (7, 0): goal in 14 steps
  from all non-terminal states: 56/56 reach goal

4. 先看输出

config

agent = both 表示同一次运行里训练两种表示方式:


   
   
    
   
   table
linear

rows x cols = 8 x 8 表示地图大小。

alpha_table 和 alpha_linear 是两个学习率。它们分开设置,是因为表格方法和函数逼近对学习率的敏感程度不同。

训练日志


   
   
    
   
   table   episode  1000 | avg return last 100:  0.343 | avg steps:  17.4 | success rate:  1.00
linear  episode  1000 | avg return last 100:  0.270 | avg steps:  17.3 | success rate:  0.96

含义和前几课一样:

字段
含义
avg return last 100
最近 100 局平均累计奖励
avg steps
最近 100 局平均走了多少步
success rate
最近 100 局到达 goal 的比例

这里不要只看谁高一点。

本课更重要的是下面这个指标:


   
   
    
   
   parameters stored

parameters stored

表格方法输出:


   
   
    
   
   parameters stored : 224

意思是:


   
   
    
   
   它存了 224 个独立 Q 值。

线性方法输出:


   
   
    
   
   parameters stored : 10

意思是:


   
   
    
   
   它只存了 10 个共享权重。

这个对比是本课重点。

表格方法的参数量随状态和动作增长。

线性方法的参数量取决于特征数量。

success rate 和 greedy reachability

输出里还有一个容易误解的地方:


   
   
    
   
   success rate last 100
greedy reachability

它们不是同一个指标。

success rate last 100 统计的是训练过程:


   
   
    
   
   最近 100 个训练 episode,
从固定 start 出发,
有多少次到达 goal。

它只说明“从默认起点出发,训练时表现如何”。

它不说明“从地图上每个格子出发,最终 greedy policy 是否都能到 goal”。

坐标的读法是:


   
   
    
   
   (row, col)
row 从上到下数;
col 从左到右数。

所以 (0,0) 是左上角,(7,0) 是左下角,(0,7) 是右上角 goal。

所以可能出现:


   
   
    
   
   table success rate last 100 = 1.00
但 table greedy reachability = 49/56

这不是矛盾。

它说明表格方法已经把从 start = (7,0) 到 goal = (0,7) 的常用路径学好了,所以最近 100 局都成功;但一些不在常用路径上、很少被访问的格子,Q 表里对应动作还没学好。

表格方法不会自动把一个格子的经验推广到另一个格子。

失败预览显示的是“从这个格子出发,按 greedy policy 一步步走下去的路径”,不是只显示第一步动作。

例如:


   
   
    
   
   (6, 5)->(7, 5)->(7, 5)->loop

含义是:


   
   
    
   
   (6,5) 的动作是 D,走到 (7,5);
(7,5) 的动作也是 D,但已经在最下面一行;
继续向下会撞边界,留在 (7,5),于是形成 loop。

再比如:


   
   
    
   
   (6, 7)->(6, 7)->loop

因为 (6,7) 的动作是 R,但它已经在最右一列,向右会撞边界并留在原地。

也可能出现:


   
   
    
   
   linear success rate last 100 = 0.96
但 linear greedy reachability = 56/56

这也不是矛盾。

训练时使用的是 epsilon-greedy:


   
   
    
   
   epsilon = 0.2

也就是说,训练过程中仍然有 20% 概率随机探索。随机探索可能掉坑、绕路,导致训练 episode 失败。

但 greedy reachability 检查的是训练结束后的最终 greedy policy:


   
   
    
   
   不再随机探索;
每个状态都选择当前 Q 值最大的动作。

所以训练成功率可以不是 100%,但最终 greedy policy 仍然能从所有非终止格子走到 goal。

线性方法还有一个额外原因:它使用共享特征和权重,容易学出全局规则,例如:


   
   
    
   
   先往上靠近第一行;
或者往右靠近最后一列;
再走向 goal。

这种规则会推广到很多没有被充分单独训练过的格子。这就是函数逼近的泛化。

5. 表格方法到底在存什么

代码里的表格类是:


   
   
    
   
   class TabularQ:

初始化时:


   
   
    
   
   self.q = {
    state: {action: 0.0 for action in ACTIONS}
    for
 state in env.states()
    if
 not env.is_terminal(state)
}

这表示:


   
   
    
   
   每个非终止 state
每个 action
都存一个独立的数字

读取 Q 值:


   
   
    
   
   return self.q[state][action]

更新 Q 值:


   
   
    
   
   self.q[state][action] = old_q + alpha * td_error

这很像在 Excel 表格里改一个单元格。

它的优点是:


   
   
    
   
   清楚;
稳定;
每个 Q 值相互独立;
容易 debug。

缺点是:


   
   
    
   
   状态越多,表越大;
没见过的状态没有经验;
连续状态无法直接完整建表。

6. 线性函数在做什么

线性方法的类是:


   
   
    
   
   class LinearQ:

它不存每个 (state, action) 的 Q 值,而是存一组权重:


   
   
    
   
   self.weights = {name: 0.0 for name in self.feature_names}

代码里使用的特征包括:

特征
含义
bias
固定为 1,让模型有基础偏移
row
下一状态所在行的归一化位置
col
下一状态所在列的归一化位置
goal_closeness
下一状态离 goal 有多近
pit_closeness
下一状态离 pit 有多近
bump
动作是否撞墙或撞边界
action_U/R/D/L
当前动作是哪一个

逐行解释 features()

features() 的作用是:


   
   
    
   
   把一个具体的 (state, action)
转换成一组数字特征,
让线性模型可以用这些数字估计 Q(s,a)。

第一步:


   
   
    
   
   next_state = self.env.next_state(state, action)

这里不是只看当前 state,而是先计算:


   
   
    
   
   如果在 state 执行 action,会走到哪里?

所以很多特征描述的是 next_state。例如:


   
   
    
   
   state = (7,0)
action = U
next_state = (6,0)

接下来:


   
   
    
   
   max_distance = max(1, (self.env.rows - 1) + (self.env.cols - 1))

这是地图中最大可能的曼哈顿距离,用来把距离缩放到大约 0~1 范围。max(1, ...) 是防御性写法,避免除以 0。


   
   
    
   
   goal_distance = manhattan(next_state, self.env.goal)

计算 next_state 到 goal 的曼哈顿距离:


   
   
    
   
   manhattan((r1,c1), (r2,c2)) = |r1-r2| + |c1-c2|

   
   
    
   
   nearest_pit_distance = min((manhattan(next_state, pit) for pit in self.env.pits), default=max_distance)

计算 next_state 到最近 pit 的距离。如果地图里没有 pit,就使用 max_distance 作为默认值。

然后开始构造特征:


   
   
    
   
   "bias": 1.0

固定为 1,相当于模型的基础分。


   
   
    
   
   "row": next_state[0] / max(1, self.env.rows - 1)
"col"
: next_state[1] / max(1, self.env.cols - 1)

表示下一格的行、列位置,并归一化到大约 0~1

在 8x8 地图里:


   
   
    
   
   row = 0 / 7 = 0.0  表示顶部
row = 7 / 7 = 1.0  表示底部
col = 0 / 7 = 0.0  表示最左边
col = 7 / 7 = 1.0  表示最右边

   
   
    
   
   "goal_closeness": 1.0 - goal_distance / max_distance

表示下一格离 goal 有多近。越接近 goal,这个值越大。


   
   
    
   
   "pit_closeness": 1.0 - nearest_pit_distance / max_distance

表示下一格离最近 pit 有多近。越接近 pit,这个值越大。它本身不表示好坏,具体好坏由训练学出来的权重决定。通常靠近 pit 是坏事,所以它的权重可能会学成负数。


   
   
    
   
   "bump": 1.0 if next_state == state else 0.0

表示是否撞墙或撞边界。如果动作执行后位置没变,bump = 1.0


   
   
    
   
   "action_U": 1.0 if action == "U" else 0.0

action_U/R/D/L 是动作的 one-hot 表示。例如动作是 U


   
   
    
   
   action_U = 1.0
action_R = 0.0
action_D = 0.0
action_L = 0.0

所以,features() 自己不学习。它只是把 (state, action) 变成模型能使用的一组数字。真正学习的是 weights

然后用这些特征和权重计算 Q 值:


   
   
    
   
   return sum(self.weights[name] * value for name, value in self.features(state, action).items())

也就是:


   
   
    
   
   Q(s,a) = w1*x1 + w2*x2 + ... + wk*xk

这里:


   
   
    
   
   x = features
w = weights

这就是最简单的函数逼近。

7. 线性方法怎么学习

训练循环仍然是 Q-learning:


   
   
    
   
   target = reward + gamma * next_best
td_error = target - old_q

区别在更新方式。

表格方法更新一格:


   
   
    
   
   q[state][action]

线性方法更新一组权重:


   
   
    
   
   self.weights[name] += alpha * td_error * feature_value

这句话可以这样理解:


   
   
    
   
   如果这次估计偏低,td_error 为正,就把相关特征的权重往上调;
如果这次估计偏高,td_error 为负,就把相关特征的权重往下调;
某个特征值越大,它受到这次更新的影响越大。

所以线性方法不是记住一个格子,而是在调整一套通用规则。

8. 为什么这叫泛化

表格方法更新:


   
   
    
   
   Q((7,0), U)

通常只影响这一格。

线性方法更新:


   
   
    
   
   weights

这些权重会参与很多状态动作的 Q 值计算。

例如 goal_closeness 的权重变了,所有“离 goal 更近或更远”的动作估计都会受到影响。

这就是泛化:


   
   
    
   
   一次经验,不只影响当前 state/action,
还会影响具有相似特征的其他 state/action。

泛化是函数逼近的优势,也是风险。

9. 函数逼近的优势

函数逼近的优势:


   
   
    
   
   参数更少;
能处理更大的状态空间;
能从相似状态中共享经验;
能处理连续状态;
可以进一步换成神经网络。

这就是为什么它是深度 RL 的入口。

如果没有函数逼近,就很难把 Q-learning 扩展到图像、复杂游戏、机器人控制这类问题。

10. 函数逼近的代价

函数逼近也有代价。

表格方法里,一个 Q 值错了,通常只影响一个格子。

线性方法里,一个权重错了,会影响很多估计。

所以它可能出现:


   
   
    
   
   估计互相干扰;
训练更不稳定;
学习率更敏感;
特征设计不好时学不到好策略;
结果不如 Q 表容易解释。

这就是为什么后面的 DQN 不是简单地把 Q 表换成神经网络就完了。

DQN 还需要:


   
   
    
   
   experience replay
target network

这些技巧都是为了让函数逼近下的 Q-learning 更稳定。

这一课的线性方法也可能出现类似提醒:训练更久不一定总是更好。因为共享权重会让不同状态动作互相影响,后面的更新可能破坏前面已经学到的估计。这不是本课代码特有的问题,而是函数逼近进入 RL 后需要认真处理的稳定性问题。

11. 查看单步更新

运行:


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py --episodes 3 --debug-episodes 1 --log-every 0 --max-steps 15 --show-weights

你会看到:


   
   
    
   
   table episode 1:
  step  state   action  reward  next    old Q   target  td err  new Q

linear episode 1:
  step  state   action  reward  next    old Q   target  td err  new Q

linear weights:
  bias          : ...
  row           : ...
  col           : ...
  goal_closeness: ...

两种方法的 debug 表看起来很像,因为它们都在做:


   
   
    
   
   old_q -> target -> td_error -> new_q

但 new_q 的来源不同:


   
   
    
   
   table:
  new_q 来自 q[state][action] 这一格被改掉

linear:
  new_q 来自 weights 被改掉后重新计算

12. 参数实验

实验 1:只运行表格方法


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py --agent table

目的:确认这仍然是前面学过的 Q-learning,只是地图更大。

实验 2:只运行线性方法


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py --agent linear --show-weights

重点看:


   
   
    
   
   parameters stored 是否仍然是 10?
linear weights 有哪些正负变化?
greedy policy 是否能走向 goal?

实验 3:扩大地图


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py --rows 12 --cols 12 --episodes 2000 --log-every 1000

观察:


   
   
    
   
   table 的 parameters stored 是否明显增长?
linear 的 parameters stored 是否仍然是 10?

这能说明:


   
   
    
   
   表格参数量跟 state/action 数量相关;
线性参数量跟 feature 数量相关。

如果 12x12 地图里 linear 的成功率很低,这也是有效观察。它说明:


   
   
    
   
   参数少不等于一定效果好;
特征太简单时,线性模型可能表达不了复杂地图里的好策略。

实验 4:调线性学习率


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py --agent linear --alpha-linear 0.01 --log-every 1000
python lessons/05_function_approximation/linear_q_gridworld.py --agent linear --alpha-linear 0.2 --log-every 1000

观察:


   
   
    
   
   学习是否变慢?
训练日志是否更波动?
greedy policy 是否更不稳定?

函数逼近通常比表格方法更怕不合适的学习率。

实验 5:训练更久


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py --episodes 6000 --log-every 3000

观察:


   
   
    
   
   linear 是否一定比默认 1000 局更好?
训练日志是否可能后期变差?

如果变差,不要急着认为程序错了。它正好说明函数逼近可能不如表格方法稳定。

实验 6:加入打滑


   
   
    
   
   python lessons/05_function_approximation/linear_q_gridworld.py --slip-probability 0.1 --log-every 1000

观察:


   
   
    
   
   随机环境下,table 和 linear 的 success rate 是否下降?
linear 的 greedy policy 是否更容易出现不稳定动作?

13. 本课总结

本课最重要的一句话是:


   
   
    
   
   Q(s,a) 不一定来自表格,也可以来自一个函数。

表格方法:


   
   
    
   
   每个 state/action 单独存一个 Q 值;
清楚稳定;
但状态多时无法扩展。

线性函数逼近:


   
   
    
   
   用 features 和 weights 计算 Q(s,a);
参数少;
能泛化;
但可能互相干扰,也更依赖特征设计。

这节课是 DQN 的前置桥梁。

DQN 可以先粗略理解为:


   
   
    
   
   Q-learning + 神经网络函数逼近

区别是:


   
   
    
   
   本课用人工设计的 features + 线性 weights;
DQN 用神经网络从状态中学习更复杂的表示。

14. 进入下一课前的检查

你应该能用自己的话回答:

  1. 1. 为什么 Q 表在大状态空间里会出问题?
  2. 2. table 和 linear 都是在估计什么?
  3. 3. parameters stored 为什么重要?
  4. 4. 线性方法里的 features 和 weights 分别是什么?
  5. 5. 为什么线性方法更新一次会影响很多状态动作?
  6. 6. 函数逼近的优势是什么?
  7. 7. 函数逼近的风险是什么?
  8. 8. DQN 和本课的关系是什么?

如果这些问题能回答清楚,就可以进入下一阶段:


   
   
    
   
   用神经网络估计 Q(s,a)

如果觉得内容不错,欢迎你点一下「在看」,或是将文章分享给其他有需要的人^^

相关好文推荐:

每次看见有人说能够识别出一段文字是不是AI生成的,我都忍不住想笑

飞书会取代微信吗?

AI 时代的软件与软件公司应该长什么样?

0条留言

留言