`
leon_a
  • 浏览: 79022 次
  • 性别: Icon_minigender_1
  • 来自: 拜月神教
社区版块
存档分类
最新评论

决策树C4.5算法

阅读更多
数据挖掘中决策树C4.5预测算法实现(半成品,还要写规则后煎支及对非离散数据信息增益计算),下一篇博客讲原理

package org.struct.decisiontree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.TreeSet;

/**
 * @author Leon.Chen
 */
public class DecisionTreeBaseC4p5 {
	
	/**
	 * root node
	 */
	private DecisionTreeNode root;

	/**
	 * visableArray
	 */
	private boolean[] visable;

	private static final int NOT_FOUND = -1;

	private static final int DATA_START_LINE = 1;

	private Object[] trainingArray;

	private String[] columnHeaderArray;

	/**
	 * forecast node index
	 */
	private int nodeIndex;

	/**
	 * @param args
	 */
	@SuppressWarnings("boxing")
	public static void main(String[] args) {
		Object[] array = new Object[] {
				new String[] { "age",          "income",   "student", "credit_rating", "buys_computer" },
				new String[] { "youth",        "high",     "no",      "fair",          "no"  },
				new String[] { "youth",        "high",     "no",      "excellent",     "no"  },
				new String[] { "middle_aged",  "high",     "no",      "fair",          "yes" },
				new String[] { "senior",       "medium",   "no",      "fair",          "yes" },
				new String[] { "senior",       "low",      "yes",     "fair",          "yes" },
				new String[] { "senior",       "low",      "yes",     "excellent",     "no"  },
				new String[] { "middle_aged",  "low",      "yes",     "excellent",     "yes" },
				new String[] { "youth",        "medium",   "no",      "fair",          "no"  },
				new String[] { "youth",        "low",      "yes",     "fair",          "yes" },
				new String[] { "senior",       "medium",   "yes",     "fair",          "yes" },
				new String[] { "youth",        "medium",   "yes",     "excellent",     "yes" },
				new String[] { "middle_aged",  "medium",   "no",      "excellent",     "yes" },
				new String[] { "middle_aged",  "high",     "yes",     "fair",          "yes" },
				new String[] { "senior",       "medium",   "no",      "excellent",     "no"  },
		};

		DecisionTreeBaseC4p5 tree = new DecisionTreeBaseC4p5();
		tree.create(array, 4);
		System.out.println("===============END PRINT TREE===============");
		System.out.println("===============DECISION RESULT===============");
		//tree.forecast(printData, tree.root);
	}

	/**
	 * @param printData
	 * @param node
	 */
	public void forecast(String[] printData, DecisionTreeNode node) {
		int index = getColumnHeaderIndexByName(node.nodeName);
		if (index == NOT_FOUND) {
			System.out.println(node.nodeName);
		}
		DecisionTreeNode[] childs = node.childNodesArray;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				if (childs[i].parentArrtibute.equals(printData[index])) {
					forecast(printData, childs[i]);
				}
			}
		}
	}

	/**
	 * @param array
	 * @param index
	 */
	public void create(Object[] array, int index) {
		this.trainingArray = Arrays.copyOfRange(array, DATA_START_LINE,
				array.length);
		init(array, index);
		createDecisionTree(this.trainingArray);
		printDecisionTree(root);
	}

	/**
	 * @param array
	 * @return Object[]
	 */
	@SuppressWarnings("boxing")
	public Object[] getMaxGain(Object[] array) {
		Object[] result = new Object[2];
		double gain = 0;
		int index = -1;

		for (int i = 0; i < visable.length; i++) {
			if (!visable[i]) {
				//TODO ID3 change to C4.5
				double value = gainRatio(array, i, this.nodeIndex);
				System.out.println(value);
				if (gain < value) {
					gain = value;
					index = i;
				}
			}
		}
		result[0] = gain;
		result[1] = index;
		// TODO throws can't forecast this model exception
		if (index != -1) {
			visable[index] = true;
		}
		return result;
	}

	/**
	 * @param array
	 */
	public void createDecisionTree(Object[] array) {
		Object[] maxgain = getMaxGain(array);
		if (root == null) {
			root = new DecisionTreeNode();
			root.parentNode = null;
			root.parentArrtibute = null;
			root.arrtibutesArray = getArrtibutesArray(((Integer) maxgain[1])
					.intValue());
			root.nodeName = getColumnHeaderNameByIndex(((Integer) maxgain[1])
					.intValue());
			root.childNodesArray = new DecisionTreeNode[root.arrtibutesArray.length];
			insertDecisionTree(array, root);
		}
	}

	/**
	 * @param array
	 * @param parentNode
	 */
	public void insertDecisionTree(Object[] array, DecisionTreeNode parentNode) {
		String[] arrtibutes = parentNode.arrtibutesArray;
		for (int i = 0; i < arrtibutes.length; i++) {
			Object[] pickArray = pickUpAndCreateSubArray(array, arrtibutes[i],
					getColumnHeaderIndexByName(parentNode.nodeName));
			Object[] info = getMaxGain(pickArray);
			double gain = ((Double) info[0]).doubleValue();
			if (gain != 0) {
				int index = ((Integer) info[1]).intValue();
				DecisionTreeNode currentNode = new DecisionTreeNode();
				currentNode.parentNode = parentNode;
				currentNode.parentArrtibute = arrtibutes[i];
				currentNode.arrtibutesArray = getArrtibutesArray(index);
				currentNode.nodeName = getColumnHeaderNameByIndex(index);
				currentNode.childNodesArray = new DecisionTreeNode[currentNode.arrtibutesArray.length];
				parentNode.childNodesArray[i] = currentNode;
				insertDecisionTree(pickArray, currentNode);
			} else {
				DecisionTreeNode leafNode = new DecisionTreeNode();
				leafNode.parentNode = parentNode;
				leafNode.parentArrtibute = arrtibutes[i];
				leafNode.arrtibutesArray = new String[0];
				leafNode.nodeName = getLeafNodeName(pickArray,this.nodeIndex);
				leafNode.childNodesArray = new DecisionTreeNode[0];
				parentNode.childNodesArray[i] = leafNode;
			}
		}
	}

	/**
	 * @param node
	 */
	public void printDecisionTree(DecisionTreeNode node) {
		System.out.println(node.nodeName);
		DecisionTreeNode[] childs = node.childNodesArray;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				System.out.println(childs[i].parentArrtibute);
				printDecisionTree(childs[i]);
			}
		}
	}

	/**
	 * init data
	 * 
	 * @param dataArray
	 * @param index
	 */
	public void init(Object[] dataArray, int index) {
		this.nodeIndex = index;
		// init data
		this.columnHeaderArray = (String[]) dataArray[0];
		visable = new boolean[((String[]) dataArray[0]).length];
		for (int i = 0; i < visable.length; i++) {
			if (i == index) {
				visable[i] = true;
			} else {
				visable[i] = false;
			}
		}
	}

	/**
	 * @param array
	 * @param arrtibute
	 * @param index
	 * @return Object[]
	 */
	public Object[] pickUpAndCreateSubArray(Object[] array, String arrtibute,
			int index) {
		List<String[]> list = new ArrayList<String[]>();
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				list.add(strs);
			}
		}
		return list.toArray();
	}

	/**
	 * gain(A)
	 * 
	 * @param array
	 * @param index
	 * @return double
	 */
	public double gain(Object[] array, int index, int nodeIndex) {
		int[] counts = separateToSameValueArrays(array, nodeIndex);
		String[] arrtibutes = getArrtibutesArray(index);
		double infoD = infoD(array, counts);
		double infoaD = infoaD(array, index, nodeIndex, arrtibutes);
		return infoD - infoaD;
	}

	/**
	 * @param array
	 * @param nodeIndex
	 * @return
	 */
	public int[] separateToSameValueArrays(Object[] array, int nodeIndex) {
		String[] arrti = getArrtibutesArray(nodeIndex);
		int[] counts = new int[arrti.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			for (int j = 0; j < arrti.length; j++) {
				if (strs[nodeIndex].equals(arrti[j])) {
					counts[j]++;
				}
			}
		}
		return counts;
	}
	
	/**
	 * gainRatio = gain(A)/splitInfo(A)
	 * 
	 * @param array
	 * @param index
	 * @param nodeIndex
	 * @return
	 */
	public double gainRatio(Object[] array,int index,int nodeIndex){
		double gain = gain(array,index,nodeIndex);
		int[] counts = separateToSameValueArrays(array, index);
		double splitInfo = splitInfoaD(array,counts);
		if(splitInfo != 0){
			return gain/splitInfo;
		}
		return 0;
	}

	/**
	 * infoD = -E(pi*log2 pi)
	 * 
	 * @param array
	 * @param counts
	 * @return
	 */
	public double infoD(Object[] array, int[] counts) {
		double infoD = 0;
		for (int i = 0; i < counts.length; i++) {
			infoD += DecisionTreeUtil.info(counts[i], array.length);
		}
		return infoD;
	}

	/**
	 * splitInfoaD = -E|Dj|/|D|*log2(|Dj|/|D|)
	 * 
	 * @param array
	 * @param counts
	 * @return
	 */
	public double splitInfoaD(Object[] array, int[] counts) {
		return infoD(array, counts);
	}

	/**
	 * infoaD = E(|Dj| / |D|) * info(Dj)
	 * 
	 * @param array
	 * @param index
	 * @param arrtibutes
	 * @return
	 */
	public double infoaD(Object[] array, int index, int nodeIndex,
			String[] arrtibutes) {
		double sv_total = 0;
		for (int i = 0; i < arrtibutes.length; i++) {
			sv_total += infoDj(array, index, nodeIndex, arrtibutes[i],
					array.length);
		}
		return sv_total;
	}

	/**
	 * ((|Dj| / |D|) * Info(Dj))
	 * 
	 * @param array
	 * @param index
	 * @param arrtibute
	 * @param allTotal
	 * @return double
	 */
	public double infoDj(Object[] array, int index, int nodeIndex,
			String arrtibute, int allTotal) {
		String[] arrtibutes = getArrtibutesArray(nodeIndex);
		int[] counts = new int[arrtibutes.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}

		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				for (int k = 0; k < arrtibutes.length; k++) {
					if (strs[nodeIndex].equals(arrtibutes[k])) {
						counts[k]++;
					}
				}
			}
		}

		int total = 0;
		double infoDj = 0;
		for (int i = 0; i < counts.length; i++) {
			total += counts[i];
		}
		for (int i = 0; i < counts.length; i++) {
			infoDj += DecisionTreeUtil.info(counts[i], total);
		}
		return DecisionTreeUtil.getPi(total, allTotal) * infoDj;
	}

	/**
	 * @param index
	 * @return String[]
	 */
	@SuppressWarnings("unchecked")
	public String[] getArrtibutesArray(int index) {
		TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
		for (int i = 0; i < trainingArray.length; i++) {
			String[] strs = (String[]) trainingArray[i];
			set.add(strs[index]);
		}
		String[] result = new String[set.size()];
		return set.toArray(result);
	}

	/**
	 * @param index
	 * @return String
	 */
	public String getColumnHeaderNameByIndex(int index) {
		for (int i = 0; i < columnHeaderArray.length; i++) {
			if (i == index) {
				return columnHeaderArray[i];
			}
		}
		return null;
	}

	/**
	 * @param array
	 * @return String
	 */
	public String getLeafNodeName(Object[] array,int nodeIndex) {
		if (array != null && array.length > 0) {
			String[] strs = (String[]) array[0];
			return strs[nodeIndex];
		}
		return null;
	}

	/**
	 * @param name
	 * @return int
	 */
	public int getColumnHeaderIndexByName(String name) {
		for (int i = 0; i < columnHeaderArray.length; i++) {
			if (name.equals(columnHeaderArray[i])) {
				return i;
			}
		}
		return NOT_FOUND;
	}
}

package org.struct.decisiontree;

/**
 * @author Leon.Chen
 */
public class DecisionTreeNode {

	DecisionTreeNode parentNode;

	String parentArrtibute;

	String nodeName;

	String[] arrtibutesArray;

	DecisionTreeNode[] childNodesArray;

}

package org.struct.decisiontree;

/**
 * @author Leon.Chen
 */
public class DecisionTreeUtil {

	/**
	 * entropy:Info(T)=(i=1...k)pi*log(2)pi
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double info(int x, int total) {
		if (x == 0) {
			return 0;
		}
		double x_pi = getPi(x, total);
		return -(x_pi * logYBase2(x_pi));
	}

	/**
	 * log2y
	 * 
	 * @param y
	 * @return double
	 */
	public static double logYBase2(double y) {
		return Math.log(y) / Math.log(2);
	}

	/**
	 * pi=|C(i,d)|/|D|
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double getPi(int x, int total) {
		return x / (double) total;
	}

}


package org.struct.decisiontree;

import java.util.Comparator;

/**
 * @author Leon.Chen
 * 
 */
@SuppressWarnings("unchecked")
public class SequenceComparator implements Comparator {

	public int compare(Object o1, Object o2) throws ClassCastException {
		String str1 = (String) o1;
		String str2 = (String) o2;
		return str1.compareTo(str2);
	}
}
4
0
分享到:
评论
2 楼 longay00 2010-11-09  
不错,很牛,不过没有原理与实验很难相信它的正确性。从代码上看,博主的编程能力不错。谢谢分享
1 楼 sdscx0530 2010-04-21  
老大,我在等着你的原理。

相关推荐

    基于Matlab实现决策树C4.5算法(源码+数据+教程).rar

    1、资源内容:基于Matlab实现决策树C4.5算法(源码+数据+教程).rar 2、适用人群:计算机,电子信息工程、数学等专业的大学生课程设计、期末大作业或毕业设计,作为“参考资料”使用。 3、解压说明:本资源需要电脑...

    决策树C4.5算法的c++实现

    C4.5算法是决策树构建的一种经典方法,由Ross Quinlan开发,是对ID3算法的改进版本。在C4.5算法中,它引入了连续属性处理、剪枝策略以及更优的特征选择标准,提高了决策树的泛化能力和准确性。 本项目是一个C++实现...

    C4.5算法.rar_C4.5算法matlab_c4.5 matlab_matlab决策树C4.5_决策树C4.5算法

    数据挖掘中的决策树C4.5算法的实现,用matlab实现

    c4.5.rar_C++决策树C4.5_C4.5决策树_c4.5算法_决策树c4.5_决策树算法

    C4.5决策树是一种广泛应用于机器学习领域的分类算法,由Ross Quinlan在ID3算法的基础上改进而成。本文将详细介绍C4.5算法的基本原理、实现方式以及在C++中的应用。 C4.5算法的核心是通过信息增益率来选择最优特征...

    决策树C4.5算法_c4.5_决策树

    014____决策树C4.5算法的文件可能包含了C4.5算法的实现代码、教程或者实例分析,对于初学者来说,这些资源可以帮助理解算法原理,动手实践构建和优化决策树模型。学习这个算法,不仅需要掌握基本概念,还要理解如何...

    决策树C4.5算法matlab源代码(完美运行).zip_C4.5算法matlab_c4.5_matlab 决策树_决策树_决策

    可以完美的实现用于统计学习的算法C4.5分类,完整的matlab程序

    决策树c4.5算法

    数据挖掘十大算法:决策树c4.5算法,可以快速了解c4.5算法原理

    决策树c4.5算法实现

    决策树c4.5算法简单实现,ID3改进,用C++编程实现

    论文研究-基于决策树C4.5集成算法的图像自动标注.pdf

    针对决策树C4.5集成算法中基分类器多样性差的问题,提出了修正矩阵correction matrix-C4.5(CMC4.5)集成学习算法,并将其应用于图像自动标注。该算法首先对特征子集进行多样性处理,然后通过构造修正矩阵依次得到基...

    决策树C4.5算法,C语言写的,很好很强大

    c语言的决策树C4.5。数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。决策树的分类模型,以其特有的优点广为人们采用。首先,决策树方法结构简单,便于人们理解;其次,决策树...

    C4.5决策树算法

    C4.5是一种用于分类决策树的算法。

    决策树C4.5。C++代码加测试集

    在C++中实现决策树C4.5,主要涉及以下几个核心步骤: 1. **数据预处理**:首先,我们需要将数据集转换为C++可以处理的格式,通常是以结构化的数组或类的形式存储。数据集应该包含特征和对应的类别标签。 2. **熵与...

    C4.5算法源码 c语言实现

    总结,C4.5算法是机器学习领域中一个重要的决策树学习算法,它的C语言实现使得在资源有限的环境中也能进行高效的数据分析。通过理解其工作原理并实际操作源代码,可以深入掌握决策树的学习和预测过程,这对于理解和...

    决策树c4.5算法随碟附送地方

    C4.5算法是决策树构建的一种经典方法,由Ross Quinlan于1993年提出,是对早期ID3算法的改进。C4.5在处理连续属性和不纯度度量上更为高效,使其在实际应用中更为实用。 C4.5算法的核心思想是通过信息增益或信息增益...

    机器学习C4.5算法C语言实现

    总之,C4.5算法在机器学习中扮演着重要角色,它的C语言实现涉及到数据预处理、决策树构建、连续属性处理、缺失值处理、剪枝策略以及模型评估等多个方面。理解和掌握这些关键点对于理解和实现C4.5算法至关重要。在...

    决策树C4.5

    决策树C4.5是一种广泛应用于数据挖掘领域的分类算法,由Ross Quinlan在ID3算法的基础上改进而成。C4.5是“Classifier System 4.5”的缩写,它以其高效、易于理解和实现的特点而备受青睐。在这个讨论中,我们将深入...

Global site tag (gtag.js) - Google Analytics