`
kavy
  • 浏览: 891441 次
  • 性别: Icon_minigender_1
  • 来自: 上海
社区版块
存档分类
最新评论

Java实现一元线性回归

 
阅读更多

Java实现一元线性回归

 
发布时间:2006.04.28 22:26     来源:月光软件站    作者:

 

 

最近在写一个荧光图像分析软件,需要自己拟合方程。一元回归线公式的算法参考了《Java数值方法》,拟合度R^2(绝对系数)是自己写的,欢迎讨论。计算结果和Excel完全一致。

总共三个文件:

DataPoint.java

/**
 * A data point for interpolation and regression.
 */
public class DataPoint
{
    /** the x value */  public float x;
    /** the y value */  public float y;

    /**
     * Constructor.
     * @param x the x value
     * @param y the y value
     */
    public DataPoint(float x, float y)
    {
        this.x = x;
        this.y = y;

    }
}

/**
 * A least-squares regression line function.
 */

import java.util.*;
import java.math.BigDecimal;

public class RegressionLine 
 //implements Evaluatable
{
    /** sum of x */     private double sumX;
    /** sum of y */     private double sumY;
    /** sum of x*x */   private double sumXX;
    /** sum of x*y */   private double sumXY;
    /** sum of y*y */   private double sumYY;
    /** sum of yi-y */   private double sumDeltaY;
    /** sum of sumDeltaY^2 */   private double sumDeltaY2;
    /**误差 */
    private double sse;  
    private double sst;  
    private double E;
    private String[] xy ;
    
    private ArrayList listX ;
    private ArrayList listY ;
    
    private int XMin,XMax,YMin,YMax;
    
    /** line coefficient a0 */  private float a0;
    /** line coefficient a1 */  private float a1;

    /** number of data points */        private int     pn ;
    /** true if coefficients valid */   private boolean coefsValid;

    /**
     * Constructor.
     */
    public RegressionLine() {
     XMax = 0;
     YMax = 0;
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
    }

    /**
     * Constructor.
     * @param data the array of data points
     */
    public RegressionLine(DataPoint data[])
    { 
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
        for (int i = 0; i < data.length; ++i) {
            addDataPoint(data[i]);
        }
    }

    /**
     * Return the current number of data points.
     * @return the count
     */
    public int getDataPointCount() { return pn; }

    /**
     * Return the coefficient a0.
     * @return the value of a0
     */
    public float getA0()
    {
        validateCoefficients();
        return a0;
    }

    /**
     * Return the coefficient a1.
     * @return the value of a1
     */
    public float getA1()
    {
        validateCoefficients();
        return a1;
    }

    /**
     * Return the sum of the x values.
     * @return the sum
     */
    public double getSumX() { return sumX; }

    /**
     * Return the sum of the y values.
     * @return the sum
     */
    public double getSumY() { return sumY; }

    /**
     * Return the sum of the x*x values.
     * @return the sum
     */
    public double getSumXX() { return sumXX; }

    /**
     * Return the sum of the x*y values.
     * @return the sum
     */
    public double getSumXY() { return sumXY; }
    
    public double getSumYY() { return sumYY; }
    
    public int getXMin() {
  return XMin;
 }

 public int getXMax() {
  return XMax;
 }

 public int getYMin() {
  return YMin;
 }

 public int getYMax() {
  return YMax;
 }
    
    /**
     * Add a new data point: Update the sums.
     * @param dataPoint the new data point
     */
    public void addDataPoint(DataPoint dataPoint)
    {
        sumX  += dataPoint.x;
        sumY  += dataPoint.y;
        sumXX += dataPoint.x*dataPoint.x;
        sumXY += dataPoint.x*dataPoint.y;
        sumYY += dataPoint.y*dataPoint.y;
        
        if(dataPoint.x > XMax){
         XMax = (int)dataPoint.x;
        }
        if(dataPoint.y > YMax){
         YMax = (int)dataPoint.y;
        }
        
        //把每个点的具体坐标存入ArrayList中,备用
        
        xy[0] = (int)dataPoint.x+ "";
        xy[1] = (int)dataPoint.y+ "";
        if(dataPoint.x!=0 && dataPoint.y != 0){
        System.out.print(xy[0]+",");
        System.out.println(xy[1]);        
        
        try{
        //System.out.println("n:"+n);
        listX.add(pn,xy[0]);
        listY.add(pn,xy[1]);
        }
        catch(Exception e){
         e.printStackTrace();
        }                
        
        /*
        System.out.println("N:" + n);
        System.out.println("ArrayList listX:"+ listX.get(n));
        System.out.println("ArrayList listY:"+ listY.get(n));
        */
        }        
        ++pn;
        coefsValid = false;
     }

    /**
     * Return the value of the regression line function at x.
     * (Implementation of Evaluatable.)
     * @param x the value of x
     * @return the value of the function at x
     */
    public float at(int x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }
    
    public float at(float x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }

    /**
     * Reset.
     */
    public void reset()
    {
        pn = 0;
        sumX = sumY = sumXX = sumXY = 0;
        coefsValid = false;
    }

    /**
     * Validate the coefficients.
     * 计算方程系数 y=ax+b 中的a
     */
    private void validateCoefficients()
    {
        if (coefsValid) return;

        if (pn >= 2) {
            float xBar = (float) sumX/pn;
            float yBar = (float) sumY/pn;

            a1 = (float) ((pn*sumXY - sumX*sumY)
                            /(pn*sumXX - sumX*sumX));
            a0 = (float) (yBar - a1*xBar);
        }
        else {
            a0 = a1 = Float.NaN;
        }

        coefsValid = true;
    }
    
    /**
     * 返回误差
     */
    public double getR(){   
     //遍历这个list并计算分母
     for(int i = 0; i < pn -1; i++)    {         
      float Yi= (float)Integer.parseInt(listY.get(i).toString());
      float Y = at(Integer.parseInt(listX.get(i).toString())); 
      float deltaY = Yi - Y;    
      float deltaY2 = deltaY*deltaY;
      /*
      System.out.println("Yi:" + Yi);
      System.out.println("Y:" + Y);
      System.out.println("deltaY:" + deltaY);
      System.out.println("deltaY2:" + deltaY2);
      */
          
         sumDeltaY2 += deltaY2;
         //System.out.println("sumDeltaY2:" + sumDeltaY2);
         
     }     
      
     sst = sumYY - (sumY*sumY)/pn;     
        //System.out.println("sst:" + sst);
     E =1- sumDeltaY2/sst;
     
     
     return round(E,4) ;
    }
    
    //用于实现精确的四舍五入
    public double round(double v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();

    }    
    
    public  float round(float v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();

    }    
}

演示程序:

LinearRegression.java

/**
 * <p><b>Linear Regression</b>
 * <br> 
 * Demonstrate linear regression by constructing the regression line for a set
 * of data points.
 * 
 * <p>require DataPoint.java,RegressionLine.java 
 * 
 * <p>为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))
 * <p><b>回归直线方程如下: f(x)=a1x+a0   </b>
 * <p><b>斜率和截距的计算公式如下:</b>
 * <br>n: 数据点个数
 * <p>a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)
 * <br>a0=(SumY - SumY * a1)/n 
 * <br>(也可表达为a0=averageY-a1*averageX)
 * 
 * <p><b>画线的原理:两点成一直线,只要能确定两个点即可</b><br>
 *  第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。
 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
 * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)
 * 
 * <p><b>拟合度计算:(即Excel中的R^2)</b>
 * <p> *R2 = 1 - E
 * <p>误差E的计算:E = SSE/SST
 * <p>SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
 * <p> 
 */
public class LinearRegression
{
    private static final int MAX_POINTS = 10;
    private double E;

    /**
  * Main program.
  * 
  * @param args
  *            the array of runtime arguments
  */
    public static void main(String args[])
    {
        RegressionLine line = new RegressionLine();

        line.addDataPoint(new DataPoint(20, 136));
        line.addDataPoint(new DataPoint(40, 143));
        line.addDataPoint(new DataPoint(60, 152));
        line.addDataPoint(new DataPoint(80, 162));
        line.addDataPoint(new DataPoint(100, 167));
        
        printSums(line);
        printLine(line);
    }

    /**
  * Print the computed sums.
  * 
  * @param line
  *            the regression line
  */
    private static void printSums(RegressionLine line)
    {
        System.out.println("\n数据点个数 n = " + line.getDataPointCount());
        System.out.println("\nSum x  = " + line.getSumX());
        System.out.println("Sum y  = " + line.getSumY());
        System.out.println("Sum xx = " + line.getSumXX());
        System.out.println("Sum xy = " + line.getSumXY());
        System.out.println("Sum yy = " + line.getSumYY());       
        
    }

    /**
  * Print the regression line function.
  * 
  * @param line
  *            the regression line
  */
    private static void printLine(RegressionLine line)
    {
        System.out.println("\n回归线公式:  y = " +
                           line.getA1() +
                           "x + " + line.getA0());
        System.out.println("拟合度:     R^2 = " + line.getR());
    } 
    
}

 

分享到:
评论

相关推荐

    JAVA实现的一元线性回归 LINEAR REGRESSION

    在Java中实现一元线性回归,通常涉及以下步骤: 1. **数据预处理**:首先,你需要收集并组织数据,这通常包括读取数据集,可能来自CSV、Excel或其他格式的文件。这些数据应包含自变量 'x' 和因变量 'y' 的值。 2. ...

    一元线性回归分析与预测

    在Java编程环境中,我们可以使用各种库来实现一元线性回归分析。例如,Apache Commons Math库提供了一系列的统计功能,包括回归分析。首先,我们需要导入数据并将其转化为可以处理的数值格式。这通常涉及读取CSV或...

    java 实现的多元线性回归分析

    本篇文章介绍了一元线性回归和多元线性回归的Java实现方法,并通过具体的代码片段详细解释了各个步骤。通过这些方法,我们可以有效地建立预测模型,评估模型的准确性和可靠性。最小二乘法作为核心算法,在实际应用中...

    Java实现最小平方误差一元线性回归

    一元线性回归采用最小平方误差法并以Java实现。

    一元与多元线性回归_spss回归_memberxlc_一元多元线性回归JAVA_atqiz_

    例如,`LinearRegressionSimple`可能是一个简单的线性回归实现,而`LinearRegression`可能提供更复杂的功能,如处理多元回归。`DataPointSimple`可能用于表示数据集中的单个观测值,包含自变量和因变量的值。 在...

    多元线性回归方程

    在本篇文章中,我们将详细介绍多元线性回归的基本原理、算法实现,并对给定的Java代码进行解析。 #### 多元线性回归原理 在数学上,多元线性回归模型可以表示为: \[ Y = \beta_0 + \beta_1X_1 + \beta_2X_2 + ......

    线性回归数据集学习时间与分数数据集

    学习时间与分数数据集,25条数据

    回归算法Java程序

    在这个“回归算法Java程序”中,我们可以深入探讨如何使用Java来实现多元线性回归和一元回归。 一、回归分析基础 回归分析是一种统计方法,用于研究两个或多个变量之间的关系,尤其是一个变量(因变量)如何随其他...

    最小二乘法回归模型java.doc

    最小二乘法一元线性回归模型的Java实现(包含代码) 最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得...

    利用线性回归预测隧道车流量

    在最简单的形式下,线性回归模型是一元线性回归,表示为 y = ax + b,其中 y 是目标变量(车流量),a 是斜率(影响大小),x 是输入变量(例如时间),b 是截距(当输入为0时的预测值)。 3. 模型学习与预测:模型...

    R语言案例回归分析.pdf

    进行多元回归分析时,本文首先尝试了一元回归模型,接着在模型中加入了二次项和三次项,以探索是否存在非线性关系。在多元回归分析中,本文利用所有子集法选择变量,这是一种寻找最优回归模型的搜索方法,可以在所有...

    LinearRegressionModel:普通最小二乘线性回归模型的实现,该模型使用矩阵运算来计算权重

    线性回归试图找到一条直线(对于一元线性回归)或超平面(对于多元线性回归),使得数据点到这条线的距离之和最小。这条线被称为回归线,其方程通常表示为y = wx + b,其中y是因变量,x是自变量,w是权重,b是截距。...

    最小二乘法直线拟合.doc

    为了解决这些问题,学者们提出了其他回归技术,如岭回归(Ridge Regression)和套索回归(Lasso Regression),这些方法通过加入惩罚项来减少模型的复杂度或对参数施加限制,从而提高模型的泛化能力。 本文通过实例...

    java面试题(最新的java面试题)

    在一元线性回归中,通常使用最小二乘法来确定直线的斜率和截距,使得所有数据点到直线的距离平方和最小。 以上就是针对给定的 Java 面试题中涉及的关键知识点的详细介绍。这些知识点涵盖了 Java 编程语言的基础知识...

    zuixiaoercheng.rar_最小二乘曲线拟合

    在最小二乘法中,我们寻找一条直线(对于一元线性回归)或超平面(多元线性回归),使得所有数据点到这条直线的垂直距离(即残差)的平方和最小。 2. **残差**:残差是实际观测值与模型预测值之间的差值。在最小二...

    非线性电路混沌现象探究以及基于-Multisim仿真设计.pdf

    这种元件的伏安特性曲线可以通过一元线性回归进行拟合分析。非线性负阻的特点是其阻值随电压或电流的增加而减小,这在某些情况下可以导致电路动态行为的显著变化,进而产生混沌状态。 【Multisim仿真设计】Multisim...

    人工智能相关课程介绍 (2).docx

    应用统计学与 R 语言建模部分,主要内容包括数据的描述性分析、随机变量的概率分布、参数估计、假设检验、类别变量分析、方差分析、一元线性回归、多元线性回归、时间序列预测、聚类分析等,旨在培养学生的统计学和 ...

    Java开发实战1200例(第1卷).(清华出版.李钟尉.陈丹丹).part3

    实例224 一元线性回归计算 282 实例225 实数矩阵的运算 283 实例226 复数的常见运算 284 实例227 T分布常用计算 285 10.3 Commons IO组件简介 286 实例228 简化文件(夹)删除 286 实例229 简化文件(夹)复制 287 ...

Global site tag (gtag.js) - Google Analytics