`

神经网络解决Xor问题的例子 Java实现代码

阅读更多
public class Network implements Serializable{

  /**
   * The global error for the training.
   */
  protected double globalError;


  /**
   * The number of input neurons.
   */
  protected int inputCount;

  /**
   * The number of hidden neurons.
   */
  protected int hiddenCount;

  /**
   * The number of output neurons
   */
  protected int outputCount;

  /**
   * The total number of neurons in the network.
   */
  protected int neuronCount;

  /**
   * The number of weights in the network.
   */
  protected int weightCount;

  /**
   * The learning rate.
   */
  protected double learnRate;

  /**
   * The outputs from the various levels.
   */
  protected double fire[];

  /**
   * The weight matrix this, along with the thresholds can be
   * thought of as the "memory" of the neural network.
   */
  protected double matrix[];

  /**
   * The errors from the last calculation.
   */
  protected double error[];

  /**
   * Accumulates matrix delta's for training.
   */
  protected double accMatrixDelta[];

  /**
   * The thresholds, this value, along with the weight matrix
   * can be thought of as the memory of the neural network.
   */
  protected double thresholds[];

  /**
   * The changes that should be applied to the weight
   * matrix.
   */
  protected double matrixDelta[];

  /**
   * The accumulation of the threshold deltas.
   */
  protected double accThresholdDelta[];

  /**
   * The threshold deltas.
   */
  protected double thresholdDelta[];

  /**
   * The momentum for training.
   */
  protected double momentum;

  /**
   * The changes in the errors.
   */
  protected double errorDelta[];


  /**
   * Construct the neural network.
   *
   * @param inputCount The number of input neurons.
   * @param hiddenCount The number of hidden neurons
   * @param outputCount The number of output neurons
   * @param learnRate The learning rate to be used when training.
   * @param momentum The momentum to be used when training.
   */
  public Network(int inputCount,
                 int hiddenCount,
                 int outputCount,
                 double learnRate,
                 double momentum) {

    this.learnRate = learnRate;
    this.momentum = momentum;

    this.inputCount = inputCount;
    this.hiddenCount = hiddenCount;
    this.outputCount = outputCount;
    neuronCount = inputCount + hiddenCount + outputCount;
    weightCount = (inputCount * hiddenCount) + (hiddenCount * outputCount);

    fire        = new double[neuronCount];
    matrix      = new double[weightCount];
    matrixDelta = new double[weightCount];
    thresholds  = new double[neuronCount];
    errorDelta  = new double[neuronCount];
    error       = new double[neuronCount];
    accThresholdDelta = new double[neuronCount];
    accMatrixDelta = new double[weightCount];
    thresholdDelta = new double[neuronCount];

    reset();
  }



  /**
   * Returns the root mean square error for a complet training set.
   *
   * @param len The length of a complete training set.
   * @return The current error for the neural network.
   */
  public double getError(int len) {
    double err = Math.sqrt(globalError / (len * outputCount));
    globalError = 0;  // clear the accumulator
    return err;

  }

  /**
   * The threshold method. You may wish to override this class to provide other
   * threshold methods.
   *
   * @param sum The activation from the neuron.
   * @return The activation applied to the threshold method.
   */
  public double threshold(double sum) {
    return 1.0 / (1 + Math.exp(-1.0 * sum));
  }

  /**
   * Compute the output for a given input to the neural network.
   *
   * @param input The input provide to the neural network.
   * @return The results from the output neurons.
   */
  public double []computeOutputs(double input[]) {
    int i, j;
    final int hiddenIndex = inputCount;
    final int outIndex = inputCount + hiddenCount;

    for (i = 0; i < inputCount; i++) {
      fire[i] = input[i];
    }

    // first layer
    int inx = 0;

    for (i = hiddenIndex; i < outIndex; i++) {
      double sum = thresholds[i];

      for (j = 0; j < inputCount; j++) {
        sum += fire[j] * matrix[inx++];
      }
      fire[i] = threshold(sum);
    }

    // hidden layer

    double result[] = new double[outputCount];

    for (i = outIndex; i < neuronCount; i++) {
      double sum = thresholds[i];

      for (j = hiddenIndex; j < outIndex; j++) {
        sum += fire[j] * matrix[inx++];
      }
      fire[i] = threshold(sum);
      result[i-outIndex] = fire[i];
    }

    return result;
  }


  /**
   * Calculate the error for the recogntion just done.
   *
   * @param ideal What the output neurons should have yielded.
   */
  public void calcError(double ideal[]) {
    int i, j;
    final int hiddenIndex = inputCount;
    final int outputIndex = inputCount + hiddenCount;

    // clear hidden layer errors
    for (i = inputCount; i < neuronCount; i++) {
      error[i] = 0;
    }

    // layer errors and deltas for output layer
    for (i = outputIndex; i < neuronCount; i++) {
      error[i] = ideal[i - outputIndex] - fire[i];
      globalError += error[i] * error[i];
      errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
    }

    // hidden layer errors
    int winx = inputCount * hiddenCount;

    for (i = outputIndex; i < neuronCount; i++) {
      for (j = hiddenIndex; j < outputIndex; j++) {
        accMatrixDelta[winx] += errorDelta[i] * fire[j];
        error[j] += matrix[winx] * errorDelta[i];
        winx++;
      }
      accThresholdDelta[i] += errorDelta[i];
    }

    // hidden layer deltas
    for (i = hiddenIndex; i < outputIndex; i++) {
      errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
    }

    // input layer errors
    winx = 0;  // offset into weight array
    for (i = hiddenIndex; i < outputIndex; i++) {
      for (j = 0; j < hiddenIndex; j++) {
        accMatrixDelta[winx] += errorDelta[i] * fire[j];
        error[j] += matrix[winx] * errorDelta[i];
        winx++;
      }
      accThresholdDelta[i] += errorDelta[i];
    }
  }

  /**
   * Modify the weight matrix and thresholds based on the last call to
   * calcError.
   */
  public void learn() {
    int i;

    // process the matrix
    for (i = 0; i < matrix.length; i++) {
      matrixDelta[i] = (learnRate * accMatrixDelta[i]) + (momentum * matrixDelta[i]);
      matrix[i] += matrixDelta[i];
      accMatrixDelta[i] = 0;
    }

    // process the thresholds
    for (i = inputCount; i < neuronCount; i++) {
      thresholdDelta[i] = learnRate * accThresholdDelta[i] + (momentum * thresholdDelta[i]);
      thresholds[i] += thresholdDelta[i];
      accThresholdDelta[i] = 0;
    }
  }

  /**
   * Reset the weight matrix and the thresholds.
   */
  public void reset() {
    int i;

    for (i = 0; i < neuronCount; i++) {
      thresholds[i] = 0.5 - (Math.random());
      thresholdDelta[i] = 0;
      accThresholdDelta[i] = 0;
    }
    for (i = 0; i < matrix.length; i++) {
      matrix[i] = 0.5 - (Math.random());
      matrixDelta[i] = 0;
      accMatrixDelta[i] = 0;
    }
  }
  
  	public File saveToFile(File file){
		try{
			ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(file));
			outputStream.writeObject(this);
			outputStream.close();
		}catch(Exception e){
			throw new RuntimeException(e.getMessage() , e.getCause());
		}
		return file;
	}
	
	public static Network readFromFile(File file){
		Network network = null;
		try{
			ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(file));
			network = (Network) inputStream.readObject();
			inputStream.close();
		}catch(Exception e){
			throw new RuntimeException(e.getMessage() , e.getCause());
		}
		return network;
	}
}






public class XorExample extends JFrame implements
ActionListener,Runnable {

  /**
   * The train button.
   */
  JButton btnTrain;

  /**
   * The run button.
   */
  JButton btnRun;

  /**
   * The quit button.
   */
  JButton btnQuit;

  /**
   * The status line.
   */
  JLabel status;

  /**
   * The background worker thread.
   */
  protected Thread worker = null;

/**
 * The number of input neurons.
 */
  protected final static int NUM_INPUT = 2;

/**
 * The number of output neurons.
 */
  protected final static int NUM_OUTPUT = 1;

/**
 * The number of hidden neurons.
 */
  protected final static int NUM_HIDDEN = 3;

/**
 * The learning rate.
 */
  protected final static double RATE = 0.5;

/**
 * The learning momentum.
 */
  protected final static double MOMENTUM = 0.7;


  /**
   * The training data that the user enters.
   * This represents the inputs and expected
   * outputs for the XOR problem.
   */
  protected JTextField data[][] = new JTextField[4][4];

  /**
   * The neural network.
   */
  protected Network network;



  /**
   * Constructor. Setup the components.
   */
  public XorExample()
  {
    setTitle("XOR Solution");
    network = new Network(
                         NUM_INPUT,
                         NUM_HIDDEN,
                         NUM_OUTPUT,
                         RATE,
                         MOMENTUM);

    Container content = getContentPane();

    GridBagLayout gridbag = new GridBagLayout();
    GridBagConstraints c = new GridBagConstraints();
    content.setLayout(gridbag);

    c.fill = GridBagConstraints.NONE;
    c.weightx = 1.0;

    // Training input label
    c.gridwidth = GridBagConstraints.REMAINDER; //end row
    c.anchor = GridBagConstraints.NORTHWEST;
    content.add(
               new JLabel(
                         "Enter training data:"),c);

    JPanel grid = new JPanel();
    grid.setLayout(new GridLayout(5,4));
    grid.add(new JLabel("IN1"));
    grid.add(new JLabel("IN2"));
    grid.add(new JLabel("Expected OUT   "));
    grid.add(new JLabel("Actual OUT"));

    for ( int i=0;i<4;i++ ) {
      int x = (i&1);
      int y = (i&2)>>1;
      grid.add(data[i][0] = new JTextField(""+y));
      grid.add(data[i][1] = new JTextField(""+x));
      grid.add(data[i][2] = new JTextField(""+(x^y)));
      grid.add(data[i][3] = new JTextField("??"));
      data[i][0].setEditable(false);
      data[i][1].setEditable(false);
      data[i][3].setEditable(false);
    }

    content.add(grid,c);

    // the button panel
    JPanel buttonPanel = new JPanel(new FlowLayout());
    buttonPanel.add(btnTrain = new JButton("Train"));
    buttonPanel.add(btnRun = new JButton("Run"));
    buttonPanel.add(btnQuit = new JButton("Quit"));
    btnTrain.addActionListener(this);
    btnRun.addActionListener(this);
    btnQuit.addActionListener(this);

    // Add the button panel
    c.gridwidth = GridBagConstraints.REMAINDER; //end row
    c.anchor = GridBagConstraints.CENTER;
    content.add(buttonPanel,c);

    // Training input label
    c.gridwidth = GridBagConstraints.REMAINDER; //end row
    c.anchor = GridBagConstraints.NORTHWEST;
    content.add(
               status = new JLabel("Click train to begin training..."),c);

    // adjust size and position
    pack();
    Toolkit toolkit = Toolkit.getDefaultToolkit();
    Dimension d = toolkit.getScreenSize();
    setLocation(
               (int)(d.width-this.getSize().getWidth())/2,
               (int)(d.height-this.getSize().getHeight())/2 );
    setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
    setResizable(false);

    btnRun.setEnabled(false);
  }

  /**
   * The main function, just display the JFrame.
   *
   * @param args No arguments are used.
   */
  public static void main(String args[])
  {
    (new XorExample()).show(true);
  }

  /**
   * Called when the user clicks one of the three
   * buttons.
   *
   * @param e The event.
   */
  public void actionPerformed(ActionEvent e)
  {
    if ( e.getSource()==btnQuit )
      System.exit(0);
    else if ( e.getSource()==btnTrain )
      train();
    else if ( e.getSource()==btnRun )
      evaluate();
  }

  /**
   * Called when the user clicks the run button.
   */
  protected void evaluate()
  {
    double xorData[][] = getGrid();
    int update=0;

    for (int i=0;i<4;i++) {
      NumberFormat nf = NumberFormat.getInstance();
      double d[] = network.computeOutputs(xorData[i]);
      data[i][3].setText(nf.format(d[0]));
    }

  }


  /**
  * Called when the user clicks the train button.
  */
  protected void train()
  {
    if ( worker != null )
      worker = null;
    worker = new Thread(this);
    worker.setPriority(Thread.MIN_PRIORITY);
    worker.start();
  }

  /**
  * The thread worker, used for training
  */
  public void run()
  {
    double xorData[][] = getGrid();
    double xorIdeal[][] = getIdeal();
    int update=0;

    int max = 10000;
    for (int i=0;i<max;i++) {
      for (int j=0;j<xorData.length;j++) {
        network.computeOutputs(xorData[j]);
        network.calcError(xorIdeal[j]);
        network.learn();
      }


      update++;
      if (update==100) {
        status.setText( "Cycles Left:" + (max-i) + ",Error:" + network.getError(xorData.length) );
        update=0;
      }
    }
    btnRun.setEnabled(true);
  }


  /**
   * Called to generate an array of doubles based on
   * the training data that the user has entered.
   *
   * @return An array of doubles
   */
  double [][]getGrid()
  {
    double array[][] = new double[4][2];

    for ( int i=0;i<4;i++ ) {
      array[i][0] =
      Float.parseFloat(data[i][0].getText());
      array[i][1] =
      Float.parseFloat(data[i][1].getText());
    }

    return array;
  }

  /**
   * Called to the the ideal values that that the neural network
   * should return for each of the grid training values.
   *
   * @return The ideal results.
   */
  double [][]getIdeal()
  {
    double array[][] = new double[4][1];

    for ( int i=0;i<4;i++ ) {
      array[i][0] =
      Float.parseFloat(data[i][2].getText());
    }

    return array;
  }


}
分享到:
评论

相关推荐

    BP神经网络求解XOR问题源代码

    本文将详细介绍BP神经网络如何解决经典的XOR问题,并通过源代码进行解析。 XOR(异或)问题是一个典型的二分类问题,它的输出只有两个可能的结果:0或1。对于输入对(0,0)、(0,1)、(1,0)和(1,1),XOR的逻辑关系为: ...

    RBF神经网络解决分类问题 用matlab编写

    RBF神经网络(Radial Basis Function,径向基函数神经网络)...通过熟练掌握这些知识点,可以有效地利用RBF神经网络解决各种分类问题。在实际应用中,要注重理论与实践的结合,不断尝试和调整,以达到最佳的分类效果。

    xor问题代码

    应用人工神经网络的方法实现xor问题的求解,对xor问题的求解有很大的帮助

    神经网络解决抑或(XOR)问题(python代码)

    在本文中,我们将深入探讨如何使用Python编程语言和神经网络模型来解决经典的逻辑运算问题——抑或(XOR)问题。XOR问题之所以经典,是因为它不能被简单的线性模型解决,而需要非线性的处理能力,这正是神经网络的...

    解决异或(XOR)问题简单的神经网络

    解决疑惑问题简单的神经网络,根据自定义迭代次数和自定义的学习效率解决(0,1)以及(0,1,0)的输入问题:基本的方式严格按照神经网络标准进行,是合格的python代码。

    采用bp解决xor问题

    通过BP网络可以解决XOR问题,XOR问题就是如何用神经网络实现异或逻辑关系,即 Y=A XOR B。多层神经网络可以解决这个问题,因为多层网络引入了中间隐含层,每个隐含神经元可以按不同的方法来划分输入空间抽取输入空间...

    BP神经网络解决异或问题

    在这个案例中,我们将探讨如何使用BP神经网络解决经典的异或(XOR)问题。 异或问题是一个二元逻辑运算,其输出只有两个可能的结果:0或1。对于输入A和B,当A和B相同时输出为0,不同时输出为1。异或问题的非线性...

    自己编写的BP神经网络解决异或问题代码

    在这个例子中,我们使用了一个隐含层,因为异或问题可以通过一个隐含层的神经网络解决。每个神经元包含一个加权求和的操作,接着是一个非线性激活函数,如sigmoid: \[ f(x) = \frac{1}{1+e^{-x}} \] **训练过程**...

    神经网络解决异或问题matlab程序

    5. **代码注释**:在提供的程序中,详细的注释有助于理解每个部分的功能,这对于初学者来说非常重要,能够帮助他们快速掌握神经网络的实现过程。 6. **可视化**:MATLAB还提供了绘制网络结构和训练过程曲线的功能,...

    matlab_BP_XOR.rar

    MATLAB是实现神经网络的一个强大平台,提供了神经网络工具箱,使得创建、训练和评估神经网络变得更加便捷。工具箱可能包括了预定义的网络结构、训练函数、可视化工具等,使得理解BP算法和调试代码变得更加直观。 在...

    神经网络源代码(调试没有问题的)C语言版

    本资源包含的是用C语言编写的神经网络源代码,这对于那些想要深入理解神经网络工作原理,或者需要在实际项目中应用神经网络的开发者来说,是一份非常有价值的资料。 首先,我们要理解神经网络的基本构成。神经网络...

    xor.zip_XOR_XOR问题分类_bp xor_xor matlab_xor分类问题

    bp解决xor问题 BP网络是目前前馈式神经网络中应用最广泛的网络之一,实现BP算法训练神经网络完成XOR的分类问题。 设计要求: (1) 能够设置网络的输入节点数、隐节点数、网络层数、学习常数等各项参数; (2) 能够...

    三元XOR问题的神经网络学习.pdf

    本文探讨了三元XOR问题的神经网络学习问题,讨论了神经网络学习单隐层前馈神经网络(SLFN)的结构误差问题,并提出了解决该问题的一种方法。该方法结合分组思想和隐函数定理,讨论了如何确定输入层与隐层的连接权...

    MATLAB神经网络.zip_BP神经网络_matlab 感知机xor_matlab神经网络_单层BP神经_单层感知器

    本资料包“MATLAB神经网络.zip”包含了多个关于神经网络的学习资源,特别是BP神经网络、MATLAB中的感知机解决XOR问题以及单层神经网络的相关实现。 **BP神经网络**(Backpropagation Neural Network)是模拟人脑...

    XOR.zip_XOR_matlab xor_neural_xor matlab_zip

    3. "XOR.m"很可能包含了实现神经网络解决XOR问题的主要代码,包括定义网络结构、初始化权重、前向传播、反向传播、更新权重等步骤。 在MATLAB中,实现神经网络通常会使用`neuralnet`或者`patternnet`函数。用户可能...

    C++写的遗传算法优化神经网络的源程序

    在压缩包内的文件"遗传算法优化神经网络XOR"可能包含了实现遗传算法优化神经网络解决XOR问题的具体代码文件。这些文件可能包括了数据预处理、神经网络模型定义、遗传算法的实现、训练与测试过程,以及可能的可视化...

    BP神经网络解决异或逻辑的两种方法(matlab)

    BP神经网络解决异或逻辑的两种方法,matlab的源代码程序文件,这是初学BP神经网络会碰到的一个比较棘手的问题,本代码提供了两种不同的方法实现BP神经网络解决异或逻辑,可能比较基础,方法也不是很牛逼,纯自己瞎玩...

    BP神经网络解决异或问题_BP算法_

    总的来说,BP神经网络解决异或问题展示了神经网络在解决非线性问题上的优势。通过理解和实现这一过程,你可以深化对神经网络、误差反向传播算法以及数字图像模式识别的理解,为进一步学习更复杂的深度学习模型打下...

Global site tag (gtag.js) - Google Analytics