`
wbj0110
  • 浏览: 1591203 次
  • 性别: Icon_minigender_1
  • 来自: 上海
文章分类
社区版块
存档分类
最新评论

mahout之TrainNaiveBayesJob源码分析

阅读更多

mahout的trainnb调用的是TrainNaiveBayesJob完成训练模型任务。所在包:

org.apache.mahout.classifier.naivebayes.training

TrainNaiveBayesJob的输入是在tfidf文件上split出来的一部分,用作训练。
TrainNaiveBayesJob代码分析,
首先加入一些命令行选项,如

LABEL      -L
ALPHA_I  -a
LABEL_INDEX  -li
TRAIN_COMPLEMENTARY      -c

然后从输入文件中读取label,将label保存于label index,例如20news group的例子,读取的label有两个,label index如下

Key class: class org.apache.hadoop.io.Text   Value Class: class org.apache.hadoop.io.IntWritable
Key: 20news-bydate-test: Value: 0
Key: 20news-bydate-train: Value: 1

其实也就是将分类建一个索引。

接下来,将相同label的vectors相加。也就是将同一个类别的所有的文章的vector相加。这里vector其实是一个key/value vector,每项由词的id和tfidf值组成。这样相加后就是一个一个类的vector,相同id的tfidf相加,没有的则插入,类似两个递增的链表的合并。由一个job来完成:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
//      Key class: class org.apache.hadoop.io.Text
//      Value Class: class org.apache.mahout.math.VectorWritable
//add up all the vectors with the same labels, while mapping the labels into our index
Job indexInstances = prepareJob(getInputPath()//input path
             getTempPath(SUMMED_OBSERVATIONS),             //output path
            SequenceFileInputFormat.class,                        //input format
        IndexInstancesMapper.class,                             //mapper class
        IntWritable.class,                                                 //mapper key
        VectorWritable.class,                                           //mapper value
        VectorSumReducer.class,                                   //reducer class
        IntWritable.class,                                                  //reducer key
        VectorWritable.class,                                          //reducer value
        SequenceFileOutputFormat.class);          //output format
indexInstances.setCombinerClass(VectorSumReducer.class);
boolean succeeded = indexInstances.waitForCompletion(true);
if (!succeeded) {
   return -1;
}

Mapper为IndexInstancesMapper,Reducer为Reducer VectorSumReducer,代码也比较简单,如下,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
  protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOExceptionInterruptedException {
    String label = labelText.toString().split("/")[1];
if (labelIndex.containsKey(label)) {
//从文件中读取的类的index作为key
      ctx.write(new IntWritable(labelIndex.get(label)), instance);
    } else {
      ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
    }
  }
  //相同key的vector相加
  protected void reduce(WritableComparable< ? > key, Iterable< VectorWritable > values, Context ctx)
    throws IOExceptionInterruptedException {
    Vector vector = null;
    for (VectorWritable v : values) {
      if (vector == null) {
        vector = v.get();
      } else {
        vector.assign(v.get(), Functions.PLUS);
      }
    }
    ctx.write(key, new VectorWritable(vector));
  }

OK,到现在已经得到了< label_index,label_vector >,即类的id和类中所有item(或者说feature)的TFIDF值。此步得到类似如下的输出,

Key: 0
Value: /comp.sys.ibm.pc.hardware/60252:{93562:17.52922821044922,93559:9.745443344116211,93558:107.53932094573975,93557:49.015570640563965,93556:9.745443344116211……}
key:1
Value:
/alt.atheism/53261:{93562:26.293842315673828,93560:19.490886688232422,93559:9.745443344116211,93558:78.52010536193848,93557:62.2713, 93555:14.35555171……}

下一个阶段就是统计每个label的所有ITIDF和,输入为上一步的输出,并由一个job来执行,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
    //sum up all the weights from the previous step, per label and per feature
    Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
                       getTempPath(WEIGHTS),
            SequenceFileInputFormat.class,
            WeightsMapper.class,
            Text.class,
            VectorWritable.class,
            VectorSumReducer.class,
            Text.class,
            VectorWritable.class,
            SequenceFileOutputFormat.class);
    weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELSString.valueOf(labelSize));
    weightSummer.setCombinerClass(VectorSumReducer.class);
    succeeded = weightSummer.waitForCompletion(true);
    if (!succeeded) {
      return -1;
    }

job的mapper为WeightsMapper,reducer与上一步的相同,为VectorSumReducer。
mapper如下,

1
2
3
4
5
6
7
8
9
  protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOExceptionInterruptedException {
    Vector instance = value.get();
    if (weightsPerFeature == null) {
      weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements());
    }
    int label = index.get();
    weightsPerFeature.assign(instance, Functions.PLUS);
    weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum());
  }

此步的输出写在cleanup()中。

1
2
3
4
5
6
7
8
9
  protected void cleanup(Context ctx) throws IOExceptionInterruptedException {
    if (weightsPerFeature != null) {
      ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
new VectorWritable(weightsPerFeature));
      ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL),
new VectorWritable(weightsPerLabel));
    }
    super.cleanup(ctx);
  }

也就是说输出只有两个key/value.
一个是WEIGHTS_PER_FEATURE(定义的常量,__SPF)
一个是WEIGHTS_PER_LABEL(__SPL)
weightsPerFeature其实就是保持上一步的vector没变,仍然是一个类中所有iterm(feature)的TFIDF。
weightsPerLabel就是求每个label中的和了。
可以看到输出为,

Key: __SPF
Value: {93562:43.82307052612305,93560:19.490886688232422,93559:19.490886688232422,93558:186.05942630767822,93557:111.28696632385254,93556:9.745443344116211……}
Key: __SPL
Value: {1:7085520.472989678,0:4662610.912284017}

最后一步,先看源代码,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//calculate the Thetas, write out to LABEL_THETA_NORMALIZER vectors
//-- TODO: add reference here to the part of the Rennie paper that discusses this
Job thetaSummer =
prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(THETAS),
            SequenceFileInputFormat.class,
            ThetaMapper.class,
            Text.class,
            VectorWritable.class,
            VectorSumReducer.class,
            Text.class,
            VectorWritable.class,
            SequenceFileOutputFormat.class);
    thetaSummer.setCombinerClass(VectorSumReducer.class);
    thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
    thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
    /* TODO(robinanil): Enable this when thetanormalization works.
    succeeded = thetaSummer.waitForCompletion(true);
    if (!succeeded) {
      return -1;
}*/

可以看到thetaSummer.waitForCompletion(true)被注释掉了,job没有执行。注释里面说的Rennie paper指的就是mahout bayes算法参考的这篇论文:Tackling the Poor Assumptions of Naive Bayes Text Classifiers,论文里面有个求Ɵ的公式如下。不知为何注释掉?求解。

最最后一步,其实model有weightsPerFeature和weightsPerLabel就完成了。这一步也就是把它们变成矩阵形式,如下,每行一个权重vector。
____|item1,iterm2,item3……
lab1|
lab2|
……

源代码如下,

1
2
3
4
5
//得到SparseMatrix矩阵
NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf());
naiveBayesModel.validate();
//序列化,写到output/naiveBayesModel.bin
naiveBayesModel.serialize(getOutputPath(), getConf());

THE END

 

http://hnote.org/big-data/mahout/mahout-train-naive-bayes-job

 

 

http://soledede.com/

 

大家可以加我个人微信号:scccdgf

 

 

或者关注soledede的微信公众号:soledede
微信公众号:
分享到:
评论

相关推荐

    Mahout源码

    Mahout 构建在Hadoop之上,利用MapReduce进行分布式计算。这意味着,对于处理大量数据,Mahout 可以在多台机器上并行运行,大大提高了计算效率。Hadoop的输入/输出机制与Mahout相结合,使得大数据处理变得简单易行。...

    mahout源码

    在大数据时代,Mahout已经成为数据科学家和工程师们的重要工具,尤其在文本分析、推荐系统和分类任务中扮演着关键角色。本篇将深入探讨Mahout中的朴素贝叶斯分类以及中文分词这两个核心功能。 一、Mahout与朴素...

    Mahout教程内含源码以及说明书可以自己运行复现.zip

    安装Mahout首先需要准备Hadoop环境,因为Mahout是构建在Hadoop之上的。你需要下载并安装Hadoop,配置Hadoop环境变量,并确保集群运行正常。接着,从Apache官方网站获取Mahout的最新版本,解压后将其添加到你的系统...

    mahout0.9源码(支持hadoop2)

    mahout0.9的源码,支持hadoop2,需要自行使用mvn编译。mvn编译使用命令: mvn clean install -Dhadoop2 -Dhadoop.2.version=2.2.0 -DskipTests

    mahout in action中的源码

    《Mahout in Action》是一本深入探讨Apache Mahout机器学习框架的专业书籍,其源码提供了丰富的实践示例和深入理解Mahout算法的机会。在GitHub上,你可以找到这些源码的完整版本,链接为。下面,我们将详细探讨...

    mahout0.9 源码

    以上就是关于Mahout 0.9源码及其在Eclipse中的使用介绍。通过学习和实践,开发者可以利用Mahout构建强大的机器学习应用,处理各种数据挖掘任务。在实际应用中,可以根据项目需求选择合适的算法,结合Hadoop分布式...

    Mahout之Item-based应用使用

    《Mahout之Item-based应用使用》 Apache Mahout是一个开源的机器学习库,主要专注于大规模数据集上的推荐系统、分类和聚类算法。在这个主题中,我们将深入探讨Mahout中的Item-based协同过滤(Item-based ...

    人工智能-推荐系统-新闻推荐-基于Mahout的新闻推荐系统

    Mahout:整体框架,实现了协同过滤 Deeplearning4j,构建VSM Jieba:分词,关键词提取 HanLP:分词,关键词提取 Spring Boot:提供API、ORM 关键实现 基于用户的协同过滤 直接调用Mahout相关接口即可 选择不同...

    mahout-distribution-0.5-src.zip mahout 源码包

    mahout-distribution-0.5-src.zip mahout 源码包

    mahout所需jar包

    Mahout的目标是帮助开发人员构建智能应用程序,如推荐系统、分类和聚类算法,这些在大数据分析领域中极为重要。 **K-Means聚类算法** K-Means是一种无监督学习的聚类算法,用于将数据集分成不同的群组或类别。在...

    mahout-core-0.9.jar+mahout-core-0.8.jar+mahout-core-0.1.jar

    Mahout是建立在Hadoop之上的,利用其分布式计算能力处理大规模数据集。这使得Mahout能够处理超出单台机器内存和计算能力的数据。 3. **版本差异**: - mahout-core-0.1.jar:这是早期版本,可能包含的基本功能,...

    [Mahout] Windows下Mahout单机安装

    打开命令行,进入解压后的Mahout源码目录,执行以下Maven命令来构建Mahout: ``` mvn clean install -DskipTests ``` 这个过程可能会比较耗时,因为Maven会自动下载所有依赖。等待编译完成后,Mahout的可执行jar文件...

    apache-mahout-distribution-0.11.0-src.zip

    在源码中,您可以探索Mahout实现的各种算法,如协同过滤(Collaborative Filtering)、频繁项集挖掘(Frequent Itemset Mining)、近邻搜索(Nearest Neighbor Search)等。这些算法是通过Java编程语言实现的,因此...

    【甘道夫】通过Mahout构建贝叶斯文本分类器案例详解 -- 配套源码

    Apache Mahout是一个基于Hadoop的机器学习库,它提供了一系列的算法,包括聚类、分类和协同过滤,用于大数据分析。贝叶斯分类器是其中一种常用的文本分类方法,因其简单高效而在实际应用中广泛使用。 首先,我们要...

    mahout-distribution-0.7-src.zip

    《Apache Mahout 0.7源码解析与应用探索》 Apache Mahout 是一个开源机器学习库,专注于大规模数据集的算法实现。该库由Java编写,并采用Maven作为构建工具,提供了一系列用于构建智能应用的高效算法。本文将深入...

    spring-mahout-demo

    在项目中,"bhagyas-spring-mahout-demo-f25ecc6"这个文件可能是项目的源码仓库,包含了整个示例的完整代码。通过分析这些代码,我们可以学习到以下几点关键知识点: 1. **Spring配置**:Spring配置文件中会包含对...

    maven_mahout_template-mahout-0.8

    《Apache Maven与Mahout实战:基于maven_mahout_template-mahout-0.8的探索》 Apache Maven是一款强大的项目管理和依赖管理工具,广泛应用于Java开发领域。它通过一个项目对象模型(Project Object Model,POM)来...

Global site tag (gtag.js) - Google Analytics