`

神经网络解决Xor问题的例子 Java实现代码

阅读更多
public class Network implements Serializable{

  /**
   * The global error for the training.
   */
  protected double globalError;


  /**
   * The number of input neurons.
   */
  protected int inputCount;

  /**
   * The number of hidden neurons.
   */
  protected int hiddenCount;

  /**
   * The number of output neurons
   */
  protected int outputCount;

  /**
   * The total number of neurons in the network.
   */
  protected int neuronCount;

  /**
   * The number of weights in the network.
   */
  protected int weightCount;

  /**
   * The learning rate.
   */
  protected double learnRate;

  /**
   * The outputs from the various levels.
   */
  protected double fire[];

  /**
   * The weight matrix this, along with the thresholds can be
   * thought of as the "memory" of the neural network.
   */
  protected double matrix[];

  /**
   * The errors from the last calculation.
   */
  protected double error[];

  /**
   * Accumulates matrix delta's for training.
   */
  protected double accMatrixDelta[];

  /**
   * The thresholds, this value, along with the weight matrix
   * can be thought of as the memory of the neural network.
   */
  protected double thresholds[];

  /**
   * The changes that should be applied to the weight
   * matrix.
   */
  protected double matrixDelta[];

  /**
   * The accumulation of the threshold deltas.
   */
  protected double accThresholdDelta[];

  /**
   * The threshold deltas.
   */
  protected double thresholdDelta[];

  /**
   * The momentum for training.
   */
  protected double momentum;

  /**
   * The changes in the errors.
   */
  protected double errorDelta[];


  /**
   * Construct the neural network.
   *
   * @param inputCount The number of input neurons.
   * @param hiddenCount The number of hidden neurons
   * @param outputCount The number of output neurons
   * @param learnRate The learning rate to be used when training.
   * @param momentum The momentum to be used when training.
   */
  public Network(int inputCount,
                 int hiddenCount,
                 int outputCount,
                 double learnRate,
                 double momentum) {

    this.learnRate = learnRate;
    this.momentum = momentum;

    this.inputCount = inputCount;
    this.hiddenCount = hiddenCount;
    this.outputCount = outputCount;
    neuronCount = inputCount + hiddenCount + outputCount;
    weightCount = (inputCount * hiddenCount) + (hiddenCount * outputCount);

    fire        = new double[neuronCount];
    matrix      = new double[weightCount];
    matrixDelta = new double[weightCount];
    thresholds  = new double[neuronCount];
    errorDelta  = new double[neuronCount];
    error       = new double[neuronCount];
    accThresholdDelta = new double[neuronCount];
    accMatrixDelta = new double[weightCount];
    thresholdDelta = new double[neuronCount];

    reset();
  }



  /**
   * Returns the root mean square error for a complet training set.
   *
   * @param len The length of a complete training set.
   * @return The current error for the neural network.
   */
  public double getError(int len) {
    double err = Math.sqrt(globalError / (len * outputCount));
    globalError = 0;  // clear the accumulator
    return err;

  }

  /**
   * The threshold method. You may wish to override this class to provide other
   * threshold methods.
   *
   * @param sum The activation from the neuron.
   * @return The activation applied to the threshold method.
   */
  public double threshold(double sum) {
    return 1.0 / (1 + Math.exp(-1.0 * sum));
  }

  /**
   * Compute the output for a given input to the neural network.
   *
   * @param input The input provide to the neural network.
   * @return The results from the output neurons.
   */
  public double []computeOutputs(double input[]) {
    int i, j;
    final int hiddenIndex = inputCount;
    final int outIndex = inputCount + hiddenCount;

    for (i = 0; i < inputCount; i++) {
      fire[i] = input[i];
    }

    // first layer
    int inx = 0;

    for (i = hiddenIndex; i < outIndex; i++) {
      double sum = thresholds[i];

      for (j = 0; j < inputCount; j++) {
        sum += fire[j] * matrix[inx++];
      }
      fire[i] = threshold(sum);
    }

    // hidden layer

    double result[] = new double[outputCount];

    for (i = outIndex; i < neuronCount; i++) {
      double sum = thresholds[i];

      for (j = hiddenIndex; j < outIndex; j++) {
        sum += fire[j] * matrix[inx++];
      }
      fire[i] = threshold(sum);
      result[i-outIndex] = fire[i];
    }

    return result;
  }


  /**
   * Calculate the error for the recogntion just done.
   *
   * @param ideal What the output neurons should have yielded.
   */
  public void calcError(double ideal[]) {
    int i, j;
    final int hiddenIndex = inputCount;
    final int outputIndex = inputCount + hiddenCount;

    // clear hidden layer errors
    for (i = inputCount; i < neuronCount; i++) {
      error[i] = 0;
    }

    // layer errors and deltas for output layer
    for (i = outputIndex; i < neuronCount; i++) {
      error[i] = ideal[i - outputIndex] - fire[i];
      globalError += error[i] * error[i];
      errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
    }

    // hidden layer errors
    int winx = inputCount * hiddenCount;

    for (i = outputIndex; i < neuronCount; i++) {
      for (j = hiddenIndex; j < outputIndex; j++) {
        accMatrixDelta[winx] += errorDelta[i] * fire[j];
        error[j] += matrix[winx] * errorDelta[i];
        winx++;
      }
      accThresholdDelta[i] += errorDelta[i];
    }

    // hidden layer deltas
    for (i = hiddenIndex; i < outputIndex; i++) {
      errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
    }

    // input layer errors
    winx = 0;  // offset into weight array
    for (i = hiddenIndex; i < outputIndex; i++) {
      for (j = 0; j < hiddenIndex; j++) {
        accMatrixDelta[winx] += errorDelta[i] * fire[j];
        error[j] += matrix[winx] * errorDelta[i];
        winx++;
      }
      accThresholdDelta[i] += errorDelta[i];
    }
  }

  /**
   * Modify the weight matrix and thresholds based on the last call to
   * calcError.
   */
  public void learn() {
    int i;

    // process the matrix
    for (i = 0; i < matrix.length; i++) {
      matrixDelta[i] = (learnRate * accMatrixDelta[i]) + (momentum * matrixDelta[i]);
      matrix[i] += matrixDelta[i];
      accMatrixDelta[i] = 0;
    }

    // process the thresholds
    for (i = inputCount; i < neuronCount; i++) {
      thresholdDelta[i] = learnRate * accThresholdDelta[i] + (momentum * thresholdDelta[i]);
      thresholds[i] += thresholdDelta[i];
      accThresholdDelta[i] = 0;
    }
  }

  /**
   * Reset the weight matrix and the thresholds.
   */
  public void reset() {
    int i;

    for (i = 0; i < neuronCount; i++) {
      thresholds[i] = 0.5 - (Math.random());
      thresholdDelta[i] = 0;
      accThresholdDelta[i] = 0;
    }
    for (i = 0; i < matrix.length; i++) {
      matrix[i] = 0.5 - (Math.random());
      matrixDelta[i] = 0;
      accMatrixDelta[i] = 0;
    }
  }
  
  	public File saveToFile(File file){
		try{
			ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(file));
			outputStream.writeObject(this);
			outputStream.close();
		}catch(Exception e){
			throw new RuntimeException(e.getMessage() , e.getCause());
		}
		return file;
	}
	
	public static Network readFromFile(File file){
		Network network = null;
		try{
			ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(file));
			network = (Network) inputStream.readObject();
			inputStream.close();
		}catch(Exception e){
			throw new RuntimeException(e.getMessage() , e.getCause());
		}
		return network;
	}
}






public class XorExample extends JFrame implements
ActionListener,Runnable {

  /**
   * The train button.
   */
  JButton btnTrain;

  /**
   * The run button.
   */
  JButton btnRun;

  /**
   * The quit button.
   */
  JButton btnQuit;

  /**
   * The status line.
   */
  JLabel status;

  /**
   * The background worker thread.
   */
  protected Thread worker = null;

/**
 * The number of input neurons.
 */
  protected final static int NUM_INPUT = 2;

/**
 * The number of output neurons.
 */
  protected final static int NUM_OUTPUT = 1;

/**
 * The number of hidden neurons.
 */
  protected final static int NUM_HIDDEN = 3;

/**
 * The learning rate.
 */
  protected final static double RATE = 0.5;

/**
 * The learning momentum.
 */
  protected final static double MOMENTUM = 0.7;


  /**
   * The training data that the user enters.
   * This represents the inputs and expected
   * outputs for the XOR problem.
   */
  protected JTextField data[][] = new JTextField[4][4];

  /**
   * The neural network.
   */
  protected Network network;



  /**
   * Constructor. Setup the components.
   */
  public XorExample()
  {
    setTitle("XOR Solution");
    network = new Network(
                         NUM_INPUT,
                         NUM_HIDDEN,
                         NUM_OUTPUT,
                         RATE,
                         MOMENTUM);

    Container content = getContentPane();

    GridBagLayout gridbag = new GridBagLayout();
    GridBagConstraints c = new GridBagConstraints();
    content.setLayout(gridbag);

    c.fill = GridBagConstraints.NONE;
    c.weightx = 1.0;

    // Training input label
    c.gridwidth = GridBagConstraints.REMAINDER; //end row
    c.anchor = GridBagConstraints.NORTHWEST;
    content.add(
               new JLabel(
                         "Enter training data:"),c);

    JPanel grid = new JPanel();
    grid.setLayout(new GridLayout(5,4));
    grid.add(new JLabel("IN1"));
    grid.add(new JLabel("IN2"));
    grid.add(new JLabel("Expected OUT   "));
    grid.add(new JLabel("Actual OUT"));

    for ( int i=0;i<4;i++ ) {
      int x = (i&1);
      int y = (i&2)>>1;
      grid.add(data[i][0] = new JTextField(""+y));
      grid.add(data[i][1] = new JTextField(""+x));
      grid.add(data[i][2] = new JTextField(""+(x^y)));
      grid.add(data[i][3] = new JTextField("??"));
      data[i][0].setEditable(false);
      data[i][1].setEditable(false);
      data[i][3].setEditable(false);
    }

    content.add(grid,c);

    // the button panel
    JPanel buttonPanel = new JPanel(new FlowLayout());
    buttonPanel.add(btnTrain = new JButton("Train"));
    buttonPanel.add(btnRun = new JButton("Run"));
    buttonPanel.add(btnQuit = new JButton("Quit"));
    btnTrain.addActionListener(this);
    btnRun.addActionListener(this);
    btnQuit.addActionListener(this);

    // Add the button panel
    c.gridwidth = GridBagConstraints.REMAINDER; //end row
    c.anchor = GridBagConstraints.CENTER;
    content.add(buttonPanel,c);

    // Training input label
    c.gridwidth = GridBagConstraints.REMAINDER; //end row
    c.anchor = GridBagConstraints.NORTHWEST;
    content.add(
               status = new JLabel("Click train to begin training..."),c);

    // adjust size and position
    pack();
    Toolkit toolkit = Toolkit.getDefaultToolkit();
    Dimension d = toolkit.getScreenSize();
    setLocation(
               (int)(d.width-this.getSize().getWidth())/2,
               (int)(d.height-this.getSize().getHeight())/2 );
    setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
    setResizable(false);

    btnRun.setEnabled(false);
  }

  /**
   * The main function, just display the JFrame.
   *
   * @param args No arguments are used.
   */
  public static void main(String args[])
  {
    (new XorExample()).show(true);
  }

  /**
   * Called when the user clicks one of the three
   * buttons.
   *
   * @param e The event.
   */
  public void actionPerformed(ActionEvent e)
  {
    if ( e.getSource()==btnQuit )
      System.exit(0);
    else if ( e.getSource()==btnTrain )
      train();
    else if ( e.getSource()==btnRun )
      evaluate();
  }

  /**
   * Called when the user clicks the run button.
   */
  protected void evaluate()
  {
    double xorData[][] = getGrid();
    int update=0;

    for (int i=0;i<4;i++) {
      NumberFormat nf = NumberFormat.getInstance();
      double d[] = network.computeOutputs(xorData[i]);
      data[i][3].setText(nf.format(d[0]));
    }

  }


  /**
  * Called when the user clicks the train button.
  */
  protected void train()
  {
    if ( worker != null )
      worker = null;
    worker = new Thread(this);
    worker.setPriority(Thread.MIN_PRIORITY);
    worker.start();
  }

  /**
  * The thread worker, used for training
  */
  public void run()
  {
    double xorData[][] = getGrid();
    double xorIdeal[][] = getIdeal();
    int update=0;

    int max = 10000;
    for (int i=0;i<max;i++) {
      for (int j=0;j<xorData.length;j++) {
        network.computeOutputs(xorData[j]);
        network.calcError(xorIdeal[j]);
        network.learn();
      }


      update++;
      if (update==100) {
        status.setText( "Cycles Left:" + (max-i) + ",Error:" + network.getError(xorData.length) );
        update=0;
      }
    }
    btnRun.setEnabled(true);
  }


  /**
   * Called to generate an array of doubles based on
   * the training data that the user has entered.
   *
   * @return An array of doubles
   */
  double [][]getGrid()
  {
    double array[][] = new double[4][2];

    for ( int i=0;i<4;i++ ) {
      array[i][0] =
      Float.parseFloat(data[i][0].getText());
      array[i][1] =
      Float.parseFloat(data[i][1].getText());
    }

    return array;
  }

  /**
   * Called to the the ideal values that that the neural network
   * should return for each of the grid training values.
   *
   * @return The ideal results.
   */
  double [][]getIdeal()
  {
    double array[][] = new double[4][1];

    for ( int i=0;i<4;i++ ) {
      array[i][0] =
      Float.parseFloat(data[i][2].getText());
    }

    return array;
  }


}
分享到:
评论

相关推荐

    用Java实现人工智能编程.pdf

    XOR的问题可以通过训练神经网络来解决,首先准备包含输入数据的文本文件,然后利用JOONE读取这些数据,训练神经网络。训练过程包括将XOR的例子提交给网络,检查输出,根据预期结果与实际输出的差距调整突触权重,这...

    Java Encog神经网络简介

    "XorExample.zip"文件可能包含一个演示如何使用Java Encog解决XOR问题的完整示例代码。在这个例子中,开发者将展示如何初始化网络、加载数据、训练网络以及评估结果。通过阅读和分析这个代码,你可以更直观地了解...

    Artificial_neural_network:使用 XOR 测试人工神经网络。 它将成为更大事物的基础

    XOR(异或)问题在逻辑运算中是一个重要的例子,因为它无法通过单一的逻辑门(如与、或、非)直接解决,但可以通过两个或更多的门组合来完成。在神经网络中,XOR问题被用来测试网络的非线性学习能力,因为它的输出...

    神经网络 joone (资料很全)

    压缩包中有三个文件: 1. joone文件夹中是官方网站提供的开发包和工具 2. joone-javadoc.zip压缩文件是存放了api(英文版,暂时没找到...3. XOR_using_NeuralNet.java文件是一个简单的例子(这个例子很好,我找了好久)

    《数据结构》(02331)基础概念

    内容概要:本文档《数据结构》(02331)第一章主要介绍数据结构的基础概念,涵盖数据与数据元素的定义及其特性,详细阐述了数据结构的三大要素:逻辑结构、存储结构和数据运算。逻辑结构分为线性结构(如线性表、栈、队列)、树形结构(涉及根节点、父节点、子节点等术语)和其他结构。存储结构对比了顺序存储和链式存储的特点,包括访问方式、插入删除操作的时间复杂度以及空间分配方式,并介绍了索引存储和散列存储的概念。最后讲解了抽象数据类型(ADT)的定义及其组成部分,并探讨了算法分析中的时间复杂度计算方法。 适合人群:计算机相关专业学生或初学者,对数据结构有一定兴趣并希望系统学习其基础知识的人群。 使用场景及目标:①理解数据结构的基本概念,掌握逻辑结构和存储结构的区别与联系;②熟悉不同存储方式的特点及应用场景;③学会分析简单算法的时间复杂度,为后续深入学习打下坚实基础。 阅读建议:本章节内容较为理论化,建议结合实际案例进行理解,尤其是对于逻辑结构和存储结构的理解要深入到具体的应用场景中,同时可以尝试编写一些简单的程序来加深对抽象数据类型的认识。

    【工业自动化】施耐德M580 PLC系统架构详解:存储结构、硬件配置与冗余设计

    内容概要:本文详细介绍了施耐德M580系列PLC的存储结构、系统硬件架构、上电写入程序及CPU冗余特性。在存储结构方面,涵盖拓扑寻址、Device DDT远程寻址以及寄存器寻址三种方式,详细解释了不同类型的寻址方法及其应用场景。系统硬件架构部分,阐述了最小系统的构建要素,包括CPU、机架和模块的选择与配置,并介绍了常见的系统拓扑结构,如简单的机架间拓扑和远程子站以太网菊花链等。上电写入程序环节,说明了通过USB和以太网两种接口进行程序下载的具体步骤,特别是针对初次下载时IP地址的设置方法。最后,CPU冗余部分重点描述了热备功能的实现机制,包括IP通讯地址配置和热备拓扑结构。 适合人群:从事工业自动化领域工作的技术人员,特别是对PLC编程及系统集成有一定了解的工程师。 使用场景及目标:①帮助工程师理解施耐德M580系列PLC的寻址机制,以便更好地进行模块配置和编程;②指导工程师完成最小系统的搭建,优化系统拓扑结构的设计;③提供详细的上电写入程序指南,确保程序下载顺利进行;④解释CPU冗余的实现方式,提高系统的稳定性和可靠性。 其他说明:文中还涉及一些特殊模块的功能介绍,如定时器事件和Modbus串口通讯模块,这些内容有助于用户深入了解M580系列PLC的高级应用。此外,附录部分提供了远程子站和热备冗余系统的实物图片,便于用户直观理解相关概念。

    某型自动垂直提升仓储系统方案论证及关键零部件的设计.zip

    某型自动垂直提升仓储系统方案论证及关键零部件的设计.zip

    2135D3F1EFA99CB590678658F575DB23.pdf#page=1&view=fitH

    2135D3F1EFA99CB590678658F575DB23.pdf#page=1&view=fitH

    agentransack文本搜索软件

    可以搜索文本内的内容,指定目录,指定文件格式,匹配大小写等

    Windows 平台 Android Studio 下载与安装指南.zip

    Windows 平台 Android Studio 下载与安装指南.zip

    Android Studio Meerkat 2024.3.1 Patch 1(android-studio-2024.3.1.14-windows-zip.zip.002)

    Android Studio Meerkat 2024.3.1 Patch 1(android-studio-2024.3.1.14-windows.zip)适用于Windows系统,文件使用360压缩软件分割成两个压缩包,必须一起下载使用: part1: https://download.csdn.net/download/weixin_43800734/90557033 part2: https://download.csdn.net/download/weixin_43800734/90557035

    4-3-台区智能融合终端功能模块技术规范(试行).pdf

    国网台区终端最新规范

    4-13-台区智能融合终端软件检测规范(试行).pdf

    国网台区终端最新规范

    【锂电池剩余寿命预测】Transformer-GRU锂电池剩余寿命预测(Matlab完整源码和数据)

    1.【锂电池剩余寿命预测】Transformer-GRU锂电池剩余寿命预测(Matlab完整源码和数据) 2.数据集:NASA数据集,已经处理好,B0005电池训练、B0006测试; 3.环境准备:Matlab2023b,可读性强; 4.模型描述:Transformer-GRU在各种各样的问题上表现非常出色,现在被广泛使用。 5.领域描述:近年来,随着锂离子电池的能量密度、功率密度逐渐提升,其安全性能与剩余使用寿命预测变得愈发重要。本代码实现了Transformer-GRU在该领域的应用。 6.作者介绍:机器学习之心,博客专家认证,机器学习领域创作者,2023博客之星TOP50,主做机器学习和深度学习时序、回归、分类、聚类和降维等程序设计和案例分析,文章底部有博主联系方式。从事Matlab、Python算法仿真工作8年,更多仿真源码、数据集定制私信。

    基于android的家庭收纳App的设计与实现.zip

    Android项目原生java语言课程设计,包含LW+ppt

    大学生入门前端-五子棋vue项目

    大学生入门前端-五子棋vue项目

    二手车分析完整项目,包含源代码和数据集,包含:XGBoost 模型,训练模型代码,数据集包含 10,000 条二手车记录的数据集,涵盖车辆品牌、型号、年份、里程数、发动机缸数、价格等

    这是一个完整的端到端解决方案,用于分析和预测阿联酋(UAE)地区的二手车价格。数据集包含 10,000 条二手车信息,覆盖了迪拜、阿布扎比和沙迦等城市,并提供了精确的地理位置数据。此外,项目还包括一个基于 Dash 构建的 Web 应用程序代码和一个训练好的 XGBoost 模型,帮助用户探索区域市场趋势、预测车价以及可视化地理空间洞察。 数据集内容 项目文件以压缩 ZIP 归档形式提供,包含以下内容: 数据文件: data/uae_used_cars_10k.csv:包含 10,000 条二手车记录的数据集,涵盖车辆品牌、型号、年份、里程数、发动机缸数、价格、变速箱类型、燃料类型、颜色、描述以及销售地点(如迪拜、阿布扎比、沙迦)。 模型文件: models/stacking_model.pkl:训练好的 XGBoost 模型,用于预测二手车价格。 models/scaler.pkl:用于数据预处理的缩放器。 models.py:模型相关功能的实现。 train_model.py:训练模型的脚本。 Web 应用程序文件: app.py:Dash 应用程序的主文件。 callback

    《基于YOLOv8的船舶航行违规并线预警系统》(包含源码、可视化界面、完整数据集、部署教程)简单部署即可运行。功能完善、操作简单,适合毕设或课程设计.zip

    资源内项目源码是来自个人的毕业设计,代码都测试ok,包含源码、数据集、可视化页面和部署说明,可产生核心指标曲线图、混淆矩阵、F1分数曲线、精确率-召回率曲线、验证集预测结果、标签分布图。都是运行成功后才上传资源,毕设答辩评审绝对信服的保底85分以上,放心下载使用,拿来就能用。包含源码、数据集、可视化页面和部署说明一站式服务,拿来就能用的绝对好资源!!! 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、大作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.txt文件,仅供学习参考, 切勿用于商业用途。

    《基于YOLOv8的工业布匹瑕疵分类系统》(包含源码、可视化界面、完整数据集、部署教程)简单部署即可运行。功能完善、操作简单,适合毕设或课程设计.zip

    资源内项目源码是来自个人的毕业设计,代码都测试ok,包含源码、数据集、可视化页面和部署说明,可产生核心指标曲线图、混淆矩阵、F1分数曲线、精确率-召回率曲线、验证集预测结果、标签分布图。都是运行成功后才上传资源,毕设答辩评审绝对信服的保底85分以上,放心下载使用,拿来就能用。包含源码、数据集、可视化页面和部署说明一站式服务,拿来就能用的绝对好资源!!! 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、大作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.txt文件,仅供学习参考, 切勿用于商业用途。

    CodeCount.exe

    此为代码审查工具 可查 文件数,字节数,总行数,代码行数,注释行数,空白行数,注释率等

Global site tag (gtag.js) - Google Analytics