`
slowman
  • 浏览: 38369 次
  • 性别: Icon_minigender_1
  • 来自: 武汉
社区版块
存档分类
最新评论

用JAVA进行神经网络建模及泛化能力测试

    博客分类:
  • AI
阅读更多

 

作者:桂子山下一棵草   email: slowguy@qq.com 

 

 

题目:

                     表一澳大利亚野兔眼睛晶状体重量与年龄的对应关系

 

 

编号

年龄()

重量(mg)

年龄()

重量(mg)

年龄()

重量(mg)

年龄()

重量(mg)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

15

15

15

18

28

29

37

37

44

50

50

60

61

64

65

65

72

75

21.66

22.75

22.3

31.25

44.79

40.55

50.25

46.88

52.03

63.47

61.13

81

73.09

79.09

79.51

65.31

71.9

86.1

75

82

85

91

91

97

98

125

142

142

147

147

150

159

165

183

192

195

94.6

92.5

105

101.7

102.9

110

104.3

134.9

130.68

140.58

155.3

152.2

144.5

142.15

139.81

153.22

145.72

161.1

218

218

219

224

225

227

232

232

237

246

258

276

285

300

301

305

312

317

174.18

173.03

173.54

178.86

177.68

173.73

159.98

161.29

187.07

176.13

183.4

186.26

189.66

186.09

186.7

186.8

195.1

216.41

338

347

354

357

375

394

513

535

554

591

648

660

705

723

756

768

860

203.23

188.38

189.7

195.31

202.63

224.82

203.3

209.7

233.9

234.7

244.3

231

242.4

230.77

242.57

232.12

246.7

 

澳大利亚野兔眼睛晶状体的重量为年龄的函数。利用BP算法,设计一个多层感知器,为表中的数据集提供一个非线性逼近,并测试其泛化能力。

算法源码:

 

package com.lwm.cn.althom;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.GregorianCalendar;
import java.util.Random;

public class BackProp {
	private int randomPrecision = 8; // 生成double型随机数的精度,默认为6位小数

	private int input_dimension; // 输入向量的维度

	private int output_dimension; // 输出向量的维数

	private int mid_dimension; // 隐层结点的个数

	private double[][] V; // 输入层到隐层的权值矩阵

	private double[][] W; // 隐层到输出层的权值矩阵

	private double[] inputArray; // 输入层向量

	private double[] midArray; // 隐层输出向量

	private double[] outputArray; // 输出层向量

	private double[] teacherArray; // 期望层向量

	private double mid_Threshold; // 隐层阈值

	private double out_Threshold; // 输出层阈值

	private double[] midError; // 隐层的误差

	private double[] outError; // 输出层的误差

	private double totalError = 0.0;

	private double outPrecision; // 要达到的精度

	private double learnRate; // 学习的速率

	private int trainTotal = 3000; // 学习1000次

	private boolean isQualify = false; // 用于判断是不是达到精度要求

	private ArrayList<SampleNode> trainArray = new ArrayList<SampleNode>(100); // 存入训练集

	private ArrayList<SampleNode> testArray = new ArrayList<SampleNode>(100); // 存放测试集

	private BufferedWriter bw = null; // 用于将学习和测试过程写于文件

	Date startTime;

	// SampleNode sample;

	// Math.random()
	public BackProp(double[][] v, double[][] w, int input_dimension,
			int output_dimension, int mid_dimension) {
		super();
		V = v;
		W = w;
		this.input_dimension = input_dimension;
		this.output_dimension = output_dimension;
		this.mid_dimension = mid_dimension;
	}

	/**
	 * 默认构造函数 ,对于本次实验,输入向量只有一个,输出也只有一个. 隐层结点的个数默认为4
	 * 
	 */
	public BackProp() {
		input_dimension = 1;
		output_dimension = 1;
		mid_dimension = 8;

		inputArray = new double[input_dimension];
		teacherArray = new double[output_dimension];
		midArray = new double[mid_dimension];
		outputArray = new double[output_dimension];

		V = new double[input_dimension][mid_dimension];
		W = new double[mid_dimension][input_dimension];

		midError = new double[mid_dimension];
		outError = new double[output_dimension];
	}

	/**
	 * 初始化函数 我认为一个完整的BP算法应该具备通用性,可以任意设置输入结点个数和隐层的层数及每一层的结点个数
	 * 初始化权值矩阵V和W,每个元素的值均为0-1之间的六位小数
	 */

	public void init()
	{
		// 记录程序开始时间及结束时间,以开始时间命名一个文件,用来保存学习和测试结果.
		startTime = new Date();
		SimpleDateFormat sdf = new SimpleDateFormat("yyyy年MM月dd日HH时mm分ss秒");
		String timeStr = sdf.format(startTime);
		String filePathName = "E:" + File.separator + timeStr + ".txt";
		try
		{
			bw = new BufferedWriter(new FileWriter(filePathName));
			bw.write("程序开始时间:" + timeStr + "\n");
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		mid_Threshold = MathExtend.round(Math.random(), randomPrecision); // 初始化隐层的阈值
		out_Threshold = MathExtend.round(Math.random(), randomPrecision); // 初始化输出层的阈值
		// 初始化V矩阵
		for (int i = 0; i < input_dimension; i++)
			for (int j = 0; j < mid_dimension; j++)
				V[i][j] = MathExtend.round(Math.random(), randomPrecision);

		// 初始化W矩阵
		for (int i = 0; i < mid_dimension; i++)
			for (int j = 0; j < output_dimension; j++)
				W[i][j] = MathExtend.round(Math.random(), randomPrecision);

		// 置总的误差为0,学习率为0-1之间的小数,网络训练后达到的精度为一正小数
		totalError = 0.0;
		learnRate = MathExtend.round(Math.random(), randomPrecision);
		// learnRate = 0.12;
		outPrecision = MathExtend.round(Math.random(), randomPrecision);

		try
		{
			StringBuilder sb = new StringBuilder();
			sb.append("本次实验随机生成的学习率: " + learnRate);
			sb.append("\n");
			sb.append("期望达到的精度为: " + outPrecision);
			sb.append("\n");
			bw.write(sb.toString());
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

		getTrainData(); // 取得训练集
		getTestData(); // 取得测试集
		normalized(); // 归一化
	}

	/**
	 * @author Administrator 输入层向隐层,隐层向输出层的传播
	 * 
	 */
	public void finish()
	{
		// Date endDate = new Date() ;

		try
		{
			bw.close();
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	public void forword()
	{
		int i, j;
		double temp_sum ; // 用于向量的内积
		// 输出层到隐层
		for (i = 0; i < mid_dimension; i++)
		{
			temp_sum = 0.0  ; //初始化为0
			for (j = 0; j < input_dimension; j++)
				temp_sum += V[j][i] * inputArray[j];
			temp_sum = temp_sum - mid_Threshold;
			midArray[i] = 1.0 / (1 + Math.exp(-temp_sum));
		}

		
		// 隐层到输出层
		for (i = 0; i < output_dimension; i++)
		{
			temp_sum = 0.0; // 初始化
			for (j = 0; j < mid_dimension; j++)
				temp_sum = W[j][i] * midArray[j];
			temp_sum = temp_sum - out_Threshold;
			outputArray[i] = 1.0 / (1 + Math.exp(-temp_sum));
		}
		// 计算误差,累加起来,
		temp_sum = 0.0;
		for (i = 0; i < output_dimension; i++)
		{
			temp_sum = teacherArray[i] - outputArray[i]; // 注意中,本设计中output_dimension=1的
			totalError += temp_sum * temp_sum / 2;
		}
		// printResult();
	}

	private void printResult()
	{
		/*
		 * StringBuilder sb = new StringBuilder() ;
		 * sb.append("输入数据:"+inputArray[0]); sb.append("
		 * 实际输出数据:"+outputArray[0]); sb.append(" 期望输出数据为:"+teacherArray[0]) ;
		 * sb.append("\\n") ; try { bw.write(sb.toString()); } catch
		 * (IOException e) { // TODO Auto-generated catch block
		 * e.printStackTrace(); }
		 */
		System.out.print("输入数据:" + inputArray[0]);
		System.out.print("   实际输出数据:" + outputArray[0]);
		System.out.println("   期望输出数据为:" + teacherArray[0]);
	}

	/**
	 * 反向调整权值矩阵
	 */
	public void adjustWeight()
	{
		double temp_sum = 0.0;
		int i, j;
		// 计算各层的误差信号  输出层
		for (i = 0; i < output_dimension; i++)
		{
			outError[i] = (teacherArray[i] - outputArray[i])
					* (1 - outputArray[i]) * outputArray[i];
		}
//    隐层误差
		for (i = 0; i < mid_dimension; i++)
		{
			temp_sum=0.0d ;
			for (j = 0; j < output_dimension; j++)
				temp_sum += outError[j] * W[i][j];
			midError[i] = temp_sum * (1 - midArray[i]) * midArray[i];
		}

		// 调整W权值矩阵
		for (i = 0; i < mid_dimension; i++)
		{
			for (j = 0; j < output_dimension; j++)
				W[i][j] += learnRate * outError[j] * midArray[i];
		}
		// 调整V权值矩阵

		for (i = 0; i < input_dimension; i++)
			for (j = 0; j < mid_dimension; j++)
				V[i][j] += learnRate * midError[j] * inputArray[i];

	}

	public void getTrainData()
	{
		String filePathName = "E:" + File.separator + "traindata.txt";
		BufferedReader br = null;
		try
		{
			br = new BufferedReader(new FileReader(filePathName));
		} catch (FileNotFoundException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		String s = null;
		SampleNode sNode = null;
		try
		{
			while ((s = br.readLine()) != null)
			{
				String data[] = s.trim().split("[\\s]+");
				if (data == null || data.length != 2)
				{

					System.out.println("traindata文件数据有问题!");
					return;
				}
				double in = Double.parseDouble(data[0]);
				double hope = Double.parseDouble(data[1]);
				sNode = new SampleNode(in, hope);
				trainArray.add(sNode);

				// trainArray.
			}
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		trainArray.trimToSize();
	}

	public void getTestData()
	{
		String fileName = "E:" + File.separator + "testdata.txt";
		BufferedReader br = null;
		try
		{
			br = new BufferedReader(new FileReader(fileName));
		} catch (FileNotFoundException e)
		{
			// TODO Auto-generated catch block
			System.out.println("testdata.txt文件不存在");
			e.printStackTrace();
		}

		String s = null;
		SampleNode sNode = null;
		try
		{
			while ((s = br.readLine()) != null)
			{
				String data[] = s.trim().split("[\\s]+");
				if (data == null || data.length != 2)
				{

					System.out.println("testdata文件数据有问题!");
					return;
				}
				double in = Double.parseDouble(data[0]);
				double hope = Double.parseDouble(data[1]);
				sNode = new SampleNode(in, hope);
				testArray.add(sNode);

				// trainArray.
			}
		} catch (IOException e)
		{
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		testArray.trimToSize();

	}

	/**
	 * 对输入数据进行归一化处理,将输入数据限制在[0,1]区间内
	 * 
	 */

	private void normalized()
	{
		if (trainArray == null || trainArray.size() == 0 || testArray == null
				|| testArray.size() == 0)
		{
			System.out.println("测试数据或者训练数据有问题!");
			return;
		}
		SampleNode sNode = null;
		// 训练数据归一化
		int size = trainArray.size();
		int i = 0;
		while (i < size)
		{
			sNode = trainArray.get(i);
			double in = sNode.in;
			double hope = sNode.hope;
			in /= 1000.0; // 归一
			hope /= 250.0;
			sNode.in = in;
			sNode.hope = hope;
			trainArray.set(i, sNode);
			i++;
		}

		size = testArray.size();
		i = 0;
		// 测试数据归一化
		while (i < size)
		{
			sNode = testArray.get(i);
			double in = sNode.in;
			double hope = sNode.hope;
			in /= 1000.0; // 归一
			hope /= 250.0;
			sNode.in = in;
			sNode.hope = hope;
			trainArray.set(i, sNode);
			i++;
		}

	}

	public void startTrain()
	{
		if (trainArray == null || trainArray.size() == 0)
			return;
		System.out.println("训练开始");
		System.out.println("当前学习速率:" + learnRate);
		System.out.println("期望精度为:" + outPrecision);
		int trainConunter = 0;
		while (trainConunter++ < trainTotal)
		{
			System.out.println("第" + trainConunter + "次训练开始:");
			for (SampleNode sNode : trainArray)
			{
				// 说明:在本设计中inputArray,和teacherArray虽然都是数组,但均只有一个元素.
				// 本人为了综合虑,才将设为数组的.
				inputArray[0] = sNode.in;
				teacherArray[0] = sNode.hope;
				forword(); // 学习一次
				printResult();
			} // 至此,所有训练集全部学习完毕,下面应该进行权值调整.
			/*System.out.println("此次学习后,总的误差为:" + totalError);
			StringBuilder sb = new StringBuilder();
			sb.append("第" + trainConunter);
			sb.append("次学习后,总的误差为:" + totalError);
			sb.append("\n");*/
			try
			{
			//	bw.write(sb.toString());
				bw.write(Double.toString(totalError)+"\n") ;
			} catch (IOException e)
			{
				e.printStackTrace();
			}
			adjustWeight(); // 集体主义原则来调整权值

			if (totalError <= outPrecision)
			{
				isQualify = true; // 置标志位为真,表示达到要求

				break;
			}
			totalError = 0.0; // 误差初化
		}

		Date endTime = new Date();
		SimpleDateFormat sdf = new SimpleDateFormat("yyyy年MM月dd日HH时mm分ss秒");
		String endtimeStr = sdf.format(endTime);
		long gap = endTime.getTime() - this.startTime.getTime();
		StringBuilder sb = new StringBuilder();
		try
		{
			sb.append("训练结束时间为:" + endtimeStr);
			sb.append("\n");
			sb.append("总的学习时间为:" + gap);
			sb.append("微秒\n");
			sb.append("********************************************\n");
			bw.write(sb.toString());
		} catch (IOException e1)
		{
			// TODO Auto-generated catch block
			e1.printStackTrace();
		}

		if (!isQualify)
		{
			System.out.println("达到训练次数,训练结束!");
			try
			{
				bw.write("训练次数:" + trainTotal + "次\n");
			} catch (IOException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		} else
		{
			try
			{
				bw.write("达到精度要求,学习完毕!\n");
			} catch (IOException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			System.out.println("达到要求的精度,训练结束!");
		}
	}

	public void startTest()
	{

		if (testArray == null || testArray.isEmpty() == true)
			return;
		
		for (SampleNode sNode : testArray)
		{
			StringBuilder sb = new StringBuilder();
			inputArray[0] = sNode.in;
			teacherArray[0] = sNode.hope;
			forword();
			sb.append("输入测试数据: " + inputArray[0]);
			sb.append("   实际输出:" + outputArray[0]);
			sb.append("   期望输出:" + teacherArray[0]);
			sb.append("\n");
			try
			{
				bw.write(sb.toString());
			} catch (IOException e)
			{
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			printResult();
		}
	}

}

 

 

 

测试输出结果如下图:

 

 

 

程序运行一次的收敛图如下图:



 

<!--EndFragment-->

  • 大小: 187.8 KB
  • 大小: 3.8 KB
1
0
分享到:
评论
2 楼 slowman 2012-01-18  
是的,不过这只是实现了算法的最基本要求。
1 楼 loveq369 2011-04-11  
自己写的吗?厉害哦。

相关推荐

    JAVA实现bp神经网络

    8. **预测**:训练完成后,使用测试集数据进行预测,评估模型的泛化能力。 9. **优化策略**:可能需要调整学习率、动量项、正则化参数等超参数,或者采用更高级的优化算法,如Adam,以提高训练效率和模型性能。 10...

    15.MATLAB神经网络43个案例分析 定制神经网络的实现-神经网络的个性化建模与仿真.zip

    在这个名为“15.MATLAB神经网络43个案例分析 定制神经网络的实现-神经网络的个性化建模与仿真.zip”的资源包中,重点是利用MATLAB进行神经网络的构建、训练和应用。神经网络是一种模仿人脑神经元结构的计算模型,...

    基于maltalb神经网络代码实现

    在IT领域,神经网络是一种模仿人脑神经元结构的计算模型,广泛应用于机器学习和人工智能。...通过研究和实践这些代码,不仅可以掌握各种神经网络的工作原理,还能提升在Matlab环境下进行机器学习和人工智能开发的能力。

    45.配套案例2 BP神经网络的非线性系统建模-非线性函数拟合.zip

    BP神经网络,全称为Backpropagation Neural Network,是人工神经网络的一种典型模型,尤其在非线性系统建模和预测领域应用广泛。本案例通过BP神经网络实现对非线性函数的拟合,以展示其在处理复杂问题时的能力。 在...

    27.MATLAB神经网络43个案例分析 基于灰色神经网络的预测算法研究-订单需求预测.zip

    在本资源中,"27.MATLAB神经网络43个案例分析 基于灰色神经网络的预测算法研究-订单需求预测.zip" 提供了一套深入的学习资料,旨在教授如何利用MATLAB进行神经网络的建模和预测,特别是关注灰色神经网络在订单需求...

    用matlab编BP神经网络预测程序.doc.zip

    标题 "用matlab编BP神经网络预测程序.doc.zip" 提示我们这个压缩包包含一个MATLAB编写的BP(Backpropagation)神经网络预测程序的文档。BP神经网络是一种广泛用于函数逼近、分类和预测的监督学习算法,尤其在处理非...

    MATLAB神经网络43个案例分析 基于MIV的神经网络变量筛选-基于BP神经网络的变量筛选.zip

    在本资源中,我们主要探讨的是如何利用MATLAB这一强大的计算工具进行神经网络的应用,特别是针对变量筛选问题。...通过对BP神经网络的学习和实践,我们可以提升数据分析和建模的能力,更好地应对实际工程问题。

    基于BP神经网络的空调能耗预测与监控系统.pdf

    BP神经网络是机器学习和深度学习领域中的一种常用算法,通过反向传播算法来调整神经网络中的权值,从而实现对复杂数据的建模和预测。在本文中,作者提出了基于BP神经网络的空调能耗预测与监控系统,以解决空调能耗...

    Java实现对Weka算法的应用案例

    这个案例主要探讨了如何在Java环境中利用Weka进行数据分析和建模。以下是对标题和描述中涉及知识点的详细解释: 1. **Weka算法**:Weka(Waikato Environment for Knowledge Analysis)是新西兰怀卡托大学开发的一...

    MATLAB分类与判别模型代码 基于模糊神经网络的嘉陵江水质评价代码.zip

    在本资源中,我们主要探讨的是使用MATLAB进行分类与判别建模,特别是基于模糊神经网络的方法来评价嘉陵江的水质。MATLAB是一款强大的数值计算和编程环境,广泛应用于工程、科研和数据分析等领域。在这个项目中,我们...

    5.MATLAB神经网络43个案例分析 LIBSVM-FarutoUltimate工具箱及GUI版本介绍与使用.zip

    本资源“5.MATLAB神经网络43个案例分析 LIBSVM-FarutoUltimate工具箱及GUI版本介绍与使用.zip”是针对MATLAB神经网络应用的深入学习材料,其中包含了43个具体的案例,涵盖了各种神经网络模型和应用场景,以及LIBSVM...

    BP神经网络

    BP神经网络,全称为Backpropagation Neural Network,是一种广泛应用的人工神经网络模型,主要用于监督学习中的非线性建模。它的核心思想是通过反向传播误差来调整网络中各连接权重,从而实现对复杂函数的逼近。在这...

    67.配套案例24 模糊神经网络的预测算法-嘉陵江水质评价.zip

    MATLAB是进行模糊系统和神经网络建模的常用工具,其强大的库函数和直观的编程环境为模型构建提供了便利。 "data1.mat"和"data2.mat"是数据文件,可能包含了嘉陵江不同时间点的水质监测数据,如pH值、溶解氧、氨氮、...

    基于深度神经网络的空气质量预测系统.pdf

    基于深度神经网络的空气质量预测系统.pdf 本文提出了一种基于深度神经网络的空气质量预测系统,使用栈式自编码模型来预测空气质量。该系统选择了PM、PM10等污染物数据作为样本,基于Java平台构建,进行了训练和参数...

    MATLAB Builder for Java混合编程开发手册

    本手册将以神经网络为例,详细介绍如何使用MATLAB Builder for Java进行混合编程。首先,将在MATLAB环境中编写神经网络模型的训练和预测代码,然后使用MATLAB Builder for Java将其转换为Java组件。最后,这些组件将...

    基于双隐含层BP神经网络的预测源码.zip

    描述中提到的“预测源码”表明这是用某种编程语言实现的代码,可能是Python、Java、C++等,用于构建和训练双隐含层的BP神经网络模型,并进行预测。源码可能包括数据预处理、网络结构定义、权重初始化、反向传播算法...

    基于BP神经网络的温度传感器辐射误差修正.pdf

    此外,我们还使用C#和Java将BP神经网络得到的辐射误差修正方程分别进行封装,并完成Windows上位机和安卓端移动应用开发。这使得我们的方法可以在实际应用中得到更好的效果。 在气候变化研究中,温度测量精确性的...

    56.配套案例13 SVM神经网络中的参数优化-提升分类器性能.zip

    SVM以其高效性和强大的泛化能力而受到青睐,而神经网络则以其强大的非线性建模能力著名。本案例“56.配套案例13 SVM神经网络中的参数优化-提升分类器性能.zip”旨在探讨如何通过参数优化来提升这两种模型的性能。 ...

    MATLAB神经网络43个案例分析 基于SVM的回归预测分析-上证指数开盘指数预测.zip

    此外,压缩包内的“19.MATLAB神经网络43个案例分析 基于BP_Adaboost的强分类器设计——公司财务预警建模.zip”文件可能包含另一个案例,说明如何使用增强学习方法,如Adaboost,结合BP神经网络进行公司财务危机的...

Global site tag (gtag.js) - Google Analytics