`
thd52java
  • 浏览: 72108 次
  • 性别: Icon_minigender_1
  • 来自: 北京
社区版块
存档分类
最新评论

Mahout 系列之----共轭梯度

阅读更多

无预处理共轭梯度

 

 

 

要求解线性方程组 \boldsymbol{Ax}=\boldsymbol{b},稳定双共轭梯度法从初始解 \boldsymbol{x}_0 开始按以下步骤迭代:

 

  1. \boldsymbol{r}_0=\boldsymbol{b}-\boldsymbol{Ax}
  2. 任意选择向量 \boldsymbol{\hat{r}}_0\; 使得 (\boldsymbol{\hat{r}}_0,\boldsymbol{r}_0) \neq0\;,例如,\boldsymbol{\hat{r}}_0=\boldsymbol{r}_0\;
  3. \rho_0=\alpha=\omega_0=1\;
  4. \boldsymbol{v}_0=\boldsymbol{p}_0=\boldsymbol{0}
  5. i=1,2,3,\ldots
    1. \rho_i=(\boldsymbol{\hat{r}}_0,\boldsymbol{r}_{i-1})\;
    2. \beta=(\rho_i/\rho_{i-1})(\alpha/\omega_{i-1})\;
    3. \boldsymbol{p}_i=\boldsymbol{r}_{i-1}+\beta(\boldsymbol{p}_{i-1}-\omega_{i-1}\boldsymbol{v}_{i-1})
    4. \boldsymbol{v}_i=\boldsymbol{Ap}_i
    5. \alpha=\rho_i/(\boldsymbol{\hat{r}}_0,\boldsymbol{v}_i)\;
    6. \boldsymbol{s}=\boldsymbol{r}_{i-1}-\alpha\boldsymbol{v}_i
    7. \boldsymbol{t}=\boldsymbol{As}
    8. \omega_i=(\boldsymbol{t},\boldsymbol{s})/(\boldsymbol{t},\boldsymbol{t})
    9. \boldsymbol{x}_i=\boldsymbol{x}_{i-1}+\alpha\boldsymbol{p}_i+\omega_i\boldsymbol{s}
    10. \boldsymbol{x}_i 足够精确则退出
    11. \boldsymbol{r}_i=\boldsymbol{s}-\omega_i\boldsymbol{t}

 

预处理共轭梯度

 

预处理通常被用来加速迭代方法的收敛。要使用预处理子 \boldsymbol{K}=\boldsymbol{K}_1\boldsymbol{K}_2\approx\boldsymbol{A} 来求解线性方程组 \boldsymbol{Ax}=\boldsymbol{b},预处理稳定双共轭梯度法从初始解 \boldsymbol{x}_0 开始按以下步骤迭代:

 

  1. \boldsymbol{r}_0=\boldsymbol{b}-\boldsymbol{Ax}
  2. 任意选择向量 \boldsymbol{\hat{r}}_0\; 使得 (\boldsymbol{\hat{r}}_0,\boldsymbol{r}_0) \neq0\;,例如,\boldsymbol{\hat{r}}_0=\boldsymbol{r}_0\;
  3. \rho_0=\alpha=\omega_0=1\;
  4. \boldsymbol{v}_0=\boldsymbol{p}_0=\boldsymbol{0}
  5. i=1,2,3,\ldots
    1. \rho_i=(\boldsymbol{\hat{r}}_0,\boldsymbol{r}_{i-1})\;
    2. \beta=(\rho_i/\rho_{i-1})(\alpha/\omega_{i-1})\;
    3. \boldsymbol{p}_i=\boldsymbol{r}_{i-1}+\beta(\boldsymbol{p}_{i-1}-\omega_{i-1}\boldsymbol{v}_{i-1})
    4. \boldsymbol{y}=\boldsymbol{K}^{-1}\boldsymbol{p}_i
    5. \boldsymbol{v}_i=\boldsymbol{Ay}
    6. \alpha=\rho_i/(\boldsymbol{\hat{r}}_0,\boldsymbol{v}_i)\;
    7. \boldsymbol{s}=\boldsymbol{r}_i-\alpha\boldsymbol{v}_i
    8. \boldsymbol{z}=\boldsymbol{As}
    9. \boldsymbol{t}=\boldsymbol{K}^{-1}\boldsymbol{z}
    10. \omega_i=(\boldsymbol{K}_1^{-1}\boldsymbol{t},\boldsymbol{K}_1^{-1}\boldsymbol{s})/(\boldsymbol{K}_1^{-1}\boldsymbol{s},\boldsymbol{K}_1^{-1}\boldsymbol{s})
    11. \boldsymbol{x}_i=\boldsymbol{x}_{i-1}+\alpha\boldsymbol{y}+\omega_i\boldsymbol{z}
    12. \boldsymbol{x}_i 足够精确则退出
    13. \boldsymbol{r}_i=\boldsymbol{s}-\omega_i\boldsymbol{t}

 

这个形式等价于将无预处理的稳定双共轭梯度法应用于显式预处理后的方程组

 

\boldsymbol{\tilde{A}\tilde{x}}=\boldsymbol{\tilde{b}}

 

其中 \boldsymbol{\tilde{A}}=\boldsymbol{K}_1^{-1}\boldsymbol{AK}_2^{-1}\boldsymbol{\tilde{x}}=\boldsymbol{K}_2\boldsymbol{x}\boldsymbol{\tilde{b}}=\boldsymbol{K}_1^{-1}\boldsymbol{b}。换句话说,左预处理和右预处理都可以通过这个形式实施。

 

Mahout 分布式共轭梯度实现:

 


package org.apache.mahout.math.solver;

 

import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.PlusMult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

 

/**
 * <p>Implementation of a conjugate gradient iterative solver for linear systems. Implements both
 * standard conjugate gradient and pre-conditioned conjugate gradient.
 *
 * <p>Conjugate gradient requires the matrix A in the linear system Ax = b to be symmetric and positive
 * definite. For convenience, this implementation allows the input matrix to be be non-symmetric, in
 * which case the system A'Ax = b is solved. Because this requires only one pass through the matrix A, it
 * is faster than explictly computing A'A, then passing the results to the solver.
 *
 * <p>For inputs that may be ill conditioned (often the case for highly sparse input), this solver
 * also accepts a parameter, lambda, which adds a scaled identity to the matrix A, solving the system
 * (A + lambda*I)x = b. This obviously changes the solution, but it will guarantee solvability. The
 * ridge regression approach to linear regression is a common use of this feature.
 *
 * <p>If only an approximate solution is required, the maximum number of iterations or the error threshold
 * may be specified to end the algorithm early at the expense of accuracy. When the matrix A is ill conditioned,
 * it may sometimes be necessary to increase the maximum number of iterations above the default of A.numCols()
 * due to numerical issues.
 *
 * <p>By default the solver will run a.numCols() iterations or until the residual falls below 1E-9.
 *
 * <p>For more information on the conjugate gradient algorithm, see Golub & van Loan, "Matrix Computations",
 * sections 10.2 and 10.3 or the <a href="http://en.wikipedia.org/wiki/Conjugate_gradient">conjugate gradient
 * wikipedia article</a>.
 */

 

public class ConjugateGradientSolver {

 

  public static final double DEFAULT_MAX_ERROR = 1.0e-9;
 
  private static final Logger log = LoggerFactory.getLogger(ConjugateGradientSolver.class);
  private static final PlusMult PLUS_MULT = new PlusMult(1.0);

 

  private int iterations;
  private double residualNormSquared;
 
  public ConjugateGradientSolver() {
    this.iterations = 0;
    this.residualNormSquared = Double.NaN;
  } 

 

  /**
   * Solves the system Ax = b with default termination criteria. A must be symmetric, square, and positive definite.
   * Only the squareness of a is checked, since testing for symmetry and positive definiteness are too expensive. If
   * an invalid matrix is specified, then the algorithm may not yield a valid result.
   * 
   * @param a  The linear operator A.
   * @param b  The vector b.
   * @return The result x of solving the system.
   * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
   *
   */
  public Vector solve(VectorIterable a, Vector b) {
    return solve(a, b, null, b.size(), DEFAULT_MAX_ERROR);
  }
 
  /**
   * Solves the system Ax = b with default termination criteria using the specified preconditioner. A must be
   * symmetric, square, and positive definite. Only the squareness of a is checked, since testing for symmetry
   * and positive definiteness are too expensive. If an invalid matrix is specified, then the algorithm may not
   * yield a valid result.
   * 
   * @param a  The linear operator A.
   * @param b  The vector b.
   * @param precond A preconditioner to use on A during the solution process.
   * @return The result x of solving the system.
   * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a.
   *
   */
  public Vector solve(VectorIterable a, Vector b, Preconditioner precond) {
    return solve(a, b, precond, b.size(), DEFAULT_MAX_ERROR);
  }
 

 

  /**
   * Solves the system Ax = b, where A is a linear operator and b is a vector. Uses the specified preconditioner
   * to improve numeric stability and possibly speed convergence. This version of solve() allows control over the
   * termination and iteration parameters.
   *
   * @param a  The matrix A.
   * @param b  The vector b.
   * @param preconditioner The preconditioner to apply.
   * @param maxIterations The maximum number of iterations to run.
   * @param maxError The maximum amount of residual error to tolerate. The algorithm will run until the residual falls
   * below this value or until maxIterations are completed.
   * @return The result x of solving the system.
   * @throws IllegalArgumentException if the matrix is not square, if the size of b is not equal to the number of
   * columns of A, if maxError is less than zero, or if maxIterations is not positive.
   */
  

 

 

 

// 共轭梯度实现的主题部分。很明显该方法是既可以用预处理的方式,也可以不用预处理的方式。Mahout中提供了单机模式的雅克比预处理,但是没有提供分布式处理的雅克比预处理,这个需要自己写。很简单,只要将对角线元素去倒数,组成一个对角阵即可。
  public Vector solve(VectorIterable a,
                      Vector b,
                      Preconditioner preconditioner,
                      int maxIterations,
                      double maxError) {

 

    if (a.numRows() != a.numCols()) {
      throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite.");
    }
   
    if (a.numCols() != b.size()) {
      throw new CardinalityException(a.numCols(), b.size());
    }

 

    if (maxIterations <= 0) {
      throw new IllegalArgumentException("Max iterations must be positive.");     
    }
   
    if (maxError < 0.0) {
      throw new IllegalArgumentException("Max error must be non-negative.");
    }
   
    Vector x = new DenseVector(b.size());

 

    iterations = 0;
    Vector residual = b.minus(a.times(x));
    residualNormSquared = residual.dot(residual);

 

    log.info("Conjugate gradient initial residual norm = {}", Math.sqrt(residualNormSquared));
    double previousConditionedNormSqr = 0.0;
    Vector updateDirection = null;
    while (Math.sqrt(residualNormSquared) > maxError && iterations < maxIterations) {
      Vector conditionedResidual;
      double conditionedNormSqr;
      if (preconditioner == null) {
        conditionedResidual = residual;
        conditionedNormSqr = residualNormSquared;
      } else {
        conditionedResidual = preconditioner.precondition(residual);
        conditionedNormSqr = residual.dot(conditionedResidual);
      }     
     
      ++iterations;
     
      if (iterations == 1) {
        updateDirection = new DenseVector(conditionedResidual);
      } else {
        double beta = conditionedNormSqr / previousConditionedNormSqr;
       
        // updateDirection = residual + beta * updateDirection
        updateDirection.assign(Functions.MULT, beta);
        updateDirection.assign(conditionedResidual, Functions.PLUS);
      }
     
      Vector aTimesUpdate = a.times(updateDirection);
     
      double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate);
     
      // x = x + alpha * updateDirection
      PLUS_MULT.setMultiplicator(alpha);
      x.assign(updateDirection, PLUS_MULT);

 

      // residual = residual - alpha * A * updateDirection
      PLUS_MULT.setMultiplicator(-alpha);
      residual.assign(aTimesUpdate, PLUS_MULT);
     
      previousConditionedNormSqr = conditionedNormSqr;
      residualNormSquared = residual.dot(residual);
     
      log.info("Conjugate gradient iteration {} residual norm = {}", iterations, Math.sqrt(residualNormSquared));
    }
    return x;
  }

 

  /**
   * Returns the number of iterations run once the solver is complete.
   *
   * @return The number of iterations run.
   */
  public int getIterations() {
    return iterations;
  }

 

  /**
   * Returns the norm of the residual at the completion of the solver. Usually this should be close to zero except in
   * the case of a non positive definite matrix A, which results in an unsolvable system, or for ill conditioned A, in
   * which case more iterations than the default may be needed.
   *
   * @return The norm of the residual in the solution.
   */
  public double getResidualNorm() {
    return Math.sqrt(residualNormSquared);
  } 
}

 

 

 

 

DistributedConjugateGradientSolver  是上CG的扩展,DCG和CG的区别在于,DCG矩阵和向量相乘时需要MR实现矩阵相乘。

2
1
分享到:
评论

相关推荐

    mahout-core-0.9.jar+mahout-core-0.8.jar+mahout-core-0.1.jar

    这个压缩包包含的是Mahout项目不同版本的核心库,分别是mahout-core-0.9.jar、mahout-core-0.8.jar和mahout-core-0.1.jar。这些版本的差异在于功能的完善、性能的优化以及对新特性的支持。 1. **Mahout核心功能**:...

    mahout-distribution-0.9.tar.gz

    "mahout-distribution-0.9.tar.gz"是Apache Mahout的0.9版本的发行包,包含了完整的源代码、文档和所需的依赖库。 **一、Mahout的背景与目标** Apache Mahout项目始于2008年,旨在简化大规模机器学习过程,提供可...

    apache-mahout-distribution-0.11.0-src.zip

    在"apache-mahout-distribution-0.11.0-src.zip"这个压缩包中,您将找到Mahout 0.11.0版本的源代码,这对于开发者和研究者来说是一个宝贵的资源,他们可以深入理解算法的内部工作原理,进行定制化开发或优化。...

    mahout-0.9-cdh5.5.0.tar.gz

    mahout-0.9-cdh5.5.0.tar.gz

    maven_mahout_template-mahout-0.8

    《Apache Maven与Mahout实战:基于maven_mahout_template-mahout-0.8的探索》 Apache Maven是一款强大的项目管理和依赖管理工具,广泛应用于Java开发领域。它通过一个项目对象模型(Project Object Model,POM)来...

    mahout-core-0.9.jar

    mahout-core-0.9.jar,支持版本hadoop-2.2.x,由mahout-distribution-0.9.tar.gz源码构建生成jar包。

    mahout-distribution-0.5-src.zip mahout 源码包

    mahout-distribution-0.5-src.zip mahout 源码包

    mahout-integration-0.7

    mahout-integration-0.7mahout-integration-0.7mahout-integration-0.7mahout-integration-0.7

    mahout-examples-0.10.1-job.jar

    mahout-examples-0.10.1-job.jar 已经包含分词程序,替换掉mahout默认的jar包

    mahout-distribution-0.9-src.zip

    标题中的"mahout-distribution-0.9-src.zip"指的是Mahout项目在0.9版本的源代码分布,这对于开发者来说是一个宝贵的资源,可以深入理解其内部实现并进行定制化开发。 Apache Mahout的核心特性主要体现在以下几个...

    mahout-distribution-0.10.0-src.tar.gz

    mahout-distribution-0.10.0-src.tar.gz

    mahout-core-0.3.jar

    mahout中需要用到的一个版本jar包:mahout-core-0.3.jar

    mahout-distribution-0.12.2-src.tar.gz

    这个压缩包“mahout-distribution-0.12.2-src.tar.gz”是Mahout项目的一个源码版本,版本号为0.12.2,提供给开发者进行深度研究和定制化开发。在解压后的文件“apache-mahout-distribution-0.12.2”中,我们可以找到...

    mahout-examples-0.9-job.jar(修改版)

    重新编译mahout-examples-0.9-job.jar,增加分类指标:最小最大精度、召回率。详情见http://blog.csdn.net/u012948976/article/details/50203249

    apache-mahout-distribution-0.12.1.tar.gz

    apache-mahout-distribution-0.12.1.tar.gz 开源版本 .

    mahout-distribution-0.8-src

    在Mahout-distribution-0.8-src这个源代码包中,我们可以深入理解其内部机制,同时也为开发者提供了实现自定义机器学习模型的可能。 一、Mahout 0.8概览 Mahout 0.8 版本是该项目的一个重要里程碑,它包含了丰富的...

    mahout-distribution-0.9含jar包

    "mahout-distribution-0.9含jar包" 是一个包含了Mahout项目0.9版本的预编译二进制文件集合,其中不包含源代码,适合那些希望直接使用Mahout功能而不需要进行编译或开发的用户。 在Mahout 0.9版本中,你可以找到以下...

    如何成功运行Apache Mahout的Taste Webapp-Mahout推荐教程-Maven3.0.5-JDK1.6-Mahout0.5

    在Mahout Taste Webapp工程中,需要添加对mahout-examples的依赖,这一步骤是必须的,因为示例代码提供了实际运行推荐系统所必需的组件。 6. 配置推荐引擎的属性 在Mahout Taste Webapp的recommender.properties...

    mahout-distribution-0.7-src.zip

    2. 解压`mahout-distribution-0.7-src.zip`文件到本地目录。 3. 进入解压后的源码目录,执行`mvn clean install`命令进行编译。这会下载依赖项,构建Mahout的jar包。 4. 编译完成后,可以在`target`目录下找到编译...

    mahout-distribution-0.5.tar.gz + 源码

    在"mahout-distribution-0.5.tar.gz"这个压缩包中,包含了Mahout项目0.5版本的所有源代码和相关文件,这对于开发者和学习者来说是一个宝贵的学习资源。在"MiA_SourceCode.zip"中,可能包含了一些特定的示例或教程的...

Global site tag (gtag.js) - Google Analytics