`

[Mahout] 使用Mahout对iris数据进行分析 - Logistic Regression

 
阅读更多

在mahout的官网上面,有讲诉如何在命令行之中使用Logistic Regression对自带的donut.csv进行训练的例子。

现在我们要做的,是自己在java代码之中对iris的数据使用LR进行分析。

 

首先,我们要熟悉一下,使用LR需要哪些参数以及他们的作用。我们从《mahout实战》上面给出的命令行例子来了解一下:

 

$ bin/mahout trainlogistic --input donut.csv \
--output ./model \
--target color --categories 2 \
--predictors x y --types numeric \
--features 20 --passes 100 --rate 50

 

 

简单说明一下:

--input: 输入的文件

--output: 输出的模型存放的文件

--target: 目标变量名

--categories: 有几个分类

--predictors: 使用哪些属性进行预测。在上面的命令行之中只使用了x跟y两个属性

--type: 预测变量的类型,除了numeric, 还有word,text.

--passes: 对于小样本数据,可以多循环几次,对于大型数据样本,1次即可

--rate: 学习率

--features: 不知道中文如何描述,我对LR的理解还不够深入。英文描述:Sets the size of the internal feature vector to use in building the model. A larger value here can be helpful, especially with text-like input data

 

命令trainlogistic 对应着org.apache.mahout.classifier.sgd.TrainLogistic.java. 这是训练模型的代码。相应的,还有运行模型的代码:org.apache.mahout.classifier.sgd.RunLogistic.java

 

在大概了解之后,我们开始针对iris的数据进行实际操作一把:

 

package org.apache.mahout.classifier.sgd;

import java.io.File;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.List;
import java.util.Locale;

import org.apache.commons.io.FileUtils;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;

public class IrisLRTest {

	private static LogisticModelParameters lmp;
	private static PrintWriter output;

	public static void main(String[] args) throws IOException {
		// 1: new
		lmp = new LogisticModelParameters();
		output = new PrintWriter(new OutputStreamWriter(System.out,
				Charsets.UTF_8), true);

		// 2: init params
		lmp.setLambda(0.001);
		lmp.setLearningRate(50);
		lmp.setMaxTargetCategories(3); //总共有3种iris
		lmp.setNumFeatures(4);         //看起来除了class只有4种属性,先设定为4
		List<String> targetCategories = Lists.newArrayList("Iris-setosa", "Iris-versicolor", "Iris-versicolor");  //这里使用的是guava里面的api
		lmp.setTargetCategories(targetCategories);
		lmp.setTargetVariable("class"); // 需要进行预测的是class属性
		List<String> typeList = Lists.newArrayList("numeric", "numeric", "numeric", "numeric");
		List<String> predictorList = Lists.newArrayList("sepallength", "sepalwidth", "petallength", "petalwidth");
		lmp.setTypeMap(predictorList, typeList);

		// 3. load data
		List<String> raw = FileUtils.readLines(new File(
				"E:\\DataSet\\R\\iris.csv")); //使用common-io进行文件读取
		String header = raw.get(0);
		List<String> content = raw.subList(1, raw.size());
		// parse data
		CsvRecordFactory csv = lmp.getCsvRecordFactory();
		csv.firstLine(header); // !!!Note: this is a initialize step, do not
								// skip this step

		// 4. begin to train
		OnlineLogisticRegression lr = lmp.createRegression();
		for(int i = 0; i < 100; i++) {  //对于小数据集我们多运行几次
			for (String line : content) {
				Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
				int targetValue = csv.processLine(line, input);
				lr.train(targetValue, input);  // 核心的一句!!!
			}
		}

		// 5. show model performance: show classify score
		double correctRate = 0;
		double sampleCount = content.size();
		
		for (String line : content) {
			Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
			int target = csv.processLine(line, v);
			int score = lr.classifyFull(v).maxValueIndex();  // 分类核心语句!!!
			System.out.println("Target:" + target + "\tReal:" + score);
			if(score == target) {
				correctRate++;
			}
		}
		output.printf(Locale.ENGLISH, "Rate = %.2f%n", correctRate / sampleCount);
	}

}

 运行结果:Rate = 0.90

 

 

在上面的代码中,要注意的是:

1. 注意所有必需的参数一定要都设定好并设定正确

2. 在必要的参数初始化之后,才能正确的getCsvRecordFactor 跟 createRegression. 否则会遇到空指针异常

 

 

为了对模型进行调优,我们可以做如下事情:

1. 设定更大的numFeatures. 当前是4,我们设定为5、10、20 。。。 

2. 设定更大的循环次数,当前是100, 我们可以设定为200、300 ==

 

最终,我设定的参数:

numFeature = 5

passes = 40

结果: Rate = 0.98

0
0
分享到:
评论

相关推荐

    Logistic-Regression:逻辑回归的实现

    5. 预测:使用训练好的模型对新数据进行预测。 6. 评估:通过混淆矩阵、准确率等指标评估模型性能。 **自定义实现** 如果你选择自定义逻辑回归,你需要实现梯度下降法或最大似然估计法的算法,同时处理数据输入、...

    hadoop学习-基于hive的航空公司客户价值的LRFCM模型案例数据源

    Hive可能不直接支持这些算法,因此可能需要借助其他Hadoop生态系统的工具,如Mahout或Spark MLlib,对数据进行离线处理,然后将结果导入Hive进行后续分析。 最后,LRFCM模型的目的是评估和提升航空公司的客户价值。...

    机器学习常见算法

    应用场景包括分类和回归,算法包括一些对常用监督式学习算法的延伸,这些算法首先试图对未标识数据进行建模,在此基础上再对标识的数据进行预测。 4. 强化学习:在强化学习中,输入数据作为对模型的反馈,不像监督...

    你需要Spark的10个理由

    而如果基于Hadoop就需要分别构建实时流处理团队、数据统计分析团队、数据挖掘团队等,而且这些团队之间无论是代码还是经验都不可相互借鉴,会形成巨大的成本,而使用Spark就不存在这个问题; 6,Mahout前一阶段表示...

    MLNaiveBayesTextClassification

    2. **逻辑回归分类器(Logistic Regression Classifier)** - **基础原理**:逻辑回归是一种线性模型,用于二分类问题,通过sigmoid函数将连续的线性组合转换为0-1之间的概率值。 - **扩展到多项式分类**:在多...

    分类:不同分类算法的实现

    1. **逻辑回归(Logistic Regression)** 逻辑回归虽然名字中含有“回归”,但实际上是一种广义线性模型,常用于二分类问题。通过构建Sigmoid函数,将连续值转化为0和1之间的概率,进而确定数据属于某一类别的可能...

Global site tag (gtag.js) - Google Analytics