news 2026/5/28 12:33:27

用Python玩转强化学习:手把手教你用Policy Iteration和Value Iteration解决赌徒问题

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
用Python玩转强化学习:手把手教你用Policy Iteration和Value Iteration解决赌徒问题

用Python玩转强化学习:手把手教你用Policy Iteration和Value Iteration解决赌徒问题

假设你走进一家赌场,口袋里揣着50美元。每次下注后,硬币正面朝上你能赢得与赌注相同的金额,反面朝上则输掉赌注。目标很简单:要么带着100美元离开,要么输光所有钱。这个看似简单的场景背后,隐藏着一个经典的强化学习问题——赌徒问题。

1. 环境搭建与问题建模

在开始编码前,我们需要明确问题的数学表达。赌徒问题可以建模为一个马尔可夫决策过程(MDP),包含以下要素:

  • 状态空间S:当前持有的金额,从0到99美元
  • 动作空间A:可能的赌注金额,范围是1到min(s, 100-s)
  • 转移概率:硬币正面概率ph,反面概率1-ph
  • 奖励函数:仅在达到100美元时获得+1奖励,其他情况为0

让我们先用NumPy建立基础环境:

import numpy as np import matplotlib.pyplot as plt class GamblerEnv: def __init__(self, goal=100, ph=0.4): self.goal = goal self.ph = ph # 硬币正面概率 self.states = np.arange(goal + 1) # 0到100的所有状态 def get_actions(self, s): """获取当前状态下的可用动作""" return np.arange(1, min(s, self.goal - s) + 1)

2. 策略迭代(Policy Iteration)实现

策略迭代包含两个交替进行的步骤:策略评估和策略改进。我们将分步实现:

2.1 策略评估

策略评估通过迭代计算当前策略下的状态价值函数:

def policy_evaluation(self, policy, theta=1e-9, max_iter=1000): V = np.zeros(self.goal + 1) V[self.goal] = 1.0 # 目标状态价值为1 for _ in range(max_iter): delta = 0 for s in self.states[1:self.goal]: # 跳过0和100 v = V[s] a = policy[s] # 计算期望回报 V[s] = self.ph * V[s + a] + (1 - self.ph) * V[s - a] delta = max(delta, abs(v - V[s])) if delta < theta: break return V

2.2 策略改进

基于评估得到的价值函数改进策略:

def policy_improvement(self, V, policy): policy_stable = True for s in self.states[1:self.goal]: old_a = policy[s] action_returns = [] for a in self.get_actions(s): ret = self.ph * V[s + a] + (1 - self.ph) * V[s - a] action_returns.append(ret) # 选择回报最大的动作 policy[s] = self.get_actions(s)[np.argmax(action_returns)] if old_a != policy[s]: policy_stable = False return policy_stable

2.3 完整策略迭代流程

将两个步骤组合成完整算法:

def policy_iteration(self, theta=1e-9): # 初始化随机策略 policy = np.zeros(self.goal + 1, dtype=int) for s in self.states[1:self.goal]: policy[s] = np.random.choice(self.get_actions(s)) while True: V = self.policy_evaluation(policy, theta) if self.policy_improvement(V, policy): break return policy, V

3. 值迭代(Value Iteration)实现

值迭代将策略评估和策略改进结合在一个步骤中:

def value_iteration(self, theta=1e-9): V = np.zeros(self.goal + 1) V[self.goal] = 1.0 policy = np.zeros(self.goal + 1, dtype=int) while True: delta = 0 for s in self.states[1:self.goal]: v = V[s] action_returns = [] for a in self.get_actions(s): ret = self.ph * V[s + a] + (1 - self.ph) * V[s - a] action_returns.append(ret) V[s] = np.max(action_returns) delta = max(delta, abs(v - V[s])) if delta < theta: break # 提取最优策略 for s in self.states[1:self.goal]: action_returns = [] for a in self.get_actions(s): ret = self.ph * V[s + a] + (1 - self.ph) * V[s - a] action_returns.append(ret) policy[s] = self.get_actions(s)[np.argmax(action_returns)] return policy, V

4. 结果分析与可视化

让我们比较两种算法的结果,并可视化关键指标:

4.1 策略对比

def compare_policies(env, ph_values=[0.4, 0.55]): fig, axes = plt.subplots(2, 2, figsize=(12, 10)) for i, ph in enumerate(ph_values): env.ph = ph # 运行两种算法 pi_policy, _ = env.policy_iteration() vi_policy, _ = env.value_iteration() # 绘制策略 axes[0,i].plot(pi_policy[1:100], label='Policy Iteration') axes[0,i].plot(vi_policy[1:100], label='Value Iteration') axes[0,i].set_title(f'Optimal Policy (ph={ph})') axes[0,i].set_xlabel('Capital') axes[0,i].set_ylabel('Stake') axes[0,i].legend() # 绘制策略差异 diff = np.abs(pi_policy - vi_policy) axes[1,i].plot(diff[1:100]) axes[1,i].set_title(f'Policy Difference (ph={ph})') axes[1,i].set_xlabel('Capital') axes[1,i].set_ylabel('Absolute Difference') plt.tight_layout() plt.show()

4.2 收敛速度分析

我们可以记录两种算法的收敛过程:

def plot_convergence(env): # 策略迭代收敛过程 env.policy_iteration() # 确保历史数据被记录 pi_sweeps = len(env.pi_history) # 值迭代收敛过程 env.value_iteration() vi_sweeps = len(env.vi_history) plt.figure(figsize=(10, 5)) plt.bar(['Policy Iteration', 'Value Iteration'], [pi_sweeps, vi_sweeps]) plt.title('Number of Sweeps to Convergence') plt.ylabel('Sweeps') plt.show()

5. 实战技巧与优化建议

在实际实现过程中,有几个关键点需要注意:

  1. 数值稳定性

    • 使用较小的θ值(如1e-9)确保收敛精度
    • 对浮点数比较使用np.isclose而非直接==
  2. 性能优化

    # 向量化计算可以显著提升速度 def vectorized_update(self, V): new_V = V.copy() for s in self.states[1:self.goal]: actions = self.get_actions(s) returns = self.ph * V[s + actions] + (1 - self.ph) * V[s - actions] new_V[s] = np.max(returns) return new_V
  3. 策略可视化技巧

    • 使用热图展示不同ph值下的策略变化
    • 添加移动平均线使策略趋势更明显
  4. 调试建议

    • 打印每次迭代的最大值变化
    • 检查边界条件(如s=1和s=99时的动作选择)
算法特性策略迭代值迭代
计算复杂度每次迭代需完整策略评估每次迭代直接更新最优值
收敛速度通常需要较少外层迭代可能需要更多次值更新
内存消耗需要存储中间策略只需维护值函数
适用场景策略变化平稳的问题动作空间较大的问题

在实现过程中发现,当ph=0.4时,两种算法得到的策略差异较大;而当ph=0.55时,策略更加相似。这是因为在有利的胜率下,最优策略趋向于更激进的投注方式。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/28 12:29:33

保姆级教程:用Python+LIBSVM复现西瓜书SVM习题(附完整代码与数据集)

从理论到实践&#xff1a;PythonLIBSVM实现西瓜书SVM习题全流程解析在机器学习领域&#xff0c;支持向量机(SVM)一直以其优秀的分类性能和清晰的数学原理备受推崇。周志华教授的《机器学习》(西瓜书)作为国内经典教材&#xff0c;其第六章对SVM的理论讲解深入浅出&#xff0c;但…

作者头像 李华
网站建设 2026/5/28 12:26:51

Arduino与TouchDesigner交互:吹气控制蒲公英光影装置全解析

1. 项目概述&#xff1a;当吹气遇见代码&#xff0c;一朵会发光的蒲公英如何诞生几年前&#xff0c;我在一个新媒体艺术展上看到一件作品&#xff1a;观众对着一个麦克风低语&#xff0c;墙上的光影便如涟漪般荡漾开。那一刻我意识到&#xff0c;将无形的物理动作转化为可见的数…

作者头像 李华
网站建设 2026/5/28 12:24:11

TimesFM动态协变量:技术深度解析与实践避坑指南

TimesFM动态协变量&#xff1a;技术深度解析与实践避坑指南 【免费下载链接】timesfm TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google Research for time-series forecasting. 项目地址: https://gitcode.com/Gi…

作者头像 李华