`
fighting_2013
  • 浏览: 15704 次
  • 性别: Icon_minigender_1
社区版块
存档分类
最新评论

数据挖掘笔记-分类-决策树-4

阅读更多

之前写的代码都是单机上跑的,发现现在很流行hadoop,所以又试着用hadoop mapreduce来处理下决策树的创建。因为hadoop接触的也不多,所以写的不好,勿怪。

 

看了一些mahout在处理决策树和随机森林的过程,大体过程是Job只有一个Mapper处理,在map方法里面做数据的转换收集工作,然后在cleanup方法里面去做决策树的创建过程。然后将决策树序列化到HDFS上面,分类样本数据集的时候,在从HDFS上面取回决策树结构。大体来说,mahout决策树的构建过程好像并没有结合分布式计算,因为我也并没有仔仔细细的去研读mahout里面的源码,所以可能是我没发现。下面是我实现的一个简单hadoop版本决策树,用的C4.5算法,通过MapReduce去计算增益率。最后生成的决策树并未保存在HDFS上面,后面有时间在考虑下吧。下面是具体代码实现:

 

public class DecisionTreeC45Job extends AbstractJob {
	
	/** 对数据集做准备工作,主要就是将填充好默认值的数据集再次传到HDFS上*/
	public String prepare(Data trainData) {
		String path = FileUtils.obtainRandomTxtPath();
		DataHandler.writeData(path, trainData);
		System.out.println(path);
		String name = path.substring(path.lastIndexOf(File.separator) + 1);
		String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;
		HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);
		return hdfsPath;
	}
	
	/** 选择最佳属性,读取MapReduce计算后产生的文件,取增益率最大*/
	public AttributeGainWritable chooseBestAttribute(String output) {
		AttributeGainWritable maxAttribute = null;
		Path path = new Path(output);
		try {
			FileSystem fs = path.getFileSystem(conf);
			Path[] paths = HDFSUtils.getPathFiles(fs, path);
			ShowUtils.print(paths);
			double maxGainRatio = 0.0;
			SequenceFile.Reader reader = null;
			for (Path p : paths) {
				reader = new SequenceFile.Reader(fs, p, conf);
				Text key = (Text) ReflectionUtils.newInstance(
						reader.getKeyClass(), conf);
				AttributeGainWritable value = new AttributeGainWritable();
				while (reader.next(key, value)) {
					double gainRatio = value.getGainRatio();
					if (gainRatio >= maxGainRatio) {
						maxGainRatio = gainRatio;
						maxAttribute = value;
					}
					value = new AttributeGainWritable();
				}
				IOUtils.closeQuietly(reader);
			}
			System.out.println("output: " + path.toString());
			HDFSUtils.delete(conf, path);
			System.out.println("hdfs delete file : " + path.toString());
		} catch (IOException e) {
			e.printStackTrace();
		}
		return maxAttribute;
	}
	
	/** 构造决策树 */
	public Object build(String input, Data data) {
		Object preHandleResult = preHandle(data);
		if (null != preHandleResult) return preHandleResult;
		String output = HDFSUtils.HDFS_TEMP_OUTPUT_URL;
		HDFSUtils.delete(conf, new Path(output));
		System.out.println("delete output path : " + output);
		String[] paths = new String[]{input, output};
		//通过MapReduce计算增益率
		CalculateC45GainRatioMR.main(paths);
		
		AttributeGainWritable bestAttr = chooseBestAttribute(output);
		String attribute = bestAttr.getAttribute();
		System.out.println("best attribute: " + attribute);
		System.out.println("isCategory: " + bestAttr.isCategory());
		if (bestAttr.isCategory()) {
			return attribute;
		}
		String[] splitPoints = bestAttr.obtainSplitPoints();
		System.out.print("splitPoints: ");
		ShowUtils.print(splitPoints);
		TreeNode treeNode = new TreeNode(attribute);
		String[] attributes = data.getAttributesExcept(attribute);
		
		//分割数据集,并将分割后的数据集传到HDFS上
		DataSplit dataSplit = DataHandler.split(new Data(
				data.getInstances(), attribute, splitPoints));
		for (DataSplitItem item : dataSplit.getItems()) {
			String path = item.getPath();
			String name = path.substring(path.lastIndexOf(File.separator) + 1);
			String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;
			HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);
			treeNode.setChild(item.getSplitPoint(), build(hdfsPath, 
					new Data(attributes, item.getInstances())));
		}
		return treeNode;
	}
	
	/** 分类,根据决策树节点判断测试样本集的类型,并将结果上传到HDFS上*/
	private void classify(TreeNode treeNode, String trainSet, String testSet, String output) {
		OutputStream out = null;
		BufferedWriter writer = null;
		try {
			Path trainSetPath = new Path(trainSet);
			FileSystem trainFS = trainSetPath.getFileSystem(conf);
			Path[] trainHdfsPaths = HDFSUtils.getPathFiles(trainFS, trainSetPath);
			FSDataInputStream trainFSInputStream = trainFS.open(trainHdfsPaths[0]);
			Data trainData = DataLoader.load(trainFSInputStream, true);
			
			Path testSetPath = new Path(testSet);
			FileSystem testFS = testSetPath.getFileSystem(conf);
			Path[] testHdfsPaths = HDFSUtils.getPathFiles(testFS, testSetPath);
			FSDataInputStream fsInputStream = testFS.open(testHdfsPaths[0]);
			Data testData = DataLoader.load(fsInputStream, true);
			
			DataHandler.fill(testData.getInstances(), trainData.getAttributes(), 0);
			Object[] results = (Object[]) treeNode.classify(testData);
			ShowUtils.print(results);
			DataError dataError = new DataError(testData.getCategories(), results);
			dataError.report();
			String path = FileUtils.obtainRandomTxtPath();
			out = new FileOutputStream(new File(path));
			writer = new BufferedWriter(new OutputStreamWriter(out));
			StringBuilder sb = null;
			for (int i = 0, len = results.length; i < len; i++) {
				sb = new StringBuilder();
				sb.append(i+1).append("\t").append(results[i]);
				writer.write(sb.toString());
				writer.newLine();
			}
			writer.flush();
			Path outputPath = new Path(output);
			FileSystem fs = outputPath.getFileSystem(conf);
			if (!fs.exists(outputPath)) {
				fs.mkdirs(outputPath);
			}
			String name = path.substring(path.lastIndexOf(File.separator) + 1);
			HDFSUtils.copyFromLocalFile(conf, path, output + 
					File.separator + name);
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			IOUtils.closeQuietly(out);
			IOUtils.closeQuietly(writer);
		}
	}
	
	public void run(String[] args) {
		try {
			if (null == conf) conf = new Configuration();
			String[] inputArgs = new GenericOptionsParser(
					conf, args).getRemainingArgs();
			if (inputArgs.length != 3) {
				System.out.println("error, please input three path.");
				System.out.println("1. trainset path.");
				System.out.println("2. testset path.");
				System.out.println("3. result output path.");
				System.exit(2);
			}
			Path input = new Path(inputArgs[0]);
			FileSystem fs = input.getFileSystem(conf);
			Path[] hdfsPaths = HDFSUtils.getPathFiles(fs, input);
			FSDataInputStream fsInputStream = fs.open(hdfsPaths[0]);
			Data trainData = DataLoader.load(fsInputStream, true);
			/** 填充缺失属性的默认值*/
			DataHandler.fill(trainData, 0);
			String hdfsInput = prepare(trainData);
			TreeNode treeNode = (TreeNode) build(hdfsInput, trainData);
			TreeNodeHelper.print(treeNode, 0, null);
			classify(treeNode, inputArgs[0], inputArgs[1], inputArgs[2]);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
	
	public static void main(String[] args) {
		DecisionTreeC45Job job = new DecisionTreeC45Job();
		long startTime = System.currentTimeMillis();
		job.run(args);
		long endTime = System.currentTimeMillis();
		System.out.println("spend time: " + (endTime - startTime));
	}

}

CalculateC45GainRatioMR具体实现:

public class CalculateC45GainRatioMR {
	
	private static void configureJob(Job job) {
		job.setJarByClass(CalculateC45GainRatioMR.class);
		
		job.setMapperClass(CalculateC45GainRatioMapper.class);
		job.setMapOutputKeyClass(Text.class);
		job.setMapOutputValueClass(AttributeWritable.class);

		job.setReducerClass(CalculateC45GainRatioReducer.class);
		job.setOutputKeyClass(Text.class);
		job.setOutputValueClass(AttributeGainWritable.class);
		
		job.setInputFormatClass(TextInputFormat.class);
		job.setOutputFormatClass(SequenceFileOutputFormat.class);
	}

	public static void main(String[] args) {
		Configuration configuration = new Configuration();
		try {
			String[] inputArgs = new GenericOptionsParser(
						configuration, args).getRemainingArgs();
			if (inputArgs.length != 2) {
				System.out.println("error, please input two path. input and output");
				System.exit(2);
			}
			Job job = new Job(configuration, "Decision Tree");
			
			FileInputFormat.setInputPaths(job, new Path(inputArgs[0]));
			FileOutputFormat.setOutputPath(job, new Path(inputArgs[1]));
			
			configureJob(job);
			
			System.out.println(job.waitForCompletion(true) ? 0 : 1);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
}

class CalculateC45GainRatioMapper extends Mapper<LongWritable, Text, 
	Text, AttributeWritable> {
	
	@Override
	protected void setup(Context context) throws IOException,
			InterruptedException {
		super.setup(context);
	}

	@Override
	protected void map(LongWritable key, Text value, Context context)
			throws IOException, InterruptedException {
		String line = value.toString();
		StringTokenizer tokenizer = new StringTokenizer(line);
		Long id = Long.parseLong(tokenizer.nextToken());
		String category = tokenizer.nextToken();
		boolean isCategory = true;
		while (tokenizer.hasMoreTokens()) {
			isCategory = false;
			String attribute = tokenizer.nextToken();
			String[] entry = attribute.split(":");
			context.write(new Text(entry[0]), new AttributeWritable(id, category, entry[1]));
		}
		if (isCategory) {
			context.write(new Text(category), new AttributeWritable(id, category, category));
		}
	}
	
	@Override
	protected void cleanup(Context context) throws IOException, InterruptedException {
		super.cleanup(context);
	}
}

class CalculateC45GainRatioReducer extends Reducer<Text, AttributeWritable, Text, AttributeGainWritable> {
	
	@Override
	protected void setup(Context context) throws IOException, InterruptedException {
		super.setup(context);
	}
	
	@Override
	protected void reduce(Text key, Iterable<AttributeWritable> values,
			Context context) throws IOException, InterruptedException {
		String attributeName = key.toString();
		double totalNum = 0.0;
		Map<String, Map<String, Integer>> attrValueSplits = 
				new HashMap<String, Map<String, Integer>>();
		Iterator<AttributeWritable> iterator = values.iterator();
		boolean isCategory = false;
		while (iterator.hasNext()) {
			AttributeWritable attribute = iterator.next();
			String attributeValue = attribute.getAttributeValue();
			if (attributeName.equals(attributeValue)) {
				isCategory = true;
				break;
			}
			Map<String, Integer> attrValueSplit = attrValueSplits.get(attributeValue);
			if (null == attrValueSplit) {
				attrValueSplit = new HashMap<String, Integer>();
				attrValueSplits.put(attributeValue, attrValueSplit);
			}
			String category = attribute.getCategory();
			Integer categoryNum = attrValueSplit.get(category);
			attrValueSplit.put(category, null == categoryNum ? 1 : categoryNum + 1);
			totalNum++;
		}
		if (isCategory) {
			System.out.println("is Category");
			int sum = 0;
			iterator = values.iterator();
			while (iterator.hasNext()) {
				iterator.next();
				sum += 1;
			}
			System.out.println("sum: " + sum);
			context.write(key, new AttributeGainWritable(attributeName,
					sum, true, null));
		} else {
			double gainInfo = 0.0;
			double splitInfo = 0.0;
			for (Map<String, Integer> attrValueSplit : attrValueSplits.values()) {
				double totalCategoryNum = 0;
				for (Integer categoryNum : attrValueSplit.values()) {
					totalCategoryNum += categoryNum;
				}
				double entropy = 0.0;
				for (Integer categoryNum : attrValueSplit.values()) {
					double p = categoryNum / totalCategoryNum;
					entropy -= p * (Math.log(p) / Math.log(2));
				}
				double dj = totalCategoryNum / totalNum;
				gainInfo += dj * entropy;
				splitInfo -= dj * (Math.log(dj) / Math.log(2));
			}
			double gainRatio = splitInfo == 0.0 ? 0.0 : gainInfo / splitInfo;
			StringBuilder splitPoints = new StringBuilder();
			for (String attrValue : attrValueSplits.keySet()) {
				splitPoints.append(attrValue).append(",");
			}
			splitPoints.deleteCharAt(splitPoints.length() - 1);
			System.out.println("attribute: " + attributeName);
			System.out.println("gainRatio: " + gainRatio);
			System.out.println("splitPoints: " + splitPoints.toString());
			context.write(key, new AttributeGainWritable(attributeName,
					gainRatio, false, splitPoints.toString()));
		}
	}
	
	@Override
	protected void cleanup(Context context) throws IOException, InterruptedException {
		super.cleanup(context);
	}
	
}

 

 

分享到:
评论

相关推荐

    efficient-decision-tree-notes高效决策树算法系列笔记

    高效决策树算法是数据挖掘和机器学习领域中的一个重要工具,尤其在分类问题中表现出色。这一系列笔记将深入探讨如何构建高效、准确的决策树模型。决策树是一种以树状结构进行决策的模型,其中每个内部节点代表一个...

    基于C4.5决策树的大学生笔记本电脑购买行为的数据挖掘.pdf

    在数据挖掘领域,决策树算法是一种常用的分类方法,它通过一系列规则对数据进行分类或回归。C4.5决策树是决策树算法的一种改进形式,由Ross Quinlan开发,它继承了ID3决策树处理离散型属性的能力,并且还能够处理...

    山东大学数据挖掘期末复习笔记.pdf

    在数据挖掘中,常用的分类方法有KNN、决策树、朴素贝叶斯分类等。 KNN算法是指K-Nearest Neighbors算法,该算法通过计算测试样本与训练样本之间的距离来预测测试样本的类别。 决策树算法是指使用决策树来分类数据...

    《数据挖掘概念与技术》-思维导图学习笔记,第一章。

    5. 数据挖掘技术:常见的数据挖掘技术包括决策树、贝叶斯网络、支持向量机、聚类算法如K-means和DBSCAN,以及关联规则算法如Apriori。这些技术各有优缺点,适用于不同的数据特性和问题场景。 6. 数据挖掘的应用领域...

    《数据挖掘技术》课程学习笔记

    分类算法如决策树(C4.5, ID3)、随机森林和神经网络,它们能根据已有数据构建模型,预测未知数据的类别。聚类算法如K-means、层次聚类和DBSCAN,则是无监督学习方法,用于发现数据的自然分组。关联规则学习,如...

    决策树随堂笔记.pdf

    决策树是一种常用的数据挖掘方法,尤其在机器学习领域中占据着重要的地位。它通过一系列基于数据属性的判断规则,将数据集分割成不同的类别或数值预测。Spark 是一个开源的大数据处理框架,它提供了MLlib库,其中...

    [浙大-数据挖掘].1-10\4.rar [浙大-数据挖掘].1-10\4.rar

    在浙江大学的数据挖掘课程中,可能会涵盖这些基本概念,同时深入到更具体的算法和技术,如SVM(支持向量机)、决策树、神经网络、Apriori算法、K-means聚类等。此外,还可能涉及数据库管理系统、统计学基础、机器...

    Python版数据挖掘实验4报告:用决策树预测获胜球队.pdf

    ### Python版数据挖掘实验4报告:用决策树预测获胜球队 #### 实验名称与目的 本次实验名为“用决策树预测获胜球队”。其主要目的是利用机器学习中的决策树算法来预测篮球比赛中哪支球队可能获胜。这不仅是一次理论...

    数据挖掘笔记思维导图1

    分类和预测任务中,支持向量机(SVM)、决策树、贝叶斯网络和神经网络是常用的模型。SVM通过构造最大分类间隔的超平面实现分类,对于非线性问题,它引入了核函数进行映射。贝叶斯网络则利用概率和条件概率来表示变量间...

    机器学习与数据挖掘学习笔记.zip

    《机器学习与数据挖掘学习笔记》是一份综合性的学习资料,涵盖了这两个领域的重要概念、算法和技术。这份笔记的目的是为了帮助读者深入理解机器学习和数据挖掘的基础知识,并提供实际操作的指导。 首先,我们来探讨...

    数据挖掘十大算法详解.zip

    数据挖掘十大算法详解,数据挖掘学习笔记--决策树C4.5 、数据挖掘十大算法--K-均值聚类算法 、机器学习与数据挖掘-支持向量机(SVM)、拉格朗日对偶、支持向量机(SVM)(三)-- 最优间隔分类器 (optimal margin ...

    数据挖掘完整项目/课堂记录笔记/比赛代码

    3. 学习和实践各种数据挖掘算法,如决策树、随机森林、支持向量机和神经网络等。 4. 了解如何在大数据环境中实现模型的训练和验证。 5. 提升问题解决能力,通过比赛代码学习如何解决实际问题并优化模型性能。 这个...

    机器学习&数据挖掘笔记_16(常见面试之机器学习算法思想简单梳理)1

    决策树算法是一种直观且易于理解的分类和回归方法。它通过学习一系列的决策规则,将特征空间划分为多个子空间,并且这些子空间对应于不同的类别标签。决策树的核心是选择最优的属性进行分割,这通常依据信息增益或...

    数据仓库笔记

    数据仓库笔记的知识点涵盖了数据仓库和数据挖掘的基本概念、数据挖掘的主要任务与方法、学习算法以及搭建数据仓库的相关知识。下面将详细阐述这些知识点。 首先,数据仓库是为了企业决策支持而设计的系统,它主要...

    数据挖掘资料(吐血汇总).rar

    "数据挖掘笔记"这部分内容可能是学习者对所学知识的整理,包括关键概念的总结、公式解析、算法实现步骤等,对于初学者来说,这是一份极具价值的参考资料,能帮助他们更好地理解和记忆复杂的知识点。 "习题"则提供了...

    06数据挖掘21

    数据挖掘中的分类技术 数据挖掘是一种常用的数据分析技术,旨在从大量数据中提取有价值的信息。数据挖掘技术可以分为多种类型,包括分类、预测、聚类、关联规则等。其中,分类是数据挖掘中的一种重要技术,旨在对...

    斯坦福大学CS345A 数据挖掘 课程所有课件(pdf+ppt)

    分类算法如决策树、随机森林和支持向量机,用于将数据分成不同的类别。聚类方法如K-means和层次聚类则用于无监督学习,帮助发现数据的自然分组。关联规则学习如Apriori算法常用于市场篮子分析,找出商品之间的购买...

    《数据挖掘》读书笔记.pdf

    《数据挖掘》读书笔记主要涵盖了数据可视化、建模方法、数据挖掘技术和预测分析的应用。作者Philipp K. Janer凭借其在物理学和软件工程领域的深厚背景,为读者提供了丰富的数据分析和数学建模知识。 在全书中,作者...

    基于 jupyterlab的决策树模型,decision_tree.zip

    在本项目中,我们主要探讨的是如何在JupyterLab环境下使用Python进行数据挖掘,并通过决策树模型对数据进行分析。JupyterLab是一个交互式的开发环境,适合数据分析、机器学习等任务,而决策树是一种常见的监督学习...

    ch04 决策树_学习笔记1

    决策树是一种广泛应用于机器学习和数据挖掘中的分类和回归模型,它的主要特点是通过构建树状结构来模拟一系列的决定过程。在本章的学习笔记中,我们聚焦于决策树的生成流程、属性划分的选择以及剪枝处理,同时也涉及...

Global site tag (gtag.js) - Google Analytics