`
linest
  • 浏览: 155533 次
  • 性别: Icon_minigender_1
  • 来自: 内蒙古
社区版块
存档分类
最新评论

读代码-TrainClassifier和TestClassifier

 
阅读更多
package org.apache.mahout.classifier.bayes;
public final class TrainClassifier
bayes和cbyes的入口类

两个分支
  public static void trainNaiveBayes(Path dir, Path outputDir, BayesParameters params) throws IOException {
    BayesDriver driver = new BayesDriver();
    driver.runJob(dir, outputDir, params);
  }
  
  public static void trainCNaiveBayes(Path dir, Path outputDir, BayesParameters params) throws IOException {
    CBayesDriver driver = new CBayesDriver();
    driver.runJob(dir, outputDir, params);
  }



先设定所有默认参数,如果有非默认项再覆盖
由于参数过多,定义一个类封装参数,便于后续传递
      BayesParameters params = new BayesParameters();
      // Setting all the default parameter values
      params.setGramSize(1);
      params.setMinDF(1);
      params.set("alpha_i","1.0");
      params.set("dataSource", "hdfs");
      
      if (cmdLine.hasOption(gramSizeOpt)) {
        params.setGramSize(Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)));
      }
      
      if (cmdLine.hasOption(minDfOpt)) {
        params.setMinDF(Integer.parseInt((String) cmdLine.getValue(minDfOpt)));
      }



      Path inputPath = new Path((String) cmdLine.getValue(inputDirOpt));
      Path outputPath = new Path((String) cmdLine.getValue(outputOpt));
      if ("cbayes".equalsIgnoreCase(classifierType)) {
        log.info("Training Complementary Bayes Classifier");
        trainCNaiveBayes(inputPath, outputPath, params);
      } else {
        log.info("Training Bayes Classifier");
        // setup the HDFS and copy the files there, then run the trainer
        trainNaiveBayes(inputPath, outputPath, params);
      }




package org.apache.mahout.classifier.bayes;
public final class TestClassifier

入口
  public static void classifyParallel(BayesParameters params) throws IOException {
    BayesClassifierDriver.runJob(params);
  }


分并行和非并行两种实现
      if ("sequential".equalsIgnoreCase(classificationMethod)) {
        classifySequential(params);
      } else if ("mapreduce".equalsIgnoreCase(classificationMethod)) {
        classifyParallel(params);
      }
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics