news 2026/5/1 8:32:21

【机器学习】案例1.2——决策树进行鸢尾花分类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【机器学习】案例1.2——决策树进行鸢尾花分类

1. 项目背景及解决问题的方案

1.1 项目背景

鸢尾花(Iris)数据集是机器学习领域的经典基准数据集,由统计学家Fisher于1936年提出,是多分类任务的入门级数据集。该数据集包含150个样本,对应3类鸢尾花(山鸢尾/Iris-setosa、变色鸢尾/Iris-versicolor、维吉尼亚鸢尾/Iris-virginica),每类各50个样本;每个样本包含4个数值型特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度。

从技术角度,决策树是一种基于树状结构做决策的分类/回归算法,具有可解释性强、无需特征归一化、直观易懂等优点,但核心痛点是:当决策树的深度过深时,模型会过度拟合训练数据的细节(如噪声),导致在测试集上的泛化能力下降(过拟合)。

本项目的核心目标:

  • 基于鸢尾花数据集,使用决策树分类器实现鸢尾花种类的精准分类;
  • 探究决策树深度对模型泛化能力(测试集错误率)的影响,验证“深度过深导致过拟合”的现象;
  • 掌握决策树模型的训练、评估、可视化及超参数(深度)调优的核心流程。
1.2 解决问题的方案(分步骤)
步骤核心动作具体实现
数据准备加载+预处理1. 加载sklearn内置的Iris数据集;
2. 转换为Pandas DataFrame,命名特征列并添加目标列;
3. 选择“花瓣长度、花瓣宽度”两个核心特征(区分度更高)。
数据集划分训练/测试集拆分按75%(训练):25%(测试)划分数据,设置random_state=42保证结果可复现。
基础模型训练决策树训练+评估1. 初始化决策树分类器(max_depth=8,gini准则);
2. 训练模型并预测测试集;
3. 计算测试集准确率、输出特征重要性;
4. 导出决策树可视化文件(dot格式)。
单样本验证自定义样本预测对花瓣长度=5、宽度=1.5的样本,预测分类概率和最终结果。
超参数探究深度对性能的影响1. 遍历深度1~14,训练不同深度的决策树;
2. 计算每个深度的测试集错误率;
3. 可视化深度与错误率的关系,验证过拟合。
可视化展示结果可视化设置中文字体,绘制深度-错误率折线图,直观展示规律。

2. 代码详细注释版

# 导入必要的库importpandasaspd# 数据处理库,用于结构化数据操作importnumpyasnp# 数值计算库,用于数组/矩阵操作fromsklearn.datasetsimportload_iris# 加载sklearn内置的鸢尾花数据集fromsklearn.treeimportDecisionTreeClassifier# 决策树分类器fromsklearn.treeimportexport_graphviz# 导出决策树为dot格式(可视化用)fromsklearn.treeimportDecisionTreeRegressor# 决策树回归器(本项目未使用,保留注释)fromsklearn.model_selectionimporttrain_test_split# 划分训练集/测试集fromsklearn.metricsimportaccuracy_score# 计算分类准确率importmatplotlib.pyplotasplt# 绘图库importmatplotlibasmpl# 绘图配置库# ===================== 步骤1:加载并预处理鸢尾花数据集 =====================# 加载鸢尾花数据集iris=load_iris()# 将特征数据转换为DataFrame(方便查看和处理)data=pd.DataFrame(iris.data)# 为特征列命名(对应数据集的4个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度)data.columns=iris.feature_names# 添加目标列(鸢尾花种类,0=setosa,1=versicolor,2=virginica)data['Species']=load_iris().target# 打印数据集前几行(默认5行),查看数据结构print(data)# 特征选择:仅选取花瓣长度(第3列)和花瓣宽度(第4列)作为输入特征(iloc[:,2:4]表示行全选,列选2-3索引)x=data.iloc[:,2:4]# 目标变量:选取最后一列(Species)作为分类目标y=data.iloc[:,-1]# ===================== 步骤2:划分训练集和测试集 =====================# train_size=0.75:训练集占75%,测试集25%;random_state=42:固定随机种子,保证结果可复现x_train,x_test,y_train,y_test=train_test_split(x,y,train_size=0.75,random_state=42)# ===================== 步骤3:训练基础决策树模型并评估 =====================# 初始化决策树分类器:max_depth=8(树最大深度),criterion='gini'(基尼系数作为分裂准则)tree_clf=DecisionTreeClassifier(max_depth=8,criterion='gini')# 用训练集数据训练模型tree_clf.fit(x_train,y_train)# 用训练好的模型预测测试集y_test_hat=tree_clf.predict(x_test)# 计算并打印测试集准确率print("acc score:",accuracy_score(y_test,y_test_hat))# 打印特征重要性:数值越大表示该特征对分类的贡献越大print("特征重要性(花瓣长度/花瓣宽度):",tree_clf.feature_importances_)# 导出决策树为dot格式文件(可通过dot命令转换为PNG图片查看树结构)export_graphviz(tree_clf,# 训练好的决策树模型out_file="./iris_tree.dot",# 输出文件路径feature_names=iris.feature_names[2:4],# 特征名(仅花瓣长度/宽度)class_names=iris.target_names,# 类别名(setosa/versicolor/virginica)rounded=True,# 节点边框圆角filled=True# 节点填充颜色)# 备注:转换命令(需安装graphviz):./dot -Tpng ~/PycharmProjects/mlstudy/bjsxt/iris_tree.dot -o ~/PycharmProjects/mlstudy/bjsxt/iris_tree.png# ===================== 步骤4:单样本预测 =====================# 预测花瓣长度=5,宽度=1.5的样本属于各类的概率print("单样本分类概率:",tree_clf.predict_proba([[5,1.5]]))# 预测该样本的最终分类结果(输出类别索引)print("单样本分类结果:",tree_clf.predict([[5,1.5]]))# ===================== 步骤5:探究决策树深度对错误率的影响 =====================# 生成深度范围:1到14(包含14)depth=np.arange(1,15)# 存储每个深度对应的错误率err_list=[]# 遍历每个深度,训练模型并计算错误率fordindepth:print(f"当前训练的决策树深度:{d}")# 初始化对应深度的决策树分类器(基尼系数准则)clf=DecisionTreeClassifier(criterion='gini',max_depth=d)# 训练模型clf.fit(x_train,y_train)# 预测测试集y_test_hat=clf.predict(x_test)# 计算预测正确的样本(True/False数组)result=(y_test_hat==y_test)# 仅当深度=1时打印预测正确与否的结果(用于调试)ifd==1:print(f"深度=1时的预测正确结果:{result}")# 计算错误率:1 - 正确样本的均值err=1-np.mean(result)# 打印错误率(百分比)print(f"深度={d}时的错误率(百分比):{100*err:.2f}%")# 将错误率加入列表err_list.append(err)# ===================== 步骤6:可视化深度与错误率的关系 =====================# 设置matplotlib的中文字体(SimHei=黑体,避免中文乱码)mpl.rcParams['font.sans-serif']=['SimHei']# 设置图片背景色为白色plt.figure(facecolor='w')# 绘制折线图:红色圆点+实线,线宽=2plt.plot(depth,err_list,'ro-',lw=2)# 设置x轴标签plt.xlabel('决策树深度',fontsize=15)# 设置y轴标签plt.ylabel('错误率',fontsize=15)# 设置标题plt.title('决策树深度和过拟合',fontsize=18)# 显示网格线plt.grid(True)# 展示图片plt.show()# 决策树回归器(本项目未使用,保留代码注释)# tree_reg = DecisionTreeRegressor(max_depth=2)# tree_reg.fit(X, y)

3. 代码简洁版(核心逻辑,精简注释/打印)

importpandasaspdimportnumpyasnpfromsklearn.datasetsimportload_irisfromsklearn.treeimportDecisionTreeClassifier,export_graphvizfromsklearn.model_selectionimporttrain_test_splitfromsklearn.metricsimportaccuracy_scoreimportmatplotlib.pyplotaspltimportmatplotlibasmpl# 数据加载与预处理iris=load_iris()data=pd.DataFrame(iris.data,columns=iris.feature_names)data['Species']=iris.target x=data.iloc[:,2:4]# 花瓣长度/宽度y=data.iloc[:,-1]# 划分数据集x_train,x_test,y_train,y_test=train_test_split(x,y,train_size=0.75,random_state=42)# 基础模型训练tree_clf=DecisionTreeClassifier(max_depth=8,criterion='gini')tree_clf.fit(x_train,y_train)print("准确率:",accuracy_score(y_test,tree_clf.predict(x_test)))# 导出决策树可视化文件export_graphviz(tree_clf,out_file="./iris_tree.dot",feature_names=iris.feature_names[2:4],class_names=iris.target_names,rounded=True,filled=True)# 单样本预测print("单样本概率:",tree_clf.predict_proba([[5,1.5]]))print("单样本结果:",tree_clf.predict([[5,1.5]]))# 探究深度对错误率的影响depth=np.arange(1,15)err_list=[]fordindepth:clf=DecisionTreeClassifier(criterion='gini',max_depth=d)clf.fit(x_train,y_train)err=1-np.mean(clf.predict(x_test)==y_test)err_list.append(err)# 可视化mpl.rcParams['font.sans-serif']=['SimHei']plt.figure(facecolor='w')plt.plot(depth,err_list,'ro-',lw=2)plt.xlabel('决策树深度',fontsize=15)plt.ylabel('错误率',fontsize=15)plt.title('决策树深度和过拟合',fontsize=18)plt.grid(True)plt.show()

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/1 3:03:17

LobeChat部署常见问题汇总及解决方案(2024最新)

LobeChat 部署常见问题深度解析与实战优化(2024) 在 AI 应用快速落地的今天,越来越多开发者不再满足于“调用 API 输出文本”的原始模式。他们希望构建一个真正可用、好看、安全且可扩展的智能对话系统——而这就是 LobeChat 存在的意义。 它…

作者头像 李华
网站建设 2026/5/1 3:02:42

LobeChat能否实现拖拽上传?文件交互体验增强技巧

LobeChat能否实现拖拽上传?文件交互体验增强技巧 在如今的AI对话应用中,用户早已不满足于简单的“你问我答”。当面对一份几十页的PDF合同、一段复杂的代码文件,或是需要分析的数据表格时,谁还愿意一行行手动输入?一个…

作者头像 李华
网站建设 2026/5/1 4:04:42

FFmepg-- 34-ffplay源码-- ffplay 的音视频同步(AV Sync)机制

文章目录 默认同步策略:音频主时钟(Audio Master) 同步流程(视频线程视角) 时钟系统:Clock 结构与 set_clock() 完整调用过程 音频解码线程更新音频时钟(audio_thread) 主线程事件循环(event_loop → video_refresh) 本文系统地解析 ffplay 的音视频同步(AV Sync)机…

作者头像 李华
网站建设 2026/5/1 4:04:26

Qwen3-8B镜像部署全流程:从diskinfo查看存储到容器启动

Qwen3-8B镜像部署全流程:从存储检测到容器启动 在大语言模型(LLM)技术飞速发展的今天,如何将强大的AI能力落地到实际环境中,已成为开发者和企业面临的核心挑战。千亿参数级模型虽然性能惊人,但其高昂的算力…

作者头像 李华
网站建设 2026/5/1 4:05:02

火电厂环保设备全方位数据采集物联网方案

目前,大多数火电厂已配置齐全的环保设备,但这些脱硫、脱硝、除尘控制系统较为独立,存在数据孤岛,依赖管理人员进行调控与开关,无法统一调整操作。同时对于机组负荷也缺少监控管理的手段,往往存在能源浪费与…

作者头像 李华
网站建设 2026/5/1 4:02:07

LobeChat开源项目深度解析:打造个性化大模型交互前端

LobeChat开源项目深度解析:打造个性化大模型交互前端 在大语言模型(LLM)能力日益普及的今天,我们已经不再为“AI会不会写诗”而惊叹。真正的问题变成了:如何让这些强大的模型真正服务于人? GPT、通义千问…

作者头像 李华