用Python玩转强化学习:手把手教你用Policy Iteration和Value Iteration解决赌徒问题
假设你走进一家赌场,口袋里揣着50美元。每次下注后,硬币正面朝上你能赢得与赌注相同的金额,反面朝上则输掉赌注。目标很简单:要么带着100美元离开,要么输光所有钱。这个看似简单的场景背后,隐藏着一个经典的强化学习问题——赌徒问题。
1. 环境搭建与问题建模
在开始编码前,我们需要明确问题的数学表达。赌徒问题可以建模为一个马尔可夫决策过程(MDP),包含以下要素:
- 状态空间S:当前持有的金额,从0到99美元
- 动作空间A:可能的赌注金额,范围是1到min(s, 100-s)
- 转移概率:硬币正面概率ph,反面概率1-ph
- 奖励函数:仅在达到100美元时获得+1奖励,其他情况为0
让我们先用NumPy建立基础环境:
import numpy as np import matplotlib.pyplot as plt class GamblerEnv: def __init__(self, goal=100, ph=0.4): self.goal = goal self.ph = ph # 硬币正面概率 self.states = np.arange(goal + 1) # 0到100的所有状态 def get_actions(self, s): """获取当前状态下的可用动作""" return np.arange(1, min(s, self.goal - s) + 1)2. 策略迭代(Policy Iteration)实现
策略迭代包含两个交替进行的步骤:策略评估和策略改进。我们将分步实现:
2.1 策略评估
策略评估通过迭代计算当前策略下的状态价值函数:
def policy_evaluation(self, policy, theta=1e-9, max_iter=1000): V = np.zeros(self.goal + 1) V[self.goal] = 1.0 # 目标状态价值为1 for _ in range(max_iter): delta = 0 for s in self.states[1:self.goal]: # 跳过0和100 v = V[s] a = policy[s] # 计算期望回报 V[s] = self.ph * V[s + a] + (1 - self.ph) * V[s - a] delta = max(delta, abs(v - V[s])) if delta < theta: break return V2.2 策略改进
基于评估得到的价值函数改进策略:
def policy_improvement(self, V, policy): policy_stable = True for s in self.states[1:self.goal]: old_a = policy[s] action_returns = [] for a in self.get_actions(s): ret = self.ph * V[s + a] + (1 - self.ph) * V[s - a] action_returns.append(ret) # 选择回报最大的动作 policy[s] = self.get_actions(s)[np.argmax(action_returns)] if old_a != policy[s]: policy_stable = False return policy_stable2.3 完整策略迭代流程
将两个步骤组合成完整算法:
def policy_iteration(self, theta=1e-9): # 初始化随机策略 policy = np.zeros(self.goal + 1, dtype=int) for s in self.states[1:self.goal]: policy[s] = np.random.choice(self.get_actions(s)) while True: V = self.policy_evaluation(policy, theta) if self.policy_improvement(V, policy): break return policy, V3. 值迭代(Value Iteration)实现
值迭代将策略评估和策略改进结合在一个步骤中:
def value_iteration(self, theta=1e-9): V = np.zeros(self.goal + 1) V[self.goal] = 1.0 policy = np.zeros(self.goal + 1, dtype=int) while True: delta = 0 for s in self.states[1:self.goal]: v = V[s] action_returns = [] for a in self.get_actions(s): ret = self.ph * V[s + a] + (1 - self.ph) * V[s - a] action_returns.append(ret) V[s] = np.max(action_returns) delta = max(delta, abs(v - V[s])) if delta < theta: break # 提取最优策略 for s in self.states[1:self.goal]: action_returns = [] for a in self.get_actions(s): ret = self.ph * V[s + a] + (1 - self.ph) * V[s - a] action_returns.append(ret) policy[s] = self.get_actions(s)[np.argmax(action_returns)] return policy, V4. 结果分析与可视化
让我们比较两种算法的结果,并可视化关键指标:
4.1 策略对比
def compare_policies(env, ph_values=[0.4, 0.55]): fig, axes = plt.subplots(2, 2, figsize=(12, 10)) for i, ph in enumerate(ph_values): env.ph = ph # 运行两种算法 pi_policy, _ = env.policy_iteration() vi_policy, _ = env.value_iteration() # 绘制策略 axes[0,i].plot(pi_policy[1:100], label='Policy Iteration') axes[0,i].plot(vi_policy[1:100], label='Value Iteration') axes[0,i].set_title(f'Optimal Policy (ph={ph})') axes[0,i].set_xlabel('Capital') axes[0,i].set_ylabel('Stake') axes[0,i].legend() # 绘制策略差异 diff = np.abs(pi_policy - vi_policy) axes[1,i].plot(diff[1:100]) axes[1,i].set_title(f'Policy Difference (ph={ph})') axes[1,i].set_xlabel('Capital') axes[1,i].set_ylabel('Absolute Difference') plt.tight_layout() plt.show()4.2 收敛速度分析
我们可以记录两种算法的收敛过程:
def plot_convergence(env): # 策略迭代收敛过程 env.policy_iteration() # 确保历史数据被记录 pi_sweeps = len(env.pi_history) # 值迭代收敛过程 env.value_iteration() vi_sweeps = len(env.vi_history) plt.figure(figsize=(10, 5)) plt.bar(['Policy Iteration', 'Value Iteration'], [pi_sweeps, vi_sweeps]) plt.title('Number of Sweeps to Convergence') plt.ylabel('Sweeps') plt.show()5. 实战技巧与优化建议
在实际实现过程中,有几个关键点需要注意:
数值稳定性:
- 使用较小的θ值(如1e-9)确保收敛精度
- 对浮点数比较使用np.isclose而非直接==
性能优化:
# 向量化计算可以显著提升速度 def vectorized_update(self, V): new_V = V.copy() for s in self.states[1:self.goal]: actions = self.get_actions(s) returns = self.ph * V[s + actions] + (1 - self.ph) * V[s - actions] new_V[s] = np.max(returns) return new_V策略可视化技巧:
- 使用热图展示不同ph值下的策略变化
- 添加移动平均线使策略趋势更明显
调试建议:
- 打印每次迭代的最大值变化
- 检查边界条件(如s=1和s=99时的动作选择)
| 算法特性 | 策略迭代 | 值迭代 |
|---|---|---|
| 计算复杂度 | 每次迭代需完整策略评估 | 每次迭代直接更新最优值 |
| 收敛速度 | 通常需要较少外层迭代 | 可能需要更多次值更新 |
| 内存消耗 | 需要存储中间策略 | 只需维护值函数 |
| 适用场景 | 策略变化平稳的问题 | 动作空间较大的问题 |
在实现过程中发现,当ph=0.4时,两种算法得到的策略差异较大;而当ph=0.55时,策略更加相似。这是因为在有利的胜率下,最优策略趋向于更激进的投注方式。