`

[Mahout] 使用Mahout对iris数据进行分析 - Logistic Regression

 
阅读更多

在mahout的官网上面,有讲诉如何在命令行之中使用Logistic Regression对自带的donut.csv进行训练的例子。

现在我们要做的,是自己在java代码之中对iris的数据使用LR进行分析。

 

首先,我们要熟悉一下,使用LR需要哪些参数以及他们的作用。我们从《mahout实战》上面给出的命令行例子来了解一下:

 

$ bin/mahout trainlogistic --input donut.csv \
--output ./model \
--target color --categories 2 \
--predictors x y --types numeric \
--features 20 --passes 100 --rate 50

 

 

简单说明一下:

--input: 输入的文件

--output: 输出的模型存放的文件

--target: 目标变量名

--categories: 有几个分类

--predictors: 使用哪些属性进行预测。在上面的命令行之中只使用了x跟y两个属性

--type: 预测变量的类型,除了numeric, 还有word,text.

--passes: 对于小样本数据,可以多循环几次,对于大型数据样本,1次即可

--rate: 学习率

--features: 不知道中文如何描述,我对LR的理解还不够深入。英文描述:Sets the size of the internal feature vector to use in building the model. A larger value here can be helpful, especially with text-like input data

 

命令trainlogistic 对应着org.apache.mahout.classifier.sgd.TrainLogistic.java. 这是训练模型的代码。相应的,还有运行模型的代码:org.apache.mahout.classifier.sgd.RunLogistic.java

 

在大概了解之后,我们开始针对iris的数据进行实际操作一把:

 

package org.apache.mahout.classifier.sgd;

import java.io.File;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.List;
import java.util.Locale;

import org.apache.commons.io.FileUtils;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;

public class IrisLRTest {

	private static LogisticModelParameters lmp;
	private static PrintWriter output;

	public static void main(String[] args) throws IOException {
		// 1: new
		lmp = new LogisticModelParameters();
		output = new PrintWriter(new OutputStreamWriter(System.out,
				Charsets.UTF_8), true);

		// 2: init params
		lmp.setLambda(0.001);
		lmp.setLearningRate(50);
		lmp.setMaxTargetCategories(3); //总共有3种iris
		lmp.setNumFeatures(4);         //看起来除了class只有4种属性,先设定为4
		List<String> targetCategories = Lists.newArrayList("Iris-setosa", "Iris-versicolor", "Iris-versicolor");  //这里使用的是guava里面的api
		lmp.setTargetCategories(targetCategories);
		lmp.setTargetVariable("class"); // 需要进行预测的是class属性
		List<String> typeList = Lists.newArrayList("numeric", "numeric", "numeric", "numeric");
		List<String> predictorList = Lists.newArrayList("sepallength", "sepalwidth", "petallength", "petalwidth");
		lmp.setTypeMap(predictorList, typeList);

		// 3. load data
		List<String> raw = FileUtils.readLines(new File(
				"E:\\DataSet\\R\\iris.csv")); //使用common-io进行文件读取
		String header = raw.get(0);
		List<String> content = raw.subList(1, raw.size());
		// parse data
		CsvRecordFactory csv = lmp.getCsvRecordFactory();
		csv.firstLine(header); // !!!Note: this is a initialize step, do not
								// skip this step

		// 4. begin to train
		OnlineLogisticRegression lr = lmp.createRegression();
		for(int i = 0; i < 100; i++) {  //对于小数据集我们多运行几次
			for (String line : content) {
				Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
				int targetValue = csv.processLine(line, input);
				lr.train(targetValue, input);  // 核心的一句!!!
			}
		}

		// 5. show model performance: show classify score
		double correctRate = 0;
		double sampleCount = content.size();
		
		for (String line : content) {
			Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
			int target = csv.processLine(line, v);
			int score = lr.classifyFull(v).maxValueIndex();  // 分类核心语句!!!
			System.out.println("Target:" + target + "\tReal:" + score);
			if(score == target) {
				correctRate++;
			}
		}
		output.printf(Locale.ENGLISH, "Rate = %.2f%n", correctRate / sampleCount);
	}

}

 运行结果:Rate = 0.90

 

 

在上面的代码中,要注意的是:

1. 注意所有必需的参数一定要都设定好并设定正确

2. 在必要的参数初始化之后,才能正确的getCsvRecordFactor 跟 createRegression. 否则会遇到空指针异常

 

 

为了对模型进行调优,我们可以做如下事情:

1. 设定更大的numFeatures. 当前是4,我们设定为5、10、20 。。。 

2. 设定更大的循环次数,当前是100, 我们可以设定为200、300 ==

 

最终,我设定的参数:

numFeature = 5

passes = 40

结果: Rate = 0.98

0
0
分享到:
评论

相关推荐

    [Mahout] 使用Mahout 对Kddcup 1999的数据进行分析 -- Naive Bayes

    4. **预测**:用训练好的模型对新数据进行分类预测。 **6. 应用与优化** 在实际应用中,可能需要对模型进行调优,比如调整参数或者采用特征选择策略。此外,由于Kddcup 1999数据集非常大,利用Hadoop分布式计算的...

    apache-mahout-distribution-0.11.0-src.zip

    在"apache-mahout-distribution-0.11.0-src.zip"这个压缩包中,您将找到Mahout 0.11.0版本的源代码,这对于开发者和研究者来说是一个宝贵的资源,他们可以深入理解算法的内部工作原理,进行定制化开发或优化。...

    mahout-core-0.9.jar+mahout-core-0.8.jar+mahout-core-0.1.jar

    - mahout-core-0.9.jar:作为最新版本,它集成了更多的改进和新特性,包括算法的优化、API的调整以及对大数据处理的进一步支持。 4. **API变化**: 随着版本的更新,Mahout的API可能会发生变化,比如引入新的接口...

    测试mahout推荐算法的数据集

    【推荐算法】是一种重要的机器学习...通过对Chubbyjiang在GitHub上分享的数据集进行分析和处理,我们可以深入理解Mahout的协同过滤算法以及MapReduce在大数据环境下的工作原理,从而构建出更高效、更精准的推荐系统。

    mahout-0.11.1 相关的jar

    mahout-examples-0.11.1 mahout-examples-0.11.1-job mahout-h2o_2.10-0.11.1 mahout-h2o_2.10-0.11.1-dependency-reduced mahout-hdfs-0.11.1 mahout-integration-0.11.1 mahout-math-0.11.1 mahout-math-0.11.1 ...

    mahout-distribution-0.9.tar.gz

    2. **分类与回归**:Mahout支持决策树(如C4.5)、随机森林和感知机等算法,用于对数据进行分类和预测。 3. **聚类**:包括K-Means、Fuzzy K-Means、Canopy Clustering、DBSCAN等算法,可用于将相似的数据点分组到...

    mahout-distribution-0.8-src

    3. **分类(Classification)**:如随机森林(Random Forest)、朴素贝叶斯(Naive Bayes)等,用于根据已知特征对数据进行分类。这些模型在垃圾邮件过滤、情感分析等方面表现出色。 三、分布式计算与Hadoop集成 ...

    mahout所需jar包

    4. **预处理数据**:如果需要,可以使用Mahout的工具对数据进行预处理,例如规范化或归一化。 5. **运行K-Means**:使用Mahout提供的命令行接口,指定输入数据、K值(预定义的群组数量)、迭代次数和其他参数。 6. *...

    如何成功运行Apache Mahout的Taste Webapp-Mahout推荐教程-Maven3.0.5-JDK1.6-Mahout0.5

    在Mahout Taste Webapp工程中,需要添加对mahout-examples的依赖,这一步骤是必须的,因为示例代码提供了实际运行推荐系统所必需的组件。 6. 配置推荐引擎的属性 在Mahout Taste Webapp的recommender.properties...

    maven_mahout_template-mahout-0.8

    `maven_mahout_template-mahout-0.8`这个项目模板,是为使用Maven构建的Mahout项目提供的一种基础架构。它包含了配置文件、依赖管理和项目结构,使得开发者可以快速地搭建起一个基于Mahout的项目环境,进行机器学习...

    mahout-integration-0.7

    mahout-integration-0.7mahout-integration-0.7mahout-integration-0.7mahout-integration-0.7

    apache-mahout-distribution-0.10.2

    这个"apache-mahout-distribution-0.10.2"压缩包包含的是Mahout的0.10.2版本,该版本是2014年发布的一个稳定版本,旨在帮助大数据研发人员构建和实现复杂的机器学习模型。 在大数据领域,机器学习是关键的技术之一...

    mahout-distribution-0.7-src.zip

    2. 解压`mahout-distribution-0.7-src.zip`文件到本地目录。 3. 进入解压后的源码目录,执行`mvn clean install`命令进行编译。这会下载依赖项,构建Mahout的jar包。 4. 编译完成后,可以在`target`目录下找到编译...

    mahout-distribution-0.9-src.zip

    标题中的"mahout-distribution-0.9-src.zip"指的是Mahout项目在0.9版本的源代码分布,这对于开发者来说是一个宝贵的资源,可以深入理解其内部实现并进行定制化开发。 Apache Mahout的核心特性主要体现在以下几个...

    mahout-distribution-0.12.2-src.tar.gz

    这个压缩包“mahout-distribution-0.12.2-src.tar.gz”是Mahout项目的一个源码版本,版本号为0.12.2,提供给开发者进行深度研究和定制化开发。在解压后的文件“apache-mahout-distribution-0.12.2”中,我们可以找到...

    mahout文本训练测试数据

    **正文** 《Mahout文本训练测试...通过对这些数据进行处理和建模,我们可以提升对Mahout的理解,同时也能掌握如何处理大规模文本数据的技巧,这对于任何涉足大数据和人工智能领域的专业人士来说都是至关重要的技能。

    mahout-distribution-0.9含jar包

    "mahout-distribution-0.9含jar包" 是一个包含了Mahout项目0.9版本的预编译二进制文件集合,其中不包含源代码,适合那些希望直接使用Mahout功能而不需要进行编译或开发的用户。 在Mahout 0.9版本中,你可以找到以下...

    豆瓣电影大数据分析-【附带爬虫豆瓣,对数据处理,数据分析,可视化】

    主要是基于豆瓣电影的数据,进行分析,所以首先要爬取相关的电影数据,对应的源代码在DouBan_Spider目录下,主要是采用Python + BeautifulSoup + urllib进行数据采集 2:ETL预处理 3:数据分析 4:可视化 代码封装...

    Mahout-0.9-jar包

    2. **聚类**:包括K-Means、Fuzzy K-Means和Canopy Clustering等算法,可以对数据集进行无监督学习,将相似的数据点分组到一起,形成不同的簇。 3. **分类**:支持如Naive Bayes和Random Forest等监督学习算法,...

    深入浅出Hadoop Mahout数据挖掘实战 第04课-Mahout数据挖掘工具(4) 共9页.pptx

    【课程大纲】第01课-Mahout数据挖掘工具(1) 共9页第02课-Mahout数据挖掘工具(2) 共9页第03课-Mahout数据挖掘工具(3) 共12页第04课-Mahout数据挖掘工具(4) 共9页第05课-Mahout数据挖掘工具(5) 共11页第06课-Mahout...

Global site tag (gtag.js) - Google Analytics