`
thecloud
  • 浏览: 953814 次
文章分类
社区版块
存档分类
最新评论

mahout源码分析之Decision Forest结束篇

 
阅读更多

Mahout版本:0.7,hadoop版本:1.0.4,jdk:1.7.0_25 64bit。

Mahout系列之Decision Forest写了几篇,其中的一些过程并没有详细说明,这里就分析一下,作为Decision Forest算法系列的结束篇。

主要的问题包括:(1)在Build Forest中分析完了Step1Mapper后就没有向下分析了,而是直接进行TestForest的分析了,中间其实还是有很多操作的,比如:把Step1Mapper的Job的输出进行转换写入文件。(2)在BuildForest中没有分析当输入是Categorical的情况,这种情况下面执行的某些步骤是不一样的,主要是在DecisionTreeBuilder中的build方法中的区分。(3)在前一篇中最后的使用forest进行对数据的分类只是简要的说了下,这里详细分析下代码。(4)决策树同样可以做回归分析,在Describe阶段设置为回归问题就可以了,但是这里就不想做分析了。下面来分条进行分析:

(1)在BuildForest中提交任务后实际运行的类是Builder中的build方法中的代码。这里面的代码任务运行后的代码如下:

if (isOutput(conf)) {
      log.debug("Parsing the output...");
      DecisionForest forest = parseOutput(job);
      HadoopUtil.delete(conf, outputPath);
      return forest;
    }
isOutput():

protected static boolean isOutput(Configuration conf) {
    return conf.getBoolean("debug.mahout.rf.output", true);
  }
可以看到这个函数去判断是否设置了debug.mahout.rf.output,如果没有设置则返回true,否则,就说明设置过了就按照设置的值来返回。这里一般都没有设置,所以就会运行if里面的代码先把job的输出传入到forest变量,然后删除job的输出。看parseOutput的操作:

protected DecisionForest parseOutput(Job job) throws IOException {
    Configuration conf = job.getConfiguration();
    
    int numTrees = Builder.getNbTrees(conf);
    
    Path outputPath = getOutputPath(conf);
    
    TreeID[] keys = new TreeID[numTrees];
    Node[] trees = new Node[numTrees];
        
    processOutput(job, outputPath, keys, trees);
    
    return new DecisionForest(Arrays.asList(trees));
  }
这里面又有一个processOutput函数,前面就是设置一些变量的size之类的,然后到processOutput函数,看这个函数:

protected static void processOutput(JobContext job,
                                      Path outputPath,
                                      TreeID[] keys,
                                      Node[] trees) throws IOException {
    Preconditions.checkArgument(keys == null && trees == null || keys != null && trees != null,
        "if keys is null, trees should also be null");
    Preconditions.checkArgument(keys == null || keys.length == trees.length, "keys.length != trees.length");

    Configuration conf = job.getConfiguration();

    FileSystem fs = outputPath.getFileSystem(conf);

    Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);

    // read all the outputs
    int index = 0;
    for (Path path : outfiles) {
      for (Pair<TreeID,MapredOutput> record : new SequenceFileIterable<TreeID, MapredOutput>(path, conf)) {
        TreeID key = record.getFirst();
        MapredOutput value = record.getSecond();
        if (keys != null) {
          keys[index] = key;
        }
        if (trees != null) {
          trees[index] = value.getTree();
        }
        index++;
      }
    }

    // make sure we got all the keys/values
    if (keys != null && index != keys.length) {
      throw new IllegalStateException("Some key/values are missing from the output");
    }
  }
这里看到就是把job的输出按条读出然后写入到Node[] trees数组中,然后把数组转换为list,赋值给DecisionForest变量new DecisionForest(Arrays.asList(trees))。最后返回到BuildForest中DFUtils.storeWritable(getConf(), forestPath, forest);,这个主要是写文件,基本没啥内容了。

(2)当输入数据中存在有Categorical的属性列时,最先的不同就是在dataset的values属性。这个values数组当输入数据属性是Numerical的时候对应的值就是null,如果是Categorical的时候就会存入相应的离散值。其次就是在DecisionTreeBuilder中find the best split这一部分的代码(源文件中192行),这里计算Split的时候分为了Categorical和Numerical,如下:

public Split computeSplit(Data data, int attr) {
    if (data.getDataset().isNumerical(attr)) {
      return numericalSplit(data, attr);
    } else {
      return categoricalSplit(data, attr);
    }
  }
看categoricalSplit函数:

 private static Split categoricalSplit(Data data, int attr) {
    double[] values = data.values(attr);
    int[][] counts = new int[values.length][data.getDataset().nblabels()];
    int[] countAll = new int[data.getDataset().nblabels()];

    Dataset dataset = data.getDataset();

    // compute frequencies
    for (int index = 0; index < data.size(); index++) {
      Instance instance = data.get(index);
      counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
      countAll[(int) dataset.getLabel(instance)]++;
    }

    int size = data.size();
    double hy = entropy(countAll, size); // H(Y)
    double hyx = 0.0; // H(Y|X)
    double invDataSize = 1.0 / size;

    for (int index = 0; index < values.length; index++) {
      size = DataUtils.sum(counts[index]);
      hyx += size * invDataSize * entropy(counts[index], size);
    }

    double ig = hy - hyx;
    return new Split(attr, ig);
  }

这里返回的Split只有两个属性,其实因为属性值是离散的,所以这里只用确定是这个值或者不是即可,不会还要说比较值的大小(而且也没法比)。
然后就是建立节点的部分了。获得最佳属性后,根据这个属性是否是Numerical而进入不同的代码块,如果是Categorical的话,进入:

else { // CATEGORICAL attribute
      double[] values = data.values(best.getAttr());

      // tree is complemented
      Collection<Double> subsetValues = null;
      if (complemented) {
        subsetValues = Sets.newHashSet();
        for (double value : values) {
          subsetValues.add(value);
        }
        values = fullSet.values(best.getAttr());
      }

      int cnt = 0;
      Data[] subsets = new Data[values.length];
      for (int index = 0; index < values.length; index++) {
        if (complemented && !subsetValues.contains(values[index])) {
          continue;
        }
        subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
        if (subsets[index].size() >= minSplitNum) {
          cnt++;
        }
      }

      // size of the subset is less than the minSpitNum
      if (cnt < 2) {
        // branch is not split
        double label;
        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
          label = sum / data.size();
        } else {
          label = data.majorityLabel(rng);
        }
        log.debug("branch is not split Leaf({})", label);
        return new Leaf(label);
      }

      selected[best.getAttr()] = true;

      Node[] children = new Node[values.length];
      for (int index = 0; index < values.length; index++) {
        if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
          // tree is complemented
          double label;
          if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
            label = sum / data.size();
          } else {
            label = data.majorityLabel(rng);
          }
          log.debug("complemented Leaf({})", label);
          children[index] = new Leaf(label);
          continue;
        }
        children[index] = build(rng, subsets[index]);
      }

      selected[best.getAttr()] = alreadySelected;

      childNode = new CategoricalNode(best.getAttr(), values, children);
    }
其实上面的代码和Numerical差不多,可以说作为Numerical的一种特殊情况,即对于Numerical把其区分为等于属性值和不等于属性值即可(但是Numerical是分为小于和等于、大于两种)。其他基本就差不多了。

(3)用forest对数据Instance变量进行分类的代码是在DecisionForest的classify函数里面:

public double classify(Dataset dataset, Random rng, Instance instance) {
    if (dataset.isNumerical(dataset.getLabelId())) {
      double sum = 0;
      int cnt = 0;
      for (Node tree : trees) {
        double prediction = tree.classify(instance);
        if (prediction != -1) {
          sum += prediction;
          cnt++;
        }
      }
      return sum / cnt;
    } else {
      int[] predictions = new int[dataset.nblabels()];
      for (Node tree : trees) {
        double prediction = tree.classify(instance);
        if (prediction != -1) {
          predictions[(int) prediction]++;
        }
      }
      
      if (DataUtils.sum(predictions) == 0) {
        return -1; // no prediction available
      }
上面就是前篇讲到的所有树都对这个数据进行分类,然后按最多次数的那个类别即是最后的结果。但是一棵树是如何分类的?这个又分为了两种,好吧,应该不难猜,就是Numerical的树和Categorical的树。分别来看,首先是Numerical:

public double classify(Instance instance) {
    if (instance.get(attr) < split) {
      return loChild.classify(instance);
    } else {
      return hiChild.classify(instance);
    }
  }
看到它是去找它的子树去了,然后最后到哪里?其实是到了Leaf的classify函数了:

@Override
  public double classify(Instance instance) {
    return label;
  }
这个也是一个递归的过程,其实就是建树过程的一个反过程而已,这样其实Categorical也是一样的了,只是要做些转换而已:

public double classify(Instance instance) {
    int index = ArrayUtils.indexOf(values, instance.get(attr));
    if (index == -1) {
      // value not available, we cannot predict
      return -1;
    }
    return childs[index].classify(instance);
  }
这样基本就ok了,下次再看这个算法的时候应该是要分析回归问题了?



分享,成长,快乐

转载请注明blog地址:http://blog.csdn.net/fansy1990



分享到:
评论

相关推荐

    mahout Algorithms源码分析

    樊哲是Mahout的积极学习者和实践者,他在CSDN上分享了关于Mahout算法的解析与案例实战的博客,获得了“CSDN2013博客之星”的荣誉。樊哲的经验表明,虽然Hadoop平台上算法开发一般需要耗费很长时间,但Mahout已经实现...

    Mahout源码

    **Apache Mahout 源码解析** Apache Mahout 是一个基于Java的开源机器学习库,旨在简化大规模数据集上的机器学习算法实现。它为开发者提供了一系列预构建的、可扩展的机器学习算法,包括分类、聚类、推荐系统以及...

    mahout源码

    本篇将深入探讨Mahout中的朴素贝叶斯分类以及中文分词这两个核心功能。 一、Mahout与朴素贝叶斯分类 1. **Mahout简介** Apache Mahout的命名来源于古印度的一种数学算术工具,它体现了项目的目标——通过数学算法...

    mahout-distribution-0.5-src.zip mahout 源码包

    mahout-distribution-0.5-src.zip mahout 源码包

    Mahout RandomForest Example使用步骤

    **Apache Mahout Random Forest 示例详解** Apache Mahout 是一个基于 Apache Hadoop 的机器学习库,提供了多种算法,包括分类、聚类和推荐系统等。在这些算法中,随机森林(Random Forest)是一种广泛使用的集成...

    svd mahout算法

    svd算法的工具类,直接调用出结果,调用及设置方式参考http://blog.csdn.net/fansy1990 &lt;mahout源码分析之DistributedLanczosSolver(七)&gt;

    MAHOUT源码包

    Mahout 是 Apache Software Foundation(ASF) 旗下的一个开源项目,提供一些可扩展的机器学习领域经典算法的实现,旨在帮助开发人员更加方便快捷地创建智能应用程序。Mahout包含许多实现,包括聚类、分类、推荐过滤...

    mahout in action中的源码

    《Mahout in Action》是一本深入探讨Apache Mahout机器学习框架的专业书籍,其源码提供了丰富的实践示例和深入理解Mahout算法的机会。在GitHub上,你可以找到这些源码的完整版本,链接为。下面,我们将详细探讨...

    mahout in action源代码maven编译jar包

    Apache Mahout是一个流行的机器学习库,广泛用于数据挖掘和大数据分析。《Mahout in Action》这本书是Mahout技术的权威指南,提供了丰富的示例代码供读者实践。然而,在实际操作过程中,使用Maven编译书中提供的源...

    mahout0.9 源码

    以上就是关于Mahout 0.9源码及其在Eclipse中的使用介绍。通过学习和实践,开发者可以利用Mahout构建强大的机器学习应用,处理各种数据挖掘任务。在实际应用中,可以根据项目需求选择合适的算法,结合Hadoop分布式...

    mahout 0.7 src

    mahout 0.7 src, mahout 源码包, hadoop 机器学习子项目 mahout 源码包

    mahout-distribution-0.7-src.zip

    《Apache Mahout 0.7源码解析与应用探索》 Apache Mahout 是一个开源机器学习库,专注于大规模数据集的算法实现。该库由Java编写,并采用Maven作为构建工具,提供了一系列用于构建智能应用的高效算法。本文将深入...

    [Mahout] Windows下Mahout单机安装

    打开命令行,进入解压后的Mahout源码目录,执行以下Maven命令来构建Mahout: ``` mvn clean install -DskipTests ``` 这个过程可能会比较耗时,因为Maven会自动下载所有依赖。等待编译完成后,Mahout的可执行jar文件...

    mahout0.9源码(支持hadoop2)

    mahout0.9的源码,支持hadoop2,需要自行使用mvn编译。mvn编译使用命令: mvn clean install -Dhadoop2 -Dhadoop.2.version=2.2.0 -DskipTests

    人工智能-推荐系统-新闻推荐-基于Mahout的新闻推荐系统

    Mahout:整体框架,实现了协同过滤 Deeplearning4j,构建VSM Jieba:分词,关键词提取 HanLP:分词,关键词提取 Spring Boot:提供API、ORM 关键实现 基于用户的协同过滤 直接调用Mahout相关接口即可 选择不同...

    mahout-core-0.9.jar+mahout-core-0.8.jar+mahout-core-0.1.jar

    Mahout是建立在Hadoop之上的,利用其分布式计算能力处理大规模数据集。这使得Mahout能够处理超出单台机器内存和计算能力的数据。 3. **版本差异**: - mahout-core-0.1.jar:这是早期版本,可能包含的基本功能,...

    apache-mahout-distribution-0.11.0-src.zip

    在源码中,您可以探索Mahout实现的各种算法,如协同过滤(Collaborative Filtering)、频繁项集挖掘(Frequent Itemset Mining)、近邻搜索(Nearest Neighbor Search)等。这些算法是通过Java编程语言实现的,因此...

    mahout所需jar包

    Mahout的目标是帮助开发人员构建智能应用程序,如推荐系统、分类和聚类算法,这些在大数据分析领域中极为重要。 **K-Means聚类算法** K-Means是一种无监督学习的聚类算法,用于将数据集分成不同的群组或类别。在...

    Mahout In Action英文完整版

    3. **分类算法**:除了推荐系统和聚类外,Mahout还支持多种分类算法,如决策树(Decision Tree)、随机森林(Random Forest)等。这些算法主要用于预测数据的类别归属,广泛应用于文本分类、情感分析等领域。 #### 五、...

Global site tag (gtag.js) - Google Analytics