`

线性回归与梯度下降算法

阅读更多
知识点:

线性回归概念
梯度下降算法
        l  批量梯度下降算法

        l  随机梯度下降算法

        l  算法收敛判断方法

1.1   线性回归

在统计学中,线性回归(Linear Regression)是利用称为线性回归方程的最小 平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。

回归分析中,只包括一个自变量和一个因变量,且二者的关系可用一条直线近似表示,这种回归分析称为一元线性回归分析。如果回归分析中包括两个或两个以上的自变量,且因变量和自变量之间是线性关系,则称为多元线性回归分析。

下面我们来举例何为一元线性回归分析,图1为某地区的房屋面积(feet)与价格($)的一个数据集,在该数据集中,只有一个自变量面积(feet),和一个因变量价格($),所以我们可以将数据集呈现在二维空间上,如图2所示。利用该数据集,我们的目的是训练一个线性方程,无限逼近所有数据点,然后利用该方程与给定的某一自变量(本例中为面积),可以预测因变量(本例中为房价)。本例中,训练所得的线性方程如图3所示。

                      图1、房价与面积对应数据集

                           图2、二维空间上的房价与面积对应图

                           图3、线性逼近

同时,分析得到的线性方程为:


接下来还是该案例,举一个多元线性回归的例子。如果增添了一个自变量:房间数,那么数据集可以如下所示:


                                 图4、房价与面积、房间数对应数据集

那么,分析得到的线性方程应如下所示:



因此,无论是一元线性方程还是多元线性方程,可统一写成如下的格式:



上式中x0=1,而求线性方程则演变成了求方程的参数ΘT。

线性回归假设特征和结果满足线性关系。其实线性关系的表达能力非常强大,每个特征对结果的影响强弱可以有前面的参数体现,而且每个特征变量可以首先映射到一个函数,然后再参与线性计算,这样就可以表达特征与结果之间的非线性关系。

1.2   梯度下降算法

为了得到目标线性方程,我们只需确定公式(3)中的ΘT,同时为了确定所选定的的ΘT效果好坏,通常情况下,我们使用一个损失函数(loss function)或者说是错误函数(error function)来评估h(x)函数的好坏。该错误函数如公式(4)所示。



如何调整ΘT以使得J(Θ)取得最小值有很多方法,其中有完全用数学描述的最小二乘法(min square)和梯度下降法。

1.2.1   批量梯度下降算法

由之前所述,求ΘT的问题演变成了求J(Θ)的极小值问题,这里使用梯度下降法。而梯度下降法中的梯度方向由J(Θ)对Θ的偏导数确定,由于求的是极小值,因此梯度方向是偏导数的反方向。



公式(5)中α为学习速率,当α过大时,有可能越过最小值,而α当过小时,容易造成迭代次数较多,收敛速度较慢。假如数据集中只有一条样本,那么样本数量,所以公式(5)中



所以公式(5)就演变成:



当样本数量m不为1时,将公式(5)中由公式(4)带入求偏导,那么每个参数沿梯度方向的变化值由公式(7)求得。



初始时ΘT可设为,然后迭代使用公式(7)计算ΘT中的每个参数,直至收敛为止。由于每次迭代计算ΘT时,都使用了整个样本集,因此我们称该梯度下降算法为批量梯度下降算法(batch gradient descent)。

1.2.2  随机梯度下降算法

当样本集数据量m很大时,批量梯度下降算法每迭代一次的复杂度为O(mn),复杂度很高。因此,为了减少复杂度,当m很大时,我们更多时候使用随机梯度下降算法(stochastic gradient descent),算法如下所示:



即每读取一条样本,就迭代对ΘT进行更新,然后判断其是否收敛,若没收敛,则继续读取样本进行处理,如果所有样本都读取完毕了,则循环重新从头开始读取样本进行处理。

这样迭代一次的算法复杂度为O(n)。对于大数据集,很有可能只需读取一小部分数据,函数J(Θ)就收敛了。比如样本集数据量为100万,有可能读取几千条或几万条时,函数就达到了收敛值。所以当数据量很大时,更倾向于选择随机梯度下降算法。

但是,相较于批量梯度下降算法而言,随机梯度下降算法使得J(Θ)趋近于最小值的速度更快,但是有可能造成永远不可能收敛于最小值,有可能一直会在最小值周围震荡,但是实践中,大部分值都能够接近于最小值,效果也都还不错。

1.2.3  算法收敛判断方法

参数ΘT的变化距离为0,或者说变化距离小于某一阈值(ΘT中每个参数的变化绝对值都小于一个阈值)。为减少计算复杂度,该方法更为推荐使用。
J(Θ)不再变化,或者说变化程度小于某一阈值。计算复杂度较高,但是如果为了精确程度,那么该方法更为推荐使用。
分享到:
评论

相关推荐

    分享一下利用sklearn进行线性回归与梯度下降算法代码实践

    ### 知识点一:线性回归的基本概念 线性回归是一种用于预测连续型目标变量的方法,通过拟合数据中的最佳线性关系来进行预测。...这些知识点对于初学者理解和掌握线性回归以及相关算法具有重要的意义。

    线性回归与梯度下降法

    线性回归与梯度下降法 机器学习中的一种重要算法是线性回归,它可以用来预测连续值标签的输出。在线性回归中,我们假设输入特征和输出变量之间存在线性关系,然后使用梯度下降法来确定模型参数。 线性回归 线性...

    多元线性回归梯度下降算法的MATLAB实现

    用于多变量线性回归的梯度下降算法的 MATLAB 实现。此代码示例包括, 特征缩放选项 基于梯度范数容差或固定迭代次数的算法终止选择 具有随机指数的随机特征向量(确切的函数关系不是线性的,而是具有特征向量的随机...

    机器学习(线性回归和梯度下降算法的python实现).zip

    在机器学习领域,线性回归和梯度下降算法是基础且重要的概念,它们在数据分析和预测模型构建中扮演着核心角色。本资料包"机器学习(线性回归和梯度下降算法的python实现)"旨在帮助你理解并掌握这两种算法的Python实现...

    梯度下降算法线性回归数据

    梯度下降算法线性回归数据

    线性回归与梯度下降

    ### 线性回归与梯度下降 #### 线性回归定义 线性回归是一种基本的统计预测方法,主要用于预测连续数值型的目标变量。它假设目标变量与自变量之间存在线性关系,即可以通过一个或多个自变量的线性组合来预测目标...

    1-3 线性回归-梯度下降算法运行展示.html

    1-3 线性回归-梯度下降算法运行展示.html

    线性回归的梯度下降.zip

    而梯度下降法是优化算法的一种,尤其在训练机器学习模型时,如线性回归,用于寻找最小化损失函数(代价函数)的最优参数。 线性回归模型的数学表示通常为 \( y = wx + b \),其中 \( y \) 是目标变量,\( x \) 是...

    17.4.14(线性回归的梯度下降算法和正规方程)1

    在这个主题中,我们将深入探讨线性回归的两个重要方面:梯度下降算法和正规方程。 首先,让我们来理解梯度下降算法。这是一个优化算法,常用于求解最小化问题,如线性回归中的最小二乘误差。在训练线性回归模型时,...

    多元线性回归及其算法实现(梯度下降法)

    上一篇文章讲述了梯度下降法的数学思想,趁热打铁,这篇博客笔者将使用梯度下降法完成多元线性回归,话不多说,直接开始。 我们假设我们的目标函数是长这样的: import numpy as np import pandas as pd # 读入...

    梯度下降法实现线性回归

    在本示例中,我们将探讨如何利用梯度下降法来实现线性回归模型,该模型是机器学习中的一个核心概念。MATLAB作为一种强大的数值计算环境,是实现这一算法的理想工具。 首先,让我们理解线性回归的基本思想。线性回归...

    机器学习中用梯度下降法实现线性回归的MATLAB源代码.rar

    梯度下降法是线性回归中最常用的优化算法之一,用于寻找模型参数的最佳值,以最小化损失函数。在这个案例中,我们将深入探讨如何使用MATLAB实现这个过程。 吴恩达是一位知名的计算机科学家和教育家,他的在线课程...

    梯度下降_梯度下降_

    总结来说,梯度下降是优化问题中的关键算法,尤其在机器学习中的线性回归模型中起着重要作用。掌握梯度下降的原理和实现方法,有助于我们更好地理解和应用这些模型。在Python中,结合科学计算库如NumPy和Pandas,...

    机器学习中的梯度下降算法及其优化变体-可实现的-有问题请联系博主,博主会第一时间回复!!!

    内容概要:本文详细介绍了梯度下降算法及其几种重要变体(SGD、Momentum、ASGD、Rprop、AdaGrad、AdaDelta、RM SProp、Adam、Nadam、AdamW、RAdam)在机器学习中的应用。这些优化算法因其结构简单、稳定性好且易于...

    机器学习_梯度下降算法实现

    在机器学习中,梯度下降被广泛应用于线性回归、逻辑回归、支持向量机、神经网络等模型的参数优化。例如,在神经网络中,通过梯度下降更新权重和偏置,以最小化损失函数,从而提高网络的预测精度。 五、测试数据的...

Global site tag (gtag.js) - Google Analytics