决策树
package decisiontree; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; public class DecisionTree { public static Map<String, Item> train(String[][] trainData) { Map<String, Item> model = new HashMap<String, Item>(); List<String[]> trainingDataList = new ArrayList<String[]>(); for (int i = 0; i < trainData.length; i++) { trainingDataList.add(trainData[i]); } Set<Integer> handledSet = new HashSet<Integer>(); train(model, Item.ROOT_EKY, handledSet, trainingDataList); return model; } private static void train(Map<String, Item> model, String currentKey, Set<Integer> handledSet, List<String[]> trainingDataList) { double entropyValue = getEntropyValue(trainingDataList); if (Math.abs(entropyValue) < Double.MIN_VALUE) { // all are the same value Item modelItem = new Item(currentKey, trainingDataList.get(0)[trainingDataList.get(0).length - 1]); model.put(modelItem.key, modelItem); } else { // not the only value double minEntropyValue = Double.MAX_VALUE; Map<String, List<String[]>> minEntropySplitDataMap = null; int minEntropyAttrIndex = -1; for (int i = 0; i <= trainingDataList.get(0).length - 2; i++) { if (!handledSet.contains(i)) { Map<String, List<String[]>> splitData = getSplitData(trainingDataList, i); entropyValue = getTotalEntropyValue(splitData); if (entropyValue < minEntropyValue) { minEntropySplitDataMap = splitData; minEntropyAttrIndex = i; minEntropyValue = entropyValue; } } } handledSet.add(minEntropyAttrIndex); if (minEntropySplitDataMap.size() == 1) { // there is only value in result list, skip this attribute; train(model, currentKey, handledSet, trainingDataList); } else { // there are more than one attribute value Item modelItem = new Item(currentKey, null); modelItem.currentIndex = minEntropyAttrIndex; model.put(modelItem.key, modelItem); for (String attrKey : minEntropySplitDataMap.keySet()) { String subKey = getKey(currentKey, minEntropyAttrIndex, attrKey); train(model, subKey, handledSet, minEntropySplitDataMap.get(attrKey)); } } handledSet.remove(minEntropyAttrIndex); } } private static String getKey(String parentKey, int attrIndex, String value) { String key = ""; if (parentKey == null || parentKey.trim().length() == 0) { key = String.valueOf(attrIndex) + "-" + value; } else { key = parentKey + "-" + String.valueOf(attrIndex) + "-" + value; } return key; } private static double getTotalEntropyValue(Map<String, List<String[]>> splitData) { double rtn = 0; for (List<String[]> itemList : splitData.values()) { rtn += getEntropyValue(itemList); } return rtn; } private static double getEntropyValue(List<String[]> splitData) { double rtn = 0; Map<String, AtomicInteger> countMap = new HashMap<String, AtomicInteger>(); for (String[] itemData : splitData) { String value = itemData[itemData.length - 1]; if (!countMap.containsKey(value)) { countMap.put(value, new AtomicInteger(0)); } countMap.get(value).getAndIncrement(); } for (AtomicInteger count : countMap.values()) { double probability = 1.0d * count.get() / splitData.size(); rtn -= probability * Math.log(probability) / Math.log(2.0); } return rtn; } private static Map<String, List<String[]>> getSplitData(List<String[]> data, int i) { Map<String, List<String[]>> rtn = new HashMap<String, List<String[]>>(); for (String[] itemData : data) { String value = itemData[i]; List<String[]> itemDataList = rtn.get(value); if (itemDataList == null) { itemDataList = new ArrayList<String[]>(); rtn.put(value, itemDataList); } itemDataList.add(itemData); } return rtn; } public static void saveModel(String fileName, Map<String, Item> model) { try (BufferedWriter writer = new BufferedWriter(new FileWriter(fileName));) { for (Item item : model.values()) { writer.write(item.toStr()); writer.write("\n"); } } catch (Exception e) { System.out.println("save Model error"); } } public static Map<String, Item> loadModel(String fileName) { Map<String, Item> model = new HashMap<String, Item>(); try (BufferedReader reader = new BufferedReader(new FileReader(fileName));) { String lineStr = reader.readLine(); while (lineStr != null) { if (lineStr.trim().length() > 0) { String[] itemStr = (lineStr + ",^^").split(","); if (itemStr.length == 4) { Item itme = new Item(itemStr[1], itemStr[2]); if (itemStr[0] != null && itemStr[0].trim().length() > 0) { itme.currentIndex = Integer.valueOf(itemStr[0]); } model.put(itme.key, itme); } else { System.out.println("Error model line:" + lineStr); } } lineStr = reader.readLine(); } } catch (Exception e) { System.out.println("load model error"); } return model; } public static String getValue(Map<String, Item> model, String[] fieldValues) { String rtn = null; if (model != null && model.size() > 0) { rtn = getValueFromModel(model, Item.ROOT_EKY, fieldValues); } return rtn; } private static String getValueFromModel(Map<String, Item> model, String key, String[] fieldValues) { String rtn = null; Item item = model.get(key); if (item != null) { if (item.value != null && item.value.trim().length() > 0) { return item.value; } else { String fieldValue = fieldValues[item.currentIndex]; String fieldIndex = String.valueOf(item.currentIndex); String currentKey = null; if (key != null && key.trim().length() > 0) { currentKey = key + Item.KEY_SEPARATOR + fieldIndex + Item.KEY_SEPARATOR + fieldValue; } else { currentKey = fieldIndex + Item.KEY_SEPARATOR + fieldValue; } rtn = getValueFromModel(model, currentKey, fieldValues); } } return rtn; } }
item
public class Item { public static final String KEY_SEPARATOR = "-"; public static final String STR_SEPARATOR = ","; public static final String ROOT_EKY = ""; public String parentKey = null; public String key = ""; public String value = null; public int currentIndex = -1; public Item(String key, String value) { super(); this.key = key; this.value = value; } public String toStr() { StringBuilder b = new StringBuilder(); if (currentIndex >= 0) { b.append(String.valueOf(currentIndex)); } b.append(STR_SEPARATOR); if (key != null) { b.append(key); } b.append(STR_SEPARATOR); if (value != null) { b.append(value); } return b.toString(); } }
简单的训练数据(最后一列为目标属性)
A,C,A
A,D,A
A,A,A
B,C,B
C,C,C
序列化的模型(下个属性列序号, Key:列序号1-列属性1-列序号2-列属性2,目标属性)
0,,
,0-A,A
,0-B,B
,0-C,C
相关推荐
Java实现的决策树算法完整实例中,主要介绍了决策树的概念、原理,并结合完整实例形式分析了Java实现决策树算法的相关操作技巧。 决策树算法的基本概念 决策树算法是一种典型的分类方法,首先对数据进行处理,利用...
下面是一个简化的Java实现决策树的步骤: 1. 定义数据结构:创建`Sample`类,包含特征数组和类别标签。 2. 创建`Node`类,表示决策树的节点,包含特征索引、子节点和类别。 3. 实现特征选择函数,如计算信息增益或...
决策树生成算法的Java实现,可能还有一些BUG,没有做仔细校验与测试,完成主要功能。决策树具体详解移步:http://blog.csdn.net/adiaixin123456/article/details/50573849 项目的目录结构分为四个文件夹algorithm,...
在决策树算法中,选择哪些特征来进行分裂是非常重要的步骤之一,这通常基于某种信息增益标准来完成。 ### Java代码解析 #### 包声明与导入语句 代码以包声明 `package cn.liip.jcs;` 开始,并且导入了必要的类库...
### Java 实现决策树ID3算法 #### 一、决策树与ID3算法简介 决策树是一种常用的机器学习方法,...此外,文件读取部分也是实现决策树算法的重要组成部分,确保了能够正确处理训练数据,从而构建出有效的决策树模型。
ID3(Iterative Dichotomiser 3)是决策树算法的一种早期形式,由Ross Quinlan在1986年提出。这个算法主要基于信息熵和信息增益来选择最佳属性进行划分,以构建最优的决策树。 ID3算法的核心思想是通过不断划分数据...
4. Java实现决策树: 在Java中,可以使用各种库如Weka、Apache Mahout或自定义代码实现决策树。自定义实现通常包括以下组件: - 数据结构:用于存储数据集和决策树结构,如ArrayList、HashMap等。 - 分类器:包含...
在Java实现决策树时,我们需要创建以下核心类: 1. `TreeNode`:表示决策树的节点,包含特征、类别和子节点。 2. `DecisionTree`:决策树的主类,包含构建树、预测等方法。 3. `Dataset`:数据集类,用于存储样本和...
6. **源码实例**:`ID3.java`中的源码提供了实际操作的示例,可以学习如何在Java环境中实现决策树算法。通过阅读和理解这段代码,你可以了解到如何处理数据、计算信息增益并构造决策树的具体步骤。 这个项目对于...
总的来说,这个Java实现的ID3决策树算法可以帮助开发者理解和运用决策树模型,进行分类任务的建模和预测,同时提供了评估模型性能的能力。通过深入理解并实践这个实现,可以为进一步探索更复杂的决策树算法(如C4.5...
在本案例中,我们将探讨如何用Java实现ID3算法,这是一种早期的决策树学习算法,由Ross Quinlan于1986年提出。 ID3(Iterative Dichotomiser 3)算法基于信息熵和信息增益来选择最优特征进行分裂。信息熵是度量数据...
在本项目中,`DecisionTree1.java`很可能是实现决策树的主要代码文件。它可能包含了构建决策树的逻辑,包括计算信息增益,选择最佳分割属性,以及递归地构建子树。关键方法可能包括`buildTree()`用于构造决策树,`...
1. **选择特征**:在每个节点上,决策树算法会选择一个最优特征进行划分。这个最优特征通常是根据某种信息增益或基尼不纯度指标来确定的。 2. **分裂节点**:基于选定的特征,数据集被分割成多个子集,每个子集对应...
在Java环境中,我们可以使用不同的算法来实现决策树,如KNN(K-最近邻)、C4.5和ID3。 1. **K-最近邻(K-Nearest Neighbors, KNN)**: KNN是一种基于实例的学习,属于懒惰学习类别。它并不立即对数据进行任何假设...
Java基于ssm+mysql的决策树算法的大学生就业预测系统的实现.zipJava基于ssm+mysql的决策树算法的大学生就业预测系统的实现.zipJava基于ssm+mysql的决策树算法的大学生就业预测系统的实现.zipJava基于ssm+mysql的决策...
java实现决策树算法 id3 数据挖掘领域经典算法应用广泛
在Java中实现决策树,我们可以选择自定义算法,如ID3(Iterative Dichotomiser 3),或者使用现有的机器学习库,如Weka、Deeplearning4j等。本篇文章将深入探讨ID3算法以及如何在Java中实现它。 ID3算法是基于信息...
在Java环境中实现决策树算法,你需要理解基本的数据结构,如二叉树和列表,以及如何操作数据集。首先,定义一个类表示决策树节点,包含属性、类别和子节点等信息。接着,编写一个函数用于构建决策树,该函数接受训练...
ID3(Iterative Dichotomiser 3)是决策树算法的早期版本,由Ross Quinlan于1986年提出。这个算法基于信息熵和信息增益的概念来选择最优特征进行节点划分,构建出一个递归的树形结构。现在我们主要探讨的是如何使用...
综上所述,这个Java实现的ID3决策树算法改良版提供了对不同数据源的支持,使得用户可以方便地应用到自己的项目中,通过调整参数和优化策略,适应不同的数据集和任务需求。在实际应用中,还需要考虑算法的效率和内存...