news 2026/5/1 10:02:28

线性拟合模型

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
线性拟合模型

线性拟合模型

一、数据准备部分

importnumpyasnpimportkerasimportmatplotlib.pyplotasplt train_X=np.asarray([30.0,40.0,60.0,80.0,100.0,120.0,140.0])train_Y=np.asarray([320.0,360.0,400.0,455.0,490.0,546.0,580.0])train_X/=100.0train_Y/=100.0
  • train_Xtrain_Y是人工构造的训练数据(x 和 y)。

  • 除以 100 是为了归一化(Normalization),将数据范围从 [30-140] 和 [320-580] 缩放到 [0.3-1.4] 和 [3.2-5.8]),有助于神经网络更快收敛。

  • 这是典型的监督学习回归问题:输入 x → 预测 y。

二、可视化函数

defplot_points(x,y,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')plt.scatter(x,y)plt.show()defplot_line(W,b,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')x=np.linspace(0.0,2.0,num=100)y=W*x+b plt.plot(x,y)plt.show()
  • plot_points:画散点图,展示原始数据。

  • plot_line:根据斜率W和截距b画出拟合直线。

三、模型构建

model=keras.models.Sequential()model.add(keras.layers.Dense(units=1,input_dim=1))
  • 只有一层:Dense全连接层
  • units=1:只有一个神经元(输出一个值)
  • input_dim=1:输入数据是一维的(一个特征)
  • 相当于数学公式:y = Wx + b,其中:
    • W:权重(weight),相当于斜率
    • b:偏置(bias),相当于截距

四、编译模型

model.compile(optimizer='sgd',loss='mean_squared_error')
  • optimizer='sgd':使用随机梯度下降优化器
    • SGD是最基础、最经典的优化算法
    • 相比adam,SGD更简单,适合这种简单线性问题
  • loss='mean_squared_error':使用均方误差作为损失函数
    • 计算公式:MSE = Σ(y_pred - y_true)² / n
    • 这是回归问题最常用的损失函数

五、训练模型

history=model.fit(x=train_X,y=train_Y,batch_size=1,epochs=10)
  • batch_size=1批大小为1(在线学习/随机梯度下降)
    • 每看一个样本就更新一次权重
    • 梯度更新频繁,波动较大
    • 内存占用小,适合小数据集
  • epochs=10:训练10轮
    • 把7个样本反复训练10遍
    • 总共训练 7 × 10 = 70 次更新

注意history会记录训练过程中的loss变化,可以用于后续分析

六. 结果可视化

plot_line(model.get_weights()[0][0][0],model.get_weights()[1][0],title_name='Current_Model')
  • model.get_weights()[0]:获取权重W(斜率)
    • [0][0][0]是因为权重的形状是(1,1),需要索引到具体数值
  • model.get_weights()[1]:获取偏置b(截距)
    • [0]是因为偏置的形状是(1,),需要索引到具体数值

这个模型在做什么?

1. 数学本质

这个模型其实就是用神经网络的方式来实现最小二乘法线性回归

  • 要找一条直线y = Wx + b
  • 让这条直线最接近所有数据点
  • "接近"的标准是:均方误差最小

2. 训练过程(SGD)

初始化:W=随机值,b=随机值for10:for每个样本(x_i,y_i):1.计算预测值:y_pred=W*x_i+b2.计算误差:error=y_pred-y_i3.计算梯度:dW=2*error*x_i# 对W的梯度db=2*error# 对b的梯度4.更新参数:W=W-learning_rate*dW b=b-learning_rate*db

完整代码:

importnumpyasnpimportkerasimportmatplotlib.pyplotasplt train_X=np.asarray([30.0,40.0,60.0,80.0,100.0,120.0,140.0])train_Y=np.asarray([320.0,360.0,400.0,455.0,490.0,546.0,580.0])train_X/=100.0train_Y/=100.0#用于对数据点进行可视化defplot_points(x,y,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')plt.scatter(x,y)plt.show()defplot_line(W,b,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')x=np.linspace(0.0,2.0,num=100)y=W*x+b plt.plot(x,y)plt.show()plot_points(train_X,train_Y,title_name='Training Points')#建立线性拟合模型,由斜率和偏移两个参数构成,相当于神经元数为1的一层全连接model=keras.models.Sequential()model.add(keras.layers.Dense(units=1,input_dim=1))#成本函数采用均差误差,优化方法使用随机梯度下降model.compile(optimizer='sgd',loss='mean_squared_error')#模型迭代10个轮次,用单样本的方式进行优化history=model.fit(x=train_X,y=train_Y,batch_size=1,epochs=10)plot_line(model.get_weights()[0][0][0],model.get_weights()[1][0],title_name='Current_Model')

附解释可视化函数部分
1.散点图
def plot_points(x, y, title_name):

  • 定义一个名为plot_points的函数。

    x:横坐标数据(如你的 train_X)
    y:纵坐标数据(如你的 train_Y)
    title_name:图表的标题(字符串)

​ plt.title(title_name) # 设置图表标题
​ plt.xlabel(‘x’) # 设置x轴标签
​ plt.ylabel(‘y’) # 设置y轴标签
​ plt.scatter(x, y) # 绘制散点图
​ plt.show() # 显示图表

2.直线图
def plot_line(W, b, title_name):
plt.title(title_name) # 设置图表标题
plt.xlabel(‘x’) # 设置x轴标签
plt.ylabel(‘y’) # 设置y轴标签

​ x = np.linspace(0.0, 2.0, num=100) # 生成100个等间距的x值
​ np: numpy模块的别名
​ .linspace(): 生成等差数列(linear space)
​ 参数:
​ 0.0: 起始值(start)
​ 2.0: 结束值(stop)
​ num=100: 生成100个点

​ y = W * x + b # 计算对应的y值

​ plt.plot(x, y) # 绘制折线图(这里是直线)
​ .plot(): 绘制折线图
​ 参数:(x, y)坐标点

​ plt.show() # 显示图表

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/30 5:11:51

职场笔杆子必看!2025公文写作软件TOP3对比

作为一名体制内笔杆子,写作公文的痛谁懂,临时的派稿任务,格式要求超严格,内容要求严谨合规,加班改稿也都是经常的事。 随着AI的不断发展,人工智能的写作能力越来越强,为写作带来显著的提效&…

作者头像 李华
网站建设 2026/4/30 10:22:03

Jenkins 2.528.3 与 GitLab 深度集成:实现自动构建

在 Jenkins 2.528.3 版本中,实现 GitLab 代码推送(Push)后自动触发构建,主要依赖于 GitLab Plugin 或 Generic Webhook Trigger Plugin。以下是两种主流方法的详细配置指南,帮助构建高效的自动化流水线。核心配置概览自…

作者头像 李华
网站建设 2026/5/1 3:42:26

算法基础-多源最短路

多源最短路 多源最短路:即图中每对顶点间的最短路径。floyd 算法本质是动态规划,⽤来求任意两个结点之间的最短路,也称插点法。通过不断在两点之间加 ⼊新的点,来更新最短路。 适⽤于任何图,不管有向⽆向,…

作者头像 李华
网站建设 2026/4/18 12:29:28

新生态・新动能:人工智能产业格局分析

‍当前,人工智能产业已成为驱动数字经济高质量发展的核心引擎,不断推动产业生态建设和效能提升,各地政府积极响应推进科研创新与算力基础设施建设,因地制宜出台特色政策。持续探索新型大模型,推动AI产业向更高水平迈进。 一、人工…

作者头像 李华
网站建设 2026/5/1 8:15:23

Spring Cloud Gateway 核心特性与实践指南

摘要 本文深入探讨Spring Cloud Gateway在微服务架构中的核心作用,包括路由、过滤、限流等关键功能的实现原理与实践应用。通过详细的代码示例和架构分析,帮助开发者掌握Spring Cloud Gateway的最佳实践方法。 1. 引言 1.1 Spring Cloud Gateway 简介 Sp…

作者头像 李华
网站建设 2026/5/1 7:55:21

血液H组二糖—解析血型奥秘与疾病标志的核心糖结构 146076-26-8

血液H组二糖是ABO血型系统中最关键的抗原决定前体结构,被视为血型特异性表达的分子基石。它不仅构成了人类红细胞表面最基本的抗原表位,更在细胞识别、微生物感染、肿瘤发展及免疫调节等一系列生物学过程中扮演着核心角色。作为寡糖研究中的重要标准品和…

作者头像 李华