`
cyzhang999
  • 浏览: 26857 次
  • 性别: Icon_minigender_1
  • 来自: 北京
社区版块
存档分类
最新评论

决策树java实现(转)

 
阅读更多

一直看决策树的原理,但没实现过,所以找个代码看看。

来源:http://www.cnblogs.com/zhangchaoyang/articles/2196631.html

格式可能不太好,可参考原博客。

 

先上问题吧,我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据:sunny,cool,high,TRUE,判断一下会不会去打球。

table 1  

 

outlook temperature humidity windy play
sunny hot high FALSE no
sunny hot high TRUE no
overcast hot high FALSE yes
rainy mild high FALSE yes
rainy cool normal FALSE yes
rainy cool normal TRUE no
overcast cool normal TRUE yes
sunny mild high FALSE no
sunny cool normal FALSE yes
rainy mild normal FALSE yes
sunny mild normal TRUE yes
overcast mild high TRUE yes
overcast hot normal FALSE yes
rainy mild high TRUE no

这个问题当然可以用朴素贝叶斯法求解,分别计算在给定天气条件下打球和不打球的概率,选概率大者作为推测结果。

现在我们使用ID3归纳决策树的方法来求解该问题。

预备知识:信息熵

熵是无序性(或不确定性)的度量指标。假如事件A的全概率划分是(A1,A2,...,An),每部分发生的概率是(p1,p2,...,pn),那信息熵定义为:

通常以2为底数,所以信息熵的单位是bit。

补充两个对数去处公式:

ID3算法

构造树的基本想法是随着树深度的增加,节点的熵迅速地降低。熵降低的速度越快越好,这样我们有望得到一棵高度最矮的决策树。

在没有给定任何天气信息时,根据历史数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为:

属性有4个:outlook,temperature,humidity,windy。我们首先要决定哪个属性作树的根节点。

对每项指标分别统计:在不同的取值下打球和不打球的次数。

table 2

 

 

 

下面我们计算当已知变量outlook的值时,信息熵为多少。

outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971

outlook=overcast时,entropy=0

outlook=rainy时,entropy=0.971

而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693

这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247

同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。

gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。

接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。

依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。

Java实现

最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。

实验用的数据文件:

@relation weather.symbolic 
  
@attribute outlook {sunny, overcast, rainy} 
@attribute temperature {hot, mild, cool} 
@attribute humidity {high, normal} 
@attribute windy {TRUE, FALSE} 
@attribute play {yes, no} 
  
@data 
sunny,hot,high,FALSE,no 
sunny,hot,high,TRUE,no 
overcast,hot,high,FALSE,yes 
rainy,mild,high,FALSE,yes 
rainy,cool,normal,FALSE,yes 
rainy,cool,normal,TRUE,no 
overcast,cool,normal,TRUE,yes 
sunny,mild,high,FALSE,no 
sunny,cool,normal,FALSE,yes 
rainy,mild,normal,FALSE,yes 
sunny,mild,normal,TRUE,yes 
overcast,mild,high,TRUE,yes 
overcast,hot,normal,FALSE,yes 
rainy,mild,high,TRUE,no 

 

程序代码:

package dt; 
  
import java.io.BufferedReader; 
import java.io.File; 
import java.io.FileReader; 
import java.io.FileWriter; 
import java.io.IOException; 
import java.util.ArrayList; 
import java.util.Iterator; 
import java.util.LinkedList; 
import java.util.List; 
import java.util.regex.Matcher; 
import java.util.regex.Pattern; 
  
import org.dom4j.Document; 
import org.dom4j.DocumentHelper; 
import org.dom4j.Element; 
import org.dom4j.io.OutputFormat; 
import org.dom4j.io.XMLWriter; 
  
public class ID3 { 
    private ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称 
    private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值 
    private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据 
    int decatt; // 决策变量在属性集中的索引 
    public static final String patternString = "@attribute(.*)[{](.*?)[}]"; 
  
    Document xmldoc; 
    Element root; 
  
    public ID3() { 
        xmldoc = DocumentHelper.createDocument(); 
        root = xmldoc.addElement("root"); 
        root.addElement("DecisionTree").addAttribute("value", "null"); 
    } 
  
    public static void main(String[] args) { 
        ID3 inst = new ID3(); 
        inst.readARFF(new File("/home/orisun/test/weather.nominal.arff")); 
        inst.setDec("play"); 
        LinkedList<Integer> ll=new LinkedList<Integer>(); 
        for(int i=0;i<inst.attribute.size();i++){ 
            if(i!=inst.decatt) 
                ll.add(i); 
        } 
        ArrayList<Integer> al=new ArrayList<Integer>(); 
        for(int i=0;i<inst.data.size();i++){ 
            al.add(i); 
        } 
        inst.buildDT("DecisionTree", "null", al, ll); 
        inst.writeXML("/home/orisun/test/dt.xml"); 
        return; 
    } 
  
    //读取arff文件,给attribute、attributevalue、data赋值 
    public void readARFF(File file) { 
        try { 
            FileReader fr = new FileReader(file); 
            BufferedReader br = new BufferedReader(fr); 
            String line; 
            Pattern pattern = Pattern.compile(patternString); 
            while ((line = br.readLine()) != null) { 
                Matcher matcher = pattern.matcher(line); 
                if (matcher.find()) { 
                    attribute.add(matcher.group(1).trim()); 
                    String[] values = matcher.group(2).split(","); 
                    ArrayList<String> al = new ArrayList<String>(values.length); 
                    for (String value : values) { 
                        al.add(value.trim()); 
                    } 
                    attributevalue.add(al); 
                } else if (line.startsWith("@data")) { 
                    while ((line = br.readLine()) != null) { 
                        if(line=="") 
                            continue; 
                        String[] row = line.split(","); 
                        data.add(row); 
                    } 
                } else { 
                    continue; 
                } 
            } 
            br.close(); 
        } catch (IOException e1) { 
            e1.printStackTrace(); 
        } 
    } 
  
    //设置决策变量 
    public void setDec(int n) { 
        if (n < 0 || n >= attribute.size()) { 
            System.err.println("决策变量指定错误。"); 
            System.exit(2); 
        } 
        decatt = n; 
    } 
    public void setDec(String name) { 
        int n = attribute.indexOf(name); 
        setDec(n); 
    } 
  
    //给一个样本(数组中是各种情况的计数),计算它的熵 
    public double getEntropy(int[] arr) { 
        double entropy = 0.0; 
        int sum = 0; 
        for (int i = 0; i < arr.length; i++) { 
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); 
            sum += arr[i]; 
        } 
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); 
        entropy /= sum; 
        return entropy; 
    } 
  
    //给一个样本数组及样本的算术和,计算它的熵 
    public double getEntropy(int[] arr, int sum) { 
        double entropy = 0.0; 
        for (int i = 0; i < arr.length; i++) { 
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); 
        } 
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); 
        entropy /= sum; 
        return entropy; 
    } 
  
    public boolean infoPure(ArrayList<Integer> subset) { 
        String value = data.get(subset.get(0))[decatt]; 
        for (int i = 1; i < subset.size(); i++) { 
            String next=data.get(subset.get(i))[decatt]; 
            //equals表示对象内容相同,==表示两个对象指向的是同一片内存 
            if (!value.equals(next)) 
                return false; 
        } 
        return true; 
    } 
  
    // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵 
    public double calNodeEntropy(ArrayList<Integer> subset, int index) { 
        int sum = subset.size(); 
        double entropy = 0.0; 
        int[][] info = new int[attributevalue.get(index).size()][]; 
        for (int i = 0; i < info.length; i++) 
            info[i] = new int[attributevalue.get(decatt).size()]; 
        int[] count = new int[attributevalue.get(index).size()]; 
        for (int i = 0; i < sum; i++) { 
            int n = subset.get(i); 
            String nodevalue = data.get(n)[index]; 
            int nodeind = attributevalue.get(index).indexOf(nodevalue); 
            count[nodeind]++; 
            String decvalue = data.get(n)[decatt]; 
            int decind = attributevalue.get(decatt).indexOf(decvalue); 
            info[nodeind][decind]++; 
        } 
        for (int i = 0; i < info.length; i++) { 
            entropy += getEntropy(info[i]) * count[i] / sum; 
        } 
        return entropy; 
    } 
  
    // 构建决策树 
    public void buildDT(String name, String value, ArrayList<Integer> subset, 
            LinkedList<Integer> selatt) { 
        Element ele = null; 
        @SuppressWarnings("unchecked") 
        List<Element> list = root.selectNodes("//"+name); 
        Iterator<Element> iter=list.iterator(); 
        while(iter.hasNext()){ 
            ele=iter.next(); 
            if(ele.attributeValue("value").equals(value)) 
                break; 
        } 
        if (infoPure(subset)) { 
            ele.setText(data.get(subset.get(0))[decatt]); 
            return; 
        } 
        int minIndex = -1; 
        double minEntropy = Double.MAX_VALUE; 
        for (int i = 0; i < selatt.size(); i++) { 
            if (i == decatt) 
                continue; 
            double entropy = calNodeEntropy(subset, selatt.get(i)); 
            if (entropy < minEntropy) { 
                minIndex = selatt.get(i); 
                minEntropy = entropy; 
            } 
        } 
        String nodeName = attribute.get(minIndex); 
        selatt.remove(new Integer(minIndex)); 
        ArrayList<String> attvalues = attributevalue.get(minIndex); 
        for (String val : attvalues) { 
            ele.addElement(nodeName).addAttribute("value", val); 
            ArrayList<Integer> al = new ArrayList<Integer>(); 
            for (int i = 0; i < subset.size(); i++) { 
                if (data.get(subset.get(i))[minIndex].equals(val)) { 
                    al.add(subset.get(i)); 
                } 
            } 
            buildDT(nodeName, val, al, selatt); 
        } 
    } 
  
    // 把xml写入文件 
   public void writeXML(String filename) { 
        try { 
            File file = new File(filename); 
            if (!file.exists()) 
                file.createNewFile(); 
            FileWriter fw = new FileWriter(file); 
            OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式 
            XMLWriter output = new XMLWriter(fw, format); 
            output.write(xmldoc); 
            output.close(); 
        } catch (IOException e) { 
            System.out.println(e.getMessage()); 
        } 
    } 
} 

 

最终生成的文件如下:

用图形象地表示就是:

  • 大小: 38 KB
分享到:
评论

相关推荐

    决策树Java代码实现

    决策树的Java实现涉及多个关键概念和技术点,包括但不限于节点的定义、数据结构的选择、属性的选择等。通过上述分析,我们不仅了解了决策树的基本原理,还深入了解了其在Java中的具体实现方式。这对于理解和开发实际...

    决策树ID3算法(Java实现)

    总的来说,ID3算法是理解和实现机器学习基础的重要步骤,它的Java实现可以帮助开发者更深入地掌握数据分类和决策树的构建原理。通过分析和调试这个代码,可以提高对决策树算法的理解,为进一步探索更复杂的算法如C...

    决策树Java实现

    在Java中实现决策树,我们可以选择自定义算法,如ID3(Iterative Dichotomiser 3),或者使用现有的机器学习库,如Weka、Deeplearning4j等。本篇文章将深入探讨ID3算法以及如何在Java中实现它。 ID3算法是基于信息...

    决策树java

    下面我们将深入探讨决策树的基本原理、Java实现的关键步骤以及如何处理试验数据。 决策树的构建主要基于以下步骤: 1. **选择特征**:在每个节点上,决策树算法会选择一个最优特征进行划分。这个最优特征通常是...

    基于java实现的决策树代码

    决策树是一种广泛应用于人工智能、机器学习领域的算法,它通过学习数据的特征来进行分类或回归分析。在本项目中,我们关注的是使用Java语言实现的决策树代码,这将涵盖数据处理、模型构建、预测以及性能评估等多个...

    决策树算法(Java实现)

    决策树生成算法的Java实现,可能还有一些BUG,没有做仔细校验与测试,完成主要功能。决策树具体详解移步:http://blog.csdn.net/adiaixin123456/article/details/50573849 项目的目录结构分为四个文件夹algorithm,...

    Java实现的决策树算法完整实例

    Java实现的决策树算法完整实例 决策树算法是机器学习领域中的一种常见算法,主要用于分类和预测。Java实现的决策树算法完整实例中,主要介绍了决策树的概念、原理,并结合完整实例形式分析了Java实现决策树算法的...

    java实现的决策树算法

    总之,Java实现的决策树算法是一个涉及数据处理、特征选择、节点划分和模型优化的过程。理解并掌握这些核心概念,将有助于你在实际项目中构建高效、可维护的决策树模型。通过不断迭代和调整,你可以根据特定需求优化...

    Java实现基于C4.5算法的决策树,实现银行贷款风险预测

    本项目通过Java编程语言实现了基于C4.5算法的决策树,用于预测银行贷款的风险。 首先,让我们深入理解C4.5算法的核心概念。C4.5选择最优特征进行分裂时,基于信息增益率(Gini指数或熵)来度量数据集的纯度。信息...

    C4.5决策树(Java实现)

    C4.5决策树是一种广泛应用于分类问题的机器学习算法,由Ross Quinlan于1993年提出,是对之前ID3算法的改进...通过这个Java实现,初学者和专业人士都能更深入地理解C4.5决策树的工作原理,并将其应用于实际的分类任务。

    使用决策树实现分类

    在Python中,我们可以使用scikit-learn库来实现决策树分类。以下是一段基本的代码示例: ```python from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from ...

    id3 决策树 Java实现

    6. **Java实现细节**:在Java中,可以使用面向对象编程思想来设计决策树类,包括节点类(包含特征、子节点等属性)和决策树类(包含构建、预测方法)。同时,需要处理好数据结构,如使用ArrayList或LinkedList存储...

    java实现的决策树算法(ID3)

    这个Java实现的ID3算法提供了一个在编程环境中应用决策树的方法,包括训练模型、构建决策树、以及对新数据进行预测。 ID3算法的核心思想是基于信息熵和信息增益来选择最优特征。熵是衡量数据纯度的一个指标,信息...

    决策树的java实现

    BI中决策树ID3算法的java实现,无界面,命令行方式

    java实现决策树ID3算法

    本文介绍了如何使用Java实现决策树ID3算法,包括算法原理、实现细节及代码解析。通过这种方式,可以更好地理解ID3算法的工作机制,并掌握其实现方法。此外,文件读取部分也是实现决策树算法的重要组成部分,确保了...

    决策树ID3JAVA实现

    在Java中实现ID3决策树,我们需要关注以下几个关键步骤: 1. 数据预处理:将原始数据转换为适合算法处理的格式,如二维数组或列表,其中包含实例的特征和对应的类别标签。 2. 实现信息熵和信息增益的计算函数。 3. ...

    决策树 java实现

    java实现决策树算法 id3 数据挖掘领域经典算法应用广泛

    DecisionTree决策树数据挖掘算法的实现(Java)

    4. Java实现决策树: 在Java中,可以使用各种库如Weka、Apache Mahout或自定义代码实现决策树。自定义实现通常包括以下组件: - 数据结构:用于存储数据集和决策树结构,如ArrayList、HashMap等。 - 分类器:包含...

    java 决策树Demo1

    本教程将基于`java 决策树Demo1`来深入探讨决策树的基本概念、工作原理以及如何在Java中实现。 决策树是一种图形模型,它通过一系列的if-else条件判断来模拟决策过程。每个内部节点代表一个特征或属性测试,每个...

    机器学习 决策树算法(ID3)java实现

    在这个“机器学习 决策树算法(ID3)java实现”的项目中,我们主要关注以下几个关键知识点: 1. **信息熵**:信息熵是衡量数据集纯度的一个指标,表示不确定性或信息的平均量。对于一个二分类问题,如果所有样本都...

Global site tag (gtag.js) - Google Analytics