`
leon_a
  • 浏览: 79052 次
  • 性别: Icon_minigender_1
  • 来自: 拜月神教
社区版块
存档分类
最新评论

决策树ID3算法

阅读更多
算了,还是自己修正一个BUG....
package graph;

import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;

/**
 * 决策树的ID3算法
 * 参照实现http://www.blog.edu.cn/user2/huangbo929/archives/2006/1533249.shtml
 * 
 * @author Leon.Chen
 * 
 */
public class DTree {
	/**
	 * 根节点
	 */
	TreeNode root;

	/**
	 * 可见性数组
	 */
	private boolean[] visable;

	/**
	 * 未找到节点
	 */
	private static final int NO_FOUND = -1;

	/**
	 * 训练集
	 */
	private Object[] trainingArray;

	/**
	 * 节点索引
	 */
	private int nodeIndex;

	/**
	 * @param args
	 */
	@SuppressWarnings("boxing")
	public static void main(String[] args) {
		Object[] array = new Object[] {
				new String[] { "男", "中年", "未婚", "大学", "中", "没购买" },
				new String[] { "女", "中年", "未婚", "大学", "中", "购买" },
				new String[] { "男", "中年", "已婚", "大学", "中", "购买" },
				new String[] { "男", "老年", "已婚", "大学以下", "低", "购买" } };

		DTree tree = new DTree();
		tree.create(array, 5);
		System.out.println("===============END PRINT TREE===============");
		String[] printData = new String[] { "女", "中年", "未婚", "大学", "中" };
		System.out.println("===============DECISION RESULT===============");
		tree.compare(printData, tree.root);
	}

	/**
	 * 根据传入数据进行预测
	 * 
	 * @param printData
	 * @param node
	 */
	public void compare(String[] printData, TreeNode node) {
		int index = getNodeIndex(node.nodeName);
		if (index == NO_FOUND) {
			System.out.println(node.nodeName);
			System.out.println((node.percent * 100) + "%");
		}
		TreeNode[] childs = node.childNodes;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				if (childs[i].parentArrtibute.equals(printData[index])) {
					compare(printData, childs[i]);
				}
			}
		}
	}

	/**
	 * 创建
	 * 
	 * @param array
	 * @param index
	 */
	public void create(Object[] array, int index) {
		this.trainingArray = array;
		init(array, index);
		createDTree(array);
		printDTree(root);
	}

	/**
	 * 得到最大信息增益
	 * 
	 * @param array
	 * @return Object[]
	 */
	@SuppressWarnings("boxing")
	public Object[] getMaxGain(Object[] array) {
		Object[] result = new Object[2];
		double gain = 0;
		int index = -1;

		for (int i = 0; i < visable.length; i++) {
			if (!visable[i]) {
				double value = gain(array, i);
				if (gain < value) {
					gain = value;
					index = i;
				}
			}
		}
		result[0] = gain;
		result[1] = index;
		if (index != -1) {
			visable[index] = true;
		}
		return result;
	}

	/**
	 * 创建决策树
	 * 
	 * @param array
	 */
	public void createDTree(Object[] array) {
		Object[] maxgain = getMaxGain(array);
		if (root == null) {
			root = new TreeNode();
			root.parent = null;
			root.parentArrtibute = null;
			root.arrtibutes = getArrtibutes(((Integer) maxgain[1]).intValue());
			root.nodeName = getNodeName(((Integer) maxgain[1]).intValue());
			root.childNodes = new TreeNode[root.arrtibutes.length];
			insertTree(array, root);
		}
	}

	/**
	 * 插入到决策树
	 * 
	 * @param array
	 * @param parentNode
	 */
	public void insertTree(Object[] array, TreeNode parentNode) {
		String[] arrtibutes = parentNode.arrtibutes;
		for (int i = 0; i < arrtibutes.length; i++) {
			Object[] pickArray = pickUpAndCreateArray(array, arrtibutes[i],
					getNodeIndex(parentNode.nodeName));
			Object[] info = getMaxGain(pickArray);
			double gain = ((Double) info[0]).doubleValue();
			if (gain != 0) {
				int index = ((Integer) info[1]).intValue();
				TreeNode currentNode = new TreeNode();
				currentNode.parent = parentNode;
				currentNode.parentArrtibute = arrtibutes[i];
				currentNode.arrtibutes = getArrtibutes(index);
				currentNode.nodeName = getNodeName(index);
				currentNode.childNodes = new TreeNode[currentNode.arrtibutes.length];
				parentNode.childNodes[i] = currentNode;
				insertTree(pickArray, currentNode);
			} else {
				TreeNode leafNode = new TreeNode();
				leafNode.parent = parentNode;
				leafNode.parentArrtibute = arrtibutes[i];
				leafNode.arrtibutes = new String[0];
				leafNode.nodeName = getLeafNodeName(pickArray);
				leafNode.childNodes = new TreeNode[0];
				parentNode.childNodes[i] = leafNode;

				double percent = 0;
				String[] arrs = getArrtibutes(this.nodeIndex);
				for (int j = 0; j < arrs.length; j++) {
					if (leafNode.nodeName.equals(arrs[j])) {
						Object[] subo = pickUpAndCreateArray(pickArray,
								arrs[j], this.nodeIndex);
						Object[] o = pickUpAndCreateArray(this.trainingArray,
								arrs[j], this.nodeIndex);
						double subCount = subo.length;
						percent = subCount / o.length;
					}
				}
				leafNode.percent = percent;
			}
		}
	}

	/**
	 * 打印决策树
	 * 
	 * @param node
	 */
	public void printDTree(TreeNode node) {
		System.out.println(node.nodeName);
		TreeNode[] childs = node.childNodes;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				System.out.println(childs[i].parentArrtibute);
				printDTree(childs[i]);
			}
		}
	}

	/**
	 * 初始化
	 * 
	 * @param dataArray
	 * @param index
	 */
	public void init(Object[] dataArray, int index) {
		this.nodeIndex = index;
		// 数据初始化
		visable = new boolean[((String[]) dataArray[0]).length];
		for (int i = 0; i < visable.length; i++) {
			if (i == index) {
				visable[i] = true;
			} else {
				visable[i] = false;
			}
		}
	}

	/**
	 * 剪取数组
	 * 
	 * @param array
	 * @param arrtibute
	 * @param index
	 * @return Object[]
	 */
	public Object[] pickUpAndCreateArray(Object[] array, String arrtibute,
			int index) {
		List<String[]> list = new ArrayList<String[]>();
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				list.add(strs);
			}
		}
		return list.toArray();
	}

	/**
	 * Entropy(S)
	 * 
	 * @param array
	 * @param index
	 * @return double
	 */
	public double gain(Object[] array, int index) {
		String[] playBalls = getArrtibutes(this.nodeIndex);
		int[] counts = new int[playBalls.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			for (int j = 0; j < playBalls.length; j++) {
				if (strs[this.nodeIndex].equals(playBalls[j])) {
					counts[j]++;
				}
			}
		}
		/**
		 * Entropy(S) = S -p(I) log2 p(I)
		 */
		double entropyS = 0;
		for (int i = 0; i < counts.length; i++) {
			entropyS += DTreeUtil.sigma(counts[i], array.length);
		}
		String[] arrtibutes = getArrtibutes(index);
		/**
		 * total ((|Sv| / |S|) * Entropy(Sv))
		 */
		double sv_total = 0;
		for (int i = 0; i < arrtibutes.length; i++) {
			sv_total += entropySv(array, index, arrtibutes[i], array.length);
		}
		return entropyS - sv_total;
	}

	/**
	 * ((|Sv| / |S|) * Entropy(Sv))
	 * 
	 * @param array
	 * @param index
	 * @param arrtibute
	 * @param allTotal
	 * @return double
	 */
	public double entropySv(Object[] array, int index, String arrtibute,
			int allTotal) {
		String[] playBalls = getArrtibutes(this.nodeIndex);
		int[] counts = new int[playBalls.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}

		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				for (int k = 0; k < playBalls.length; k++) {
					if (strs[this.nodeIndex].equals(playBalls[k])) {
						counts[k]++;
					}
				}
			}
		}

		int total = 0;
		double entropySv = 0;
		for (int i = 0; i < counts.length; i++) {
			total += counts[i];
		}
		for (int i = 0; i < counts.length; i++) {
			entropySv += DTreeUtil.sigma(counts[i], total);
		}
		return DTreeUtil.getPi(total, allTotal) * entropySv;
	}

	/**
	 * 取得属性数组
	 * 
	 * @param index
	 * @return String[]
	 */
	@SuppressWarnings("unchecked")
	public String[] getArrtibutes(int index) {
		TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
		for (int i = 0; i < trainingArray.length; i++) {
			String[] strs = (String[]) trainingArray[i];
			set.add(strs[index]);
		}
		String[] result = new String[set.size()];
		return set.toArray(result);
	}

	/**
	 * 取得节点名
	 * 
	 * @param index
	 * @return String
	 */
	public String getNodeName(int index) {
		String[] strs = new String[] { "性别", "年龄", "婚否", "学历", "中还是低", "是否购买" };
		for (int i = 0; i < strs.length; i++) {
			if (i == index) {
				return strs[i];
			}
		}
		return null;
	}

	/**
	 * 取得页节点名
	 * 
	 * @param array
	 * @return String
	 */
	public String getLeafNodeName(Object[] array) {
		if (array != null && array.length > 0) {
			String[] strs = (String[]) array[0];
			return strs[nodeIndex];
		}
		return null;
	}

	/**
	 * 取得节点索引
	 * 
	 * @param name
	 * @return int
	 */
	public int getNodeIndex(String name) {
		String[] strs = new String[] { "性别", "年龄", "婚否", "学历", "中还是低", "是否购买" };
		for (int i = 0; i < strs.length; i++) {
			if (name.equals(strs[i])) {
				return i;
			}
		}
		return NO_FOUND;
	}
}

package graph;

/**
 * @author Leon.Chen
 */
public class TreeNode {

    /**
     * 父节点
     */
    TreeNode parent;

    /**
     * 指向父的哪个属性
     */
    String parentArrtibute;

    /**
     * 节点名
     */
    String nodeName;

    /**
     * 属性数组
     */
    String[] arrtibutes;

    /**
     * 节点数组
     */
    TreeNode[] childNodes;
    
    /**
     * 可信度
     */
    double percent;

}  

package graph;

/**
 * @author Leon.Chen
 */
public class DTreeUtil {

	/**
	 * 属性值熵的计算 Info(T)=(i=1...k)pi*log(2)pi
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double sigma(int x, int total) {
		if (x == 0) {
			return 0;
		}
		double x_pi = getPi(x, total);
		return -(x_pi * logYBase2(x_pi));
	}

	/**
	 * log2y
	 * 
	 * @param y
	 * @return double
	 */
	public static double logYBase2(double y) {
		return Math.log(y) / Math.log(2);
	}

	/**
	 * pi是当前这个属性出现的概率(=出现次数/总数)
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double getPi(int x, int total) {
		return x * Double.parseDouble("1.0") / total;
	}

}

package graph;

import java.util.Comparator;

/**
 * @author Leon.Chen
 *
 */
@SuppressWarnings("unchecked")
public class SequenceComparator implements Comparator {

    public int compare(Object o1, Object o2) throws ClassCastException {
        String str1 = (String) o1;
        String str2 = (String) o2;
        return str1.compareTo(str2);
    }
}
6
0
分享到:
评论
1 楼 yumingxing 2008-12-31  
程序有错误啊,数组越界

相关推荐

    决策树ID3算法描述与实现

    决策树是一种常用的数据挖掘技术...通过这个项目,不仅可以学习到决策树ID3算法的基本原理,还能掌握如何在实际开发中运用MFC进行GUI编程。这有助于提升软件开发能力和数据分析技能,为今后从事相关工作打下坚实基础。

    决策树ID3算法(Java实现)

    文件名`ID3`可能是指包含了整个决策树ID3算法实现的Java源代码文件。在阅读和理解代码时,要关注如何计算信息熵、信息增益,以及如何根据这些值选择最佳属性进行划分,同时注意树的构建和预测过程。 总的来说,ID3...

    决策树ID3算法的实例解析

    ### 决策树ID3算法的实例解析 #### 一、引言 本文将深入探讨决策树ID3算法的核心概念及其应用案例。ID3(Iterative Dichotomiser 3)算法是由Ross Quinlan在1986年提出的一种用于生成决策树的经典算法。在数据挖掘...

    决策树ID3算法编程(c语言课程设计)

    决策树ID3算法是机器学习领域中的一种经典分类方法,主要应用于数据挖掘和模式识别。在C语言课程设计中,实现ID3算法可以帮助学生深入理解数据处理和算法逻辑。ID3算法是由Ross Quinlan提出的,它基于信息熵和信息...

    java实现决策树ID3算法

    ### Java 实现决策树ID3算法 #### 一、决策树与ID3算法简介 决策树是一种常用的机器学习方法,用于分类与回归任务。它通过树状结构来表示规则,其中每个内部节点代表一个特征上的判断,每个分支代表一个判断结果,...

    决策树ID3算法实验报告广工(附源码java)

    《决策树ID3算法实验详解——以广工实验为例》 决策树ID3算法是一种经典的机器学习算法,常用于分类任务。本实验报告基于广东工业大学(广工)的人工智能课程,通过具体案例——UCI标准数据集Car-Evaluation,详细...

    C 实现决策树ID3算法.txt

    ### C 实现决策树ID3算法 #### 一、概览 本文档旨在解析一个用C语言实现的决策树ID3算法的代码片段。决策树是一种常用的机器学习方法,广泛应用于分类与回归任务中。ID3(Iterative Dichotomiser 3)是决策树的一种...

    决策树ID3算法编程(c语言课程设计) by Chain_Gank

    决策树ID3算法编程(C语言课程设计) 本文将详细介绍决策树ID3算法编程的实践报告,涵盖了决策树的原理分析、实现步骤、程序设计及测试结果等方面的内容。 一、决策树ID3算法原理分析 决策树ID3算法是一种常用的...

    决策树ID3算法实验_数据集car_databases

    用python编写的决策树ID3算法,运用了Car-Evaluation的例子。BUG较少,综合了网上的优秀代码,并进一步形成自己的代码。代码基本有注释,风格良好,能够很快看懂。内含有比较规范的报告文档,包含所有流程图,说明图...

    08-2第八章机器学习-决策树ID3算法的实例解析.pptx

    决策树ID3算法实例解析、机器学习算法排名、信息量和熵的定义与计算 本资源摘要信息来自于一个PPT文件,标题为“08-2第八章机器学习-决策树ID3算法的实例解析.pptx”。该资源涵盖了机器学习、决策树、ID3算法、信息...

    决策树ID3算法的实现

    ID3(Iterative Dichotomiser 3)是决策树算法的一种早期版本,由Ross Quinlan于1986年提出。它基于信息熵和信息增益的概念来选择最优特征,构建决策树模型。 **ID3算法的基本概念** 1. **信息熵(Entropy)**:...

    0/1背包算法与决策树ID3算法实现

    0/1背包算法与决策树ID3算法实现 本文主要讨论了 0/1 背包动态规划算法与决策树 ID3 算法的实现 DETAILS 。 0/1 背包问题 0/1 背包问题是一种组合优化的 NP 完全问题。问题可以描述为:给定一组物品,每种物品都...

    Python实现决策树ID3算法

    Tom编写的机器学习教材中PlayTennis例题—ID3算法python实现

    论文《决策树ID3 算法的改进》

    ### 决策树ID3算法的改进 #### 引言 决策树方法因其高效的数据分析能力和直观性,在机器学习和知识发现领域得到了广泛的应用。在众多的决策树构建算法中,ID3算法由Ross Quinlan于1986年提出,是最早且最具影响力...

    数据挖掘决策树ID3算法优化

    "数据挖掘决策树ID3算法优化" 数据挖掘是从大量数据中提取出可信、新颖、有效并能被人理解的模式的高级处理过程。它是一门交叉学科,把人们对数据的应用从低层次的简单查询,提升到从大量数据中提炼有价值的信息,...

    决策树id3算法实现1

    决策树id3算法实现,递归函数,决策树打印,提供数据集

    决策树ID3算法 python

    总结来说,决策树ID3算法是基于信息熵和信息增益的分类方法,它通过递归地选择最优特征来构建决策树模型。在Python中,可以通过自定义代码或利用现有的机器学习库实现ID3算法,以解决分类问题。结合压缩包中的资源,...

    python实现决策树ID3算法的示例代码

    通过以上步骤,我们可以实现决策树ID3算法。需要注意的是,ID3算法仅适用于离散型特征,并且由于使用了信息增益,它倾向于选择取值多的特征。此外,在实际应用中,为了避免过拟合,可能需要对ID3算法进行剪枝处理。

    数据挖掘决策树ID3算法C++实现

    ID3(Iterative Dichotomiser 3)算法是决策树学习方法的一种,由Ross Quinlan于1986年提出,主要用于分类任务。在本项目中,我们将深入探讨ID3算法的原理以及如何使用C++进行实现。 ID3算法基于信息熵和信息增益来...

Global site tag (gtag.js) - Google Analytics