`
hzxdark
  • 浏览: 77863 次
社区版块
存档分类
最新评论

BP网络JAVA版源代码

阅读更多

 

神经网络课程的作业,一个简单的BP网络。准确率有点低,可能是我算法有点问题,100个训练数据,测试50个数据,只得80%正确。

 

 

BP类  //封装bp算法

package arithmetic;

public class BP {
 private double[] P;

 private double[] T;

 private double[][] W1;

 private double[][] W2;

 private int n_a0;

 private int n_a1;

 private int n_a2;

 private double[] B1;

 private double[] B2;

 private double[] a1;

 private double[] a2;

 private double[] q;

 private double[] db1;

 private double[] db2;

 private double[][] dw1;

 private double[][] dw2;

 private double e;

 private double r;

 private double e0;

 public BP(double[][] W1, double[][] W2, double[] B1, double[] B2) {
  this.W1 = W1;
  this.W2 = W2;
  this.B1 = B1;
  this.B2 = B2;
  n_a0 = W1[0].length;
  n_a1 = W1.length;
  n_a2 = W2.length;
  init();
 }

 public void setP(double[] P) {
  this.P = P;
 }

 public void setT(double[] T) {
  this.T = T;
 }

 private void init() {
  a1 = new double[n_a1];
  a2 = new double[n_a2];
  r = 0.4;
  e0 = 0.02;
  q = new double[n_a2];
  db2 = new double[n_a2];
  dw2 = new double[n_a2][n_a1];
  db1 = new double[n_a1];
  dw1 = new double[n_a1][n_a0];
 }

 public void calA1() {
  double temp = 0;
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    temp += W1[i][j] * P[j];
   }
   temp += B1[i];
   a1[i] = F.f1(temp);
  }
 }

 public double[] getA1() {
  return a1;
 }

 public double[] getA2() {
  return a2;
 }

 public void calA2() {
  double temp = 0;
  for (int k = 0; k < n_a2; k++) {
   for (int i = 0; i < n_a1; i++) {
    temp += W2[k][i] * a1[i];
   }
   temp += B2[k];
   a2[k] = F.f2(temp);
  }
 }

 public void calE() {
  e = 0;
  for (int k = 0; k < n_a2; k++) {
   double ek = T[k] - a2[k];
   e += ek * ek;
   e /= 2;
  }
 }

 public void calDb2() {
  for (int k = 0; k < n_a2; k++) {
   q[k] = (T[k] - a2[k]) * F.f2_1(a2[k]);
   db2[k] = q[k] * r;
  }
 }

 public void calDw2() {
  for (int k = 0; k < n_a2; k++) {
   for (int i = 0; i < n_a1; i++) {
    dw2[k][i] = db2[k] * a1[i];
   }
  }
 }

 public void calDb1() {
  for (int i = 0; i < n_a1; i++) {
   db1[i] = 0;
   for (int k = 0; k < n_a2; k++) {
    db1[i] += q[k] * W2[k][i];
   }
   db1[i] *= r * F.f1_1(a1[i]);
  }
 }

 public void calDw1() {
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    dw1[i][j] = db1[i] * P[j];
   }
  }
 }

 public void changeDb2() {
  for (int i = 0; i < n_a2; i++) {
   B2[i] += db2[i];
  }
 }

 public void changeDw2() {
  for (int i = 0; i < n_a2; i++) {
   for (int j = 0; j < n_a1; j++) {
    W2[i][j] += dw2[i][j];
   }
  }
 }

 public void changeDb1() {
  for (int i = 0; i < n_a1; i++) {
   B1[i] += db1[i];
  }
 }

 public void changeDw1() {
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    W1[i][j] += dw1[i][j];
   }
  }
 }

 public void train(double[][] P, double[][] T) {
  while(true){
   boolean isChange = false;
   for (int n = 0; n < P.length; n++) {
    setP(P[n]);
    setT(T[n]);
    this.calA1();
    this.calA2();
    this.calE();
    if (e < e0)
     continue;
    this.calDb2();
    this.calDw2();
    this.calDb1();
    this.calDw1();
    this.changeDb2();
    this.changeDw2();
    this.changeDb1();
    this.changeDw1();
    isChange = true;
   // break;

   }
   if (!isChange) {
    System.out.println("train succeed");
    break;
   }
  }
 }

 public double[] divide(double[] p, double[] t) {
  setP(p);
  setT(t);
  this.calA1();
  this.calA2();
  return a2;
 }
 
 public double getE(){
  return e;
 }
}

F类 //封装神经元函数

package arithmetic;

public class F {
 public static double f1(double x){
  return 1/(1+Math.exp(-1*x));
 }
 public static double f1_1(double y){
  return y*(1-y);
}
 public static double f2(double x){
  return x;
 }
 public static double f2_1(double y){
  return 1;
 }
}

Controller类 读入训练数据和测试数据,并创建BP实例进行训练测试

package arithmetic;

import java.io.*;
import java.util.ArrayList;

import javax.swing.JFrame;

public class Controler {
 private double[][] p_test;
 private double[][] t_test;
 private double[][] p_train;
 private double[][] t_train; 
 
 private BP bp;
 private JFrame viwer;
 
 
 public Controler() throws IOException{
  getTestData();
  getTrainData();
  double[][] w1 = new double[][] { { 0.2, 0.3, 0.4, 0.1 },
    { 0.3, 0.4, 0.2, 0.4 }, { 0.4, 0.8, 0.9, 0.3 }};
  double[][] w2 = new double[][] { { 0.3, 0.6, 0.7 }, { 0.1, 0.3, 0.7 } };
  double[] b1 = new double[] { 0.2, 0.4, 0.5 };
  double[] b2 = new double[] { 0.1, 0.5 };
  bp = new BP(w1,w2,b1,b2);
  bp.train(p_train, t_train);
  int a = 0;
  for(int i =0;i<p_test.length;i++){
   double[] a2 = bp.divide(p_test[i], t_test[i]);  
   int t0 = (int)(t_test[i][0])*2+(int)(t_test[i][1]);
   int t1 = (int)(a2[0])*2+(int)(a2[1]);
   boolean equals = t1==t0;
   if(equals)a++;
   System.out.println("expected:"+t0+"\t"+"output:"+t1+"\t"+equals);
  }
  a = (int)(a/50.0*100);
  System.out.println(a);
 }
 
 private void getTestData() throws IOException{
  String fileName = "testData.txt";
  BufferedReader br = null;
  try {
   br = new BufferedReader(new FileReader(fileName));
  } catch (FileNotFoundException e) {
   br.close();
   e.printStackTrace();
  }
  ArrayList al = new ArrayList();
  String s = null;
  while((s=br.readLine())!=null){
   al.add(s);
  }
  br.close();
  p_test = new double[al.size()][4];
  t_test = new double[al.size()][2];
  double[] maxData = new double[]{0,0,0,0};
  for(int i =0;i<al.size();i++){
   String[] temp = al.get(i).toString().split(" ");;
   for(int j =0;j<4;j++){
    p_test[i][j] = Double.parseDouble(temp[j+1]);
    if(p_test[i][j]>maxData[j])maxData[j] = p_test[i][j];
   }
   int d = Integer.parseInt(temp[0]);
   switch (d){
   case 0:
    t_test[i][0] = 0;
    t_test[i][1] = 0;
    break;
   case 1:
    t_test[i][0] = 0;
    t_test[i][1] = 1;
    break;
   case 2:
    t_test[i][0] = 1;
    t_test[i][1] = 0;
    break;
   default:
    t_test[i][0] = 1;
    t_test[i][1] = 1;
    break;
   }
  }
  for(int i =0;i<p_test.length;i++){
   for(int j =0;j<4;j++){
    p_test[i][j] /= maxData[j];
   }
  }
 }
 private void getTrainData() throws IOException{
  String fileName = "trainData.txt";
  BufferedReader br = null;
  try {
   br = new BufferedReader(new FileReader(fileName));
  } catch (FileNotFoundException e) {
   br.close();
   e.printStackTrace();
  }
  ArrayList al = new ArrayList();
  String s = null;
  while((s=br.readLine())!=null){
   al.add(s);
  }
  p_train = new double[al.size()][4];
  t_train = new double[al.size()][2];
  double[] maxData = new double[]{0,0,0,0};
  for(int i =0;i<al.size();i++){
   String[] temp = al.get(i).toString().split(" ");;
   for(int j =0;j<temp.length-1;j++){
    p_train[i][j] = Double.parseDouble(temp[j+1]);
    if(p_train[i][j]>maxData[j])maxData[j] = p_train[i][j];
   }
   int d = Integer.parseInt(temp[0]);
   switch (d){
   case 0:
    t_train[i][0] = 0;
    t_train[i][1] = 0;
    break;
   case 1:
    t_train[i][0] = 0;
    t_train[i][1] = 1;
    break;
   case 2:
    t_train[i][0] = 1;
    t_train[i][1] = 0;
    break;
   default:
    t_train[i][0] = 1;
    t_train[i][1] = 1;
    break;
   }
  }
  for(int i =0;i<p_train.length;i++){
   for(int j =0;j<4;j++){
    p_train[i][j] /= maxData[j];
   }
  }
 }
 public static void main(String[] args) throws Exception{
  Controler ctr =new Controler();
  
 }
}

 

分享到:
评论
4 楼 xiaolu_yatou 2012-07-25  
乱 
3 楼 张空空 2010-06-29  
没有注释哦
2 楼 ydsakyclguozi 2009-06-30  
楼主,那你发个什么劲儿啊,改好了再发啊!
1 楼 mating 2008-06-20  
我用了一下怎么有问题啊???

相关推荐

    BP神经网络程序,java语言源代码

    RBF神经网络(Radial Basis Function Network)是一种特殊的BP网络,其隐藏层神经元使用径向基函数作为激活函数,通常用于函数拟合、分类和回归任务。在这个程序中,`RbfNet`类包含了网络结构的各个关键属性,如输入...

    粒子群训练BP神经网络源代码

    Java语言编写的这份源代码,不仅具有跨平台的特性,而且能够利用Java丰富的库资源和框架来增强代码的可读性和维护性。程序文件“BPPSO”可能是整个项目的主程序或核心类库,它将包含一系列关键函数,例如初始化粒子...

    BP.rar_BP_bp 神经网络 java 算法_bp 预测_java BP预测算法_预测算法java

    总的来说,这个BP神经网络Java实现为理解和实践预测算法提供了一个基础平台,对于学习和研究神经网络算法的开发者来说,是一个有价值的资源。通过深入研究源代码,我们可以更深入地理解BP算法的工作原理,以及如何在...

    BP神经网络JAVA实现

    BP神经网络,全称为Backpropagation Neural ...在`pkg1`这个压缩包文件中,可能包含了BP神经网络的Java源代码、训练数据集、测试数据集以及其他相关资源。具体代码实现细节和数据结构,需要查看源代码才能深入了解。

    java BP神经网络

    Java BP神经网络是一种基于反向传播(Backpropagation)算法的多层前馈神经网络,广泛应用于模式识别、预测分析和优化问题。该算法通过不断调整权重和偏置来最小化损失函数,以达到对训练数据的拟合。在Java编程环境...

    Ann.rar_ANN java实现_BP神经网络_BP神经网络 java_ann java

    描述中提到“使用java语言编写的BP神经网络算法,实现了BP算法的功能”,这意味着该压缩包内包含了Java源代码,用于构建和训练BP神经网络模型。BP算法的核心是通过反向传播错误来调整神经元之间的权重,以最小化预测...

    神经网络BP算法源代码

    BP算法的神经网络的源代码, 可以根据向量建立网络,网络的训练结果和初始结构可以用XML保存和载入。 &lt;br&gt;其中 Compressor/TrainerWithDiagram.class , 是一个用于演示的训练器, 产生制定范围内的数,生成...

    BP代码java

    文件"BP代码.txt"可能是这个BP算法Java实现的源代码。源码可能包含了上述所有组件,例如网络结构的定义、数据输入、训练过程、预测功能以及可能的可视化工具。阅读和理解这份代码可以帮助你更深入地了解如何在Java中...

    BP.rar_BP_bp java_java BP_neural network java_神经网络

    "file.java"很可能是一个Java源代码文件,包含了BP神经网络的实现细节。"chengxu"可能是另一个源代码文件、数据集或者是项目文档,具体作用需要查看文件内容才能确定。通常,Java源代码文件会定义神经网络的结构、...

    BP神经网络matlab源程序代码.doc.zip

    BP神经网络,全称为Backpropagation Neural Network,是一种在机器学习领域广泛应用的多层前馈神经网络。...通过深入研究这些代码,读者不仅可以掌握BP网络的工作机制,还能熟悉MATLAB在神经网络领域的应用技巧。

    java语言写的BP源程序

    本压缩包文件名为"java语言写的BP源程序",显然它包含了使用Java编程语言实现的BP(Backpropagation)神经网络算法的源代码。BP神经网络是一种广泛应用的监督学习算法,主要用于模式识别和函数逼近问题。 BP神经...

    BP.rar_BP_BP神经网络 java_bp java pudn_神经网络

    BP神经网络算法源代码,文件为JAVA语言编写的,编译环境为Eclipse

    bpann_java.rar_BP神经网络 java_bp.rar java_神经网络 java_绁炵粡缃戠粶

    标签中的"bp.rar_java"可能是指项目中包含的源代码文件是`.rar`格式,通常需要解压后才能查看和使用。 总的来说,这个Java BP神经网络项目为学习者提供了一个实践神经网络算法的实例,涵盖了神经网络的基本构建、...

    BP and MP算法源代码

    在BP算法的源代码实现中,涉及到的编程语言通常是Python、C++或Java,这些代码实现了前向传播、误差计算和反向传播的关键步骤。在神经网络框架如TensorFlow或PyTorch中,反向传播过程是自动完成的,简化了算法的实现...

    Java写的BP神经网络实现(BP)

    Java的BP神经网络库如Deeplearning4j提供了方便的API来构建和训练神经网络,但这里的“bp”可能是一个自定义实现,因此可能需要对源代码有深入理解,以掌握其具体工作方式和优化策略。 在实际应用中,BP神经网络...

    bp.rar_BP_java BP_javxxz_neural network java

    而"bp"很可能是一个Java源代码文件,直接对应于BP神经网络的实现。 BP神经网络的基本结构包括输入层、隐藏层和输出层,每层由若干个神经元组成。网络的训练过程通过不断调整连接各层神经元之间的权重来最小化预测...

    bp神经网络对数据分类的实现(java代码,iris测试数据)

    本项目中的"BPNN"文件可能包含了实现神经网络的Java源代码、数据处理逻辑和运行脚本。通过阅读和理解这些代码,我们可以深入学习BP神经网络的工作原理,以及如何用Java进行实际的实现。同时,这也是一个很好的实践...

    e2029bd3daf3.rar_BP神经网络 java

    BP神经网络的java源代码,采用动量梯度下降法

    Java基于BP神经网络的手写数字识别源代码+训练集

    Java基于BP神经网络的手写数字识别源代码+训练集 assets/inputHidden.csv 输入层到隐藏层的权重矩阵 assets/hiddenOutput.csv 隐藏层到输出层的矩阵 assets/train-images-idx3-ubyte 训练集图片 assets/train-labels...

    BPNN_Java.zip_BPNN_BP神经网络 java_iris

    10. `src`目录是源代码文件夹,应该包含BP神经网络的Java源代码文件。 11. `iris`可能是数据集文件,可能包含处理后的Iris数据,以便供Java程序读取和使用。 总之,这个项目提供了一个使用Java实现的BP神经网络...

Global site tag (gtag.js) - Google Analytics