`
czhsuccess
  • 浏览: 41750 次
社区版块
存档分类
最新评论

java实现ID3算法

阅读更多

ID3是经典的分类算法,要理解ID3算法,需要先了解一些基本的信息论概念,包括信息量,熵,后验熵,条件熵。ID3算法的核心思想是选择互信息量最大的属性作为分割节点,这样做可以保证所建立的决策树高度最小。

树结构代码:

 

/**
 * C4.5决策树数据结构
 * @author zhenhua.chen
 * @Description: TODO
 * @date 2013-3-1 上午10:47:37 
 *
 */
public class TreeNode {
	private String nodeName; // 决策树节点名称
	private List<String> splitAttributes; // 分裂属性名
	private ArrayList<TreeNode> childrenNodes; // 决策树的子节点
	private ArrayList<ArrayList<String>> dataSet; // 划分到该节点的数据集 
	private ArrayList<String> arrributeSet; // 数据集所有属性
	
	public TreeNode(){
		childrenNodes = new ArrayList<TreeNode>();
	}
	
	public String getNodeName() {
		return nodeName;
	}
	public void setNodeName(String nodeName) {
		this.nodeName = nodeName;
	}
	public List<String> getSplitAttributes() {
		return splitAttributes;
	}
	public void setSplitAttributes(List<String> splitAttributes) {
		this.splitAttributes = splitAttributes;
	}
	public ArrayList<TreeNode> getChildrenNodes() {
		return childrenNodes;
	}
	public void setChildrenNodes(ArrayList<TreeNode> childrenNodes) {
		this.childrenNodes = childrenNodes;
	}
	public ArrayList<ArrayList<String>> getDataSet() {
		return dataSet;
	}
	public void setDataSet(ArrayList<ArrayList<String>> dataSet) {
		this.dataSet = dataSet;
	}
	public ArrayList<String> getArrributeSet() {
		return arrributeSet;
	}
	public void setArrributeSet(ArrayList<String> arrributeSet) {
		this.arrributeSet = arrributeSet;
	}
}

 

 

决策树算法:

/**
 * 构造决策树的类
 * @author zhenhua.chen
 * @Description: TODO
 * @date 2013-3-1 下午4:42:07 
 *
 */
public class DecisionTree {
	/**
	 * 建树类
	 * @param dataSet
	 * @param attributeSet
	 * @return
	 */
	public TreeNode buildTree(ArrayList<ArrayList<String>> dataSet, ArrayList<String> attributeSet) {
		TreeNode node = new TreeNode();
		node.setDataSet(dataSet);
		node.setArrributeSet(attributeSet);
		
		// 根据当前数据集计算决策树的节点
		int index = -1;
		double gain = 0;
		double maxGain = 0;
		for(int i = 0; i < attributeSet.size() - 1; i++) {
			gain = ComputeUtil.computeEntropy(dataSet, attributeSet.size() - 1) - ComputeUtil.computeConditinalEntropy(dataSet, i);
			if(gain > maxGain) {
				index = i;
				maxGain = gain;
			}
		}
		ArrayList<String> splitAttributes = ComputeUtil.getTypes(dataSet, index); // 获取该节点下的分裂属性
		node.setSplitAttributes(splitAttributes);
		node.setNodeName(attributeSet.get(index));
		
		// 判断每个属性列是否需要继续分裂
		for(int i = 0; i < splitAttributes.size(); i++) {
			ArrayList<ArrayList<String>> splitDataSet = ComputeUtil.getDataSet(dataSet, index, splitAttributes.get(i));
			
			// 判断这个分裂子数据集的目标属性是否纯净,如果纯净则结束,否则继续分裂
			int desColumn = splitDataSet.get(0).size() - 1; // 目标属性列所在的列号
			ArrayList<String> desAttributes = ComputeUtil.getTypes(splitDataSet, desColumn);
			TreeNode childNode = new TreeNode();
			if(desAttributes.size() == 1) {
				childNode.setNodeName(desAttributes.get(0));
			} else {
				ArrayList<String> newAttributeSet = new ArrayList<String>();
				for(String s : attributeSet) { // 删除新属性集合中已作为决策树节点的属性值
					if(!s.equals(attributeSet.get(index))) {
						newAttributeSet.add(s);
					}
				}
				
				ArrayList<ArrayList<String>> newDataSet = new ArrayList<ArrayList<String>>();
				for(ArrayList<String> data : splitDataSet) { // 除掉columnIndex参数指定的
					ArrayList<String> tmp = new ArrayList<String>();
					for(int j = 0; j < data.size(); j++) {
						if(j != index) {
							tmp.add(data.get(j));
						}
					}
					newDataSet.add(tmp);
				}
				
				childNode = buildTree(newDataSet, newAttributeSet); // 递归建树
			}
			node.getChildrenNodes().add(childNode);
		}
		return node;
	}
	
	/**
	 * 打印建好的树
	 * @param root
	 */
	public void printTree(TreeNode root) {
		System.out.println("----------------");
		if(null != root.getSplitAttributes()) {
			System.out.print("分裂节点:" + root.getNodeName());
			for(String attr : root.getSplitAttributes()) {
				System.out.print("(" + attr + ") ");
			}
		} else {
			System.out.print("分裂节点:" + root.getNodeName());
		}
		
		if(null != root.getChildrenNodes()) {
			for(TreeNode node : root.getChildrenNodes()) {
				printTree(node);
			}
		}
		
	}
	
	/**
	 * 
	* @Title: searchTree 
	* @Description: 层次遍历树
	* @return void
	* @throws
	 */
	public void searchTree(TreeNode root) {
		Queue<TreeNode> queue = new LinkedList<TreeNode>();
		queue.offer(root);
		
		while(queue.size() != 0) {
			TreeNode node = queue.poll();
			if(null != node.getSplitAttributes()) {
				System.out.print("分裂节点:" + node.getNodeName() + "; "); 
				for(String attr : node.getSplitAttributes()) {
					System.out.print(" (" + attr + ") ");
				}
			} else {
				System.out.print("叶子节点:" + node.getNodeName() + "; "); 
			}
			
			if(null != node.getChildrenNodes()) {
				for(TreeNode nod : node.getChildrenNodes()) {
					queue.offer(nod);
				}
			}
		}
	}
	
}

 

 

一些util代码:

/**
 * C4.5算法所需的各类计算方法
 * @author zhenhua.chen
 * @Description: TODO
 * @date 2013-3-1 上午10:48:47 
 *
 */
public class ComputeUtil {
	
	/**
	 * 获取指定数据集中指定属性列的各个类别
	* @Title: getTypes 
	* @Description: TODO
	* @return ArrayList<String>
	* @throws
	 */
	public static ArrayList<String> getTypes(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		ArrayList<String> list = new ArrayList<String>();
		for(ArrayList<String> data : dataSet) {
			if(!list.contains(data.get(columnIndex))) {
				list.add(data.get(columnIndex));
			}
		}
		return list;
	}
	
	/**
	 * 获取指定数据集中指定属性列的各个类别及其计数
	* @Title: getClassCounts 
	* @Description: TODO
	* @return Map<String,Integer>
	* @throws
	 */
	public static Map<String, Integer> getTypeCounts(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		Map<String, Integer> map = new HashMap<String, Integer>();
		for(ArrayList<String> data : dataSet) {
			String key = data.get(columnIndex);
			if(map.containsKey(key)) {
				map.put(key, map.get(key) + 1);
			} else {
				map.put(key, 1);
			}
		}
		return map;
	}
	
	/**
	 * 获取指定列上指定类别的数据集合(分裂后的数据子集)
	* @Title: getDataSet 
	* @Description: TODO
	* @return ArrayList<ArrayList<String>>
	* @throws
	 */
	public static ArrayList<ArrayList<String>> getDataSet(ArrayList<ArrayList<String>> dataSet, int columnIndex, String attribueClass) {
		ArrayList<ArrayList<String>> splitDataSet = new ArrayList<ArrayList<String>>();
		for(ArrayList<String> data : dataSet) {
			if(data.get(columnIndex).equals(attribueClass)) {
				splitDataSet.add(data);
			}
		}
		
		return splitDataSet;
	}
	
	/**
	 * 计算指定列(属性)的信息熵
	* @Title: computeEntropy 
	* @Description: TODO
	* @return double
	* @throws
	 */
	public static double computeEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		Map<String, Integer> map = getTypeCounts(dataSet, columnIndex);
		int dataSetSize = dataSet.size();
		Iterator<String> keyIter = map.keySet().iterator();
		double entropy = 0;
		while(keyIter.hasNext()) {
			double prob = (double)map.get((String)keyIter.next()) / (double)dataSetSize;
			entropy += (-1) * prob * Math.log(prob) / Math.log(2); 
			
		}
		return entropy;
	}
	
	/**
	 * 计算基于指定属性列对目标属性的条件信息熵
	 */
	public static double computeConditinalEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		Map<String, Integer> map = getTypeCounts(dataSet, columnIndex);  // 获取该属性列的所有列别及其计数
		
		double conditionalEntropy = 0; // 条件熵
		
		// 获取根据每个类别分割后的数据集合
		Iterator<String> iter = map.keySet().iterator(); 
		while(iter.hasNext()) {
			ArrayList<ArrayList<String>> splitDataSet = getDataSet(dataSet, columnIndex, (String)iter.next());
			// 计算目标属性列的列索引
			int desColumn = 0;
			if(splitDataSet.get(0).size() > 0) {
				desColumn = splitDataSet.get(0).size() - 1;
			}
			
			double probY = (double)splitDataSet.size() / (double)dataSet.size();
			
			Map<String, Integer> map1 = getTypeCounts(splitDataSet, desColumn); //根据分割后的子集计算后验熵
			Iterator<String> iter1 = map1.keySet().iterator();
			double proteriorEntropy = 0;
			while(iter1.hasNext()) {
				String key = (String)iter1.next(); // 目标属性列中的一个分类
				double posteriorProb = (double)map1.get(key) / (double)splitDataSet.size();
				proteriorEntropy += (-1) * posteriorProb * Math.log(posteriorProb) / Math.log(2);
			}
			
			conditionalEntropy += probY * proteriorEntropy; // 基于某个分割属性计算条件熵
		}
		return conditionalEntropy;
	}
}

 测试代码:

public class Test {
	public static void main(String[] args) {
		File f = new File("D:/test.txt");
		BufferedReader reader = null;
		
		try {
			reader = new BufferedReader(new FileReader(f));
			String str = null;
			try {
				str = reader.readLine(); 
				ArrayList<String> attributeList = new ArrayList<String>();
				String[] attributes = str.split("\t");
				
				for(int i = 0; i < attributes.length; i++) {
					attributeList.add(attributes[i]);
				}
				
				ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>();
				while((str = reader.readLine()) != null) {
					ArrayList<String> tmpList = new ArrayList<String>();
					String[] s = str.split("\t");
					for(int i = 0; i < s.length; i++) {
						tmpList.add(s[i]);
					}
					dataSet.add(tmpList);
				}
				
				DecisionTree dt = new DecisionTree();
				TreeNode root = dt.buildTree(dataSet, attributeList);
//				dt.printTree(root);
				dt.searchTree(root);
				
			} catch (IOException e) {
				e.printStackTrace();
			}
			
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
	}
}

 

 

 

分享到:
评论
1 楼 tcblackmamba 2015-12-07  
能看一下您的数据集嘛??

相关推荐

    JAVA实现ID3算法

    总的来说,用Java实现ID3算法涉及数据结构设计、信息熵与信息增益计算以及递归算法的运用。这个过程可以帮助我们理解决策树的构建原理,并为其他机器学习算法提供基础。同时,通过控制台输出或UI显示,可以让用户...

    决策树ID3算法(Java实现)

    文件名`ID3`可能是指包含了整个决策树ID3算法实现的Java源代码文件。在阅读和理解代码时,要关注如何计算信息熵、信息增益,以及如何根据这些值选择最佳属性进行划分,同时注意树的构建和预测过程。 总的来说,ID3...

    基于java实现的 决策树之 ID3 算法

    在本案例中,我们将探讨如何用Java实现ID3算法,这是一种早期的决策树学习算法,由Ross Quinlan于1986年提出。 ID3(Iterative Dichotomiser 3)算法基于信息熵和信息增益来选择最优特征进行分裂。信息熵是度量数据...

    用java实现的大数据分析 ID3算法

    这个例子来源于Quinlan的论文。 假设,有种户外活动。该活动能否正常进行与各种天气因素有关。不同的天气因素组合会产生两种后果,也就是分成2类:能进行活动或不能。我们用P表示该活动可以进行,N表示该活动无法...

    决策树ID3算法java实现

    现在我们主要探讨的是如何使用Java实现ID3算法。 在Java中实现ID3算法,我们需要以下几个关键步骤: 1. **数据预处理**:首先,你需要将数据集转换成适合ID3算法处理的格式。数据集通常包含实例(样本)和属性...

    id3算法的实现 java 如果你对数据挖掘

    在这个Java实现中,我们将深入探讨ID3算法的基本原理、步骤以及如何在Java编程环境中进行实现。 ID3算法的核心思想是信息熵和信息增益。熵是用来衡量数据集合纯度的一个指标,信息增益则是通过选择最佳属性来减少熵...

    Java实现ID3的代码

    Java实现ID3算法是数据挖掘领域中的一个基础任务,它主要用于构建决策树。ID3(Iterative Dichotomiser 3)是由Ross Quinlan提出的,适用于离散属性的分类问题。在这里,我们将深入探讨ID3算法的核心概念,以及如何...

    java实现决策树ID3算法

    ### Java 实现决策树ID3算法 #### 一、决策树与ID3算法简介 决策树是一种常用的机器学习方法,用于分类与回归任务。它通过树状结构来表示规则,其中每个内部节点代表一个特征上的判断,每个分支代表一个判断结果,...

    ID3算法java实现

    在Java中实现ID3算法可以帮助我们构建基于数据的决策树模型,用于预测未知数据的类别。以下是关于ID3算法及其Java实现的一些关键知识点: 1. **决策树基础**:决策树是一种直观易懂的机器学习模型,通过一系列问题...

    数据挖掘的ID3算法

    在压缩包中的"ID3数据挖掘"可能包含了实现ID3算法的Java源代码、测试数据集以及生成的XML决策树文件。通过分析这些文件,可以进一步了解如何将理论知识转化为实际的程序实现。 总的来说,ID3算法是数据挖掘中的一种...

    ID3算法Java实现

    4. **Java实现**:在Java中实现ID3算法,我们需要定义数据结构来存储训练集、特征和类别信息,以及计算信息熵和信息增益的函数。同时,需要实现决策树的构建过程,包括节点的创建、划分和终止条件判断。代码中应该...

    ID3_java.rar_ID3 决策树 java_id3_id3 java_id3 决策树_决策树

    在"ID3_java.rar"压缩包中,包含的"ID3_java"文件可能是Java代码实现的ID3算法。开发者可以通过阅读源代码,了解具体的实现细节,如如何处理连续值、如何处理缺失值、如何优化决策树的剪枝等。同时,可以将这个代码...

    ID3-Java.rar_id3_java ID3

    这个名为"ID3-Java.rar_id3_java ID3"的压缩包包含了一个使用Java编程语言实现的ID3算法。ID3算法由Ross Quinlan在1986年提出,它的核心思想是通过信息熵和信息增益来选择最优特征,构建决策树模型。 首先,我们...

    ID3分类算法Java实现

    总的来说,这个Java实现为理解ID3算法提供了一个实践平台,同时也为其他数据挖掘任务提供了基础。学习和理解这个实现,有助于深入掌握决策树分类的原理,并能应用于实际项目中。对于数据挖掘初学者来说,这是一个很...

    FP树增长算法的java实现

    在Java中实现FP树算法,我们可以按照以下步骤进行: 1. **数据预处理**:首先,我们需要对原始数据进行预处理,将交易数据转换为事务ID和项ID的形式,即每条记录表示一个交易,其中包含交易中出现的所有项。 2. **...

    机器学习 决策树算法(ID3)java实现

    4. **Java实现**:在`ID3.java`文件中,开发者可能实现了ID3算法的类,包括读取数据(可能从`app.arff`文件中,这是一个常见的数据格式,用于存储WEKA等机器学习库的数据),计算信息熵、信息增益,以及构建和遍历...

    java实现银行家算法

    在Java编程语言中,我们可以利用面向对象的特点来实现这一算法。 ### 1. 死锁概念 在多任务环境下,如果两个或更多的进程互相等待对方释放资源,而它们都无法继续执行,就形成了死锁。银行家算法的出现就是为了...

    基于Java实现的同态加密算法的实现

    在"research_encrypt-code"这个压缩包中,很可能包含了Java实现同态加密算法的源代码,包括密钥管理、加密、解密和操作加密数据的函数。通过研究这些代码,我们可以深入了解如何在实际应用中利用Java来构建安全的...

    Java实现数据挖掘算法

    Weka提供了一个直观的界面和API,支持包括C4.5、ID3和CART在内的多种决策树算法。而Apache Mahout则更注重大规模数据处理,它的决策树算法适合处理大数据集。 粗糙集理论则是数据挖掘中的另一重要概念,它由波兰...

Global site tag (gtag.js) - Google Analytics