`
Marshal_R
  • 浏览: 133203 次
  • 性别: Icon_minigender_1
  • 来自: 广州
社区版块
存档分类
最新评论

人工智能应用实例:决策树

阅读更多

人工智能应用实例:决策树

 

 

场景设置

    数据来自于1984年美国议会投票记录,共和党、民主党的议员对16个重要问题进行投票,投票选项有voted for, paired for, announced for, voted against, paired against, announced against, voted present, voted present to avoid conflict of interest, did not vote or otherwise make a position known,而数据集中进行了简化,只有y、n、?三个选项,分别表示支持、反对、中立。

 

    数据集中共有435条记录,其中267条来自民主党,其它168条来自共和党,每条记录均有17个属性,罗列如下:

 

1. Class Name: 2 (democrat, republican)

2. handicapped-infants: 2 (y,n)

3. water-project-cost-sharing: 2 (y,n)

4. adoption-of-the-budget-resolution: 2 (y,n)

5. physician-fee-freeze: 2 (y,n)

6. el-salvador-aid: 2 (y,n)

7. religious-groups-in-schools: 2 (y,n)

8. anti-satellite-test-ban: 2 (y,n)

9. aid-to-nicaraguan-contras: 2 (y,n)

10. mx-missile: 2 (y,n)

11. immigration: 2 (y,n)

12. synfuels-corporation-cutback: 2 (y,n)

13. education-spending: 2 (y,n)

14. superfund-right-to-sue: 2 (y,n)

15. crime: 2 (y,n)

16. duty-free-exports: 2 (y,n)

17. export-administration-act-south-africa: 2 (y,n)

 

    要求通过其中一部分数据生成一棵决策树,即通过党派名字之外的16个属性即能判断属于哪个党派,最后利用剩下的数据检验决策树的准确性。

 

 

决策树

    1、基本介绍

    首先,我们来了解一下什么是决策树。假设有这么一家餐馆,餐馆经理想知道以下几个因素对顾客是否会到此就餐的影响:目前就餐人数、顾客饿不饿、菜式、今天是否周末。那么,很有可能经理对几百个人进行调查,发现这样一些规律:假如餐馆是空的,没有人正在就餐,顾客感觉不对劲就不会选择到此就餐,如果有一些人在就餐,顾客感觉还行吧就会选择在此就餐,而如果餐馆坐满了,那么顾客就要想一想了,假如自己不饿的话,那还是不在这里就餐了,否则,就要看你餐馆有什么菜色了......

 

    据此,可能存在如下的决策树,它的每个非叶子节点表示一个属性,叶子节点表示顾客是否会到此就餐。有了这棵决策树,对于一个新来的顾客,我们只要从根结点从上往下遍历到叶子节点便能决定这个顾客是否会选择在这里就餐。

 
图1 就餐决策树

 
    当然,这里的决策树还可以有其它的形态,比如它的根结点是顾客饿不饿,而假设顾客简单到只要饿了就会选择在此就餐,不饿就不在此就餐,那最终的决策树只需要一个顾客饿不饿的根结点和两个是否在此就餐的叶子节点就足够了。
 
    因此,我们的任务也就是决策树算法的目标就是,尽可能地先选具有最大区分度的属性对数据进行分类,比如像上面所讲,根据顾客饿不饿把顾客分成了两堆,一堆就餐,一堆不就餐,这样每堆都是纯的就是最大的区分度。这里必须注意一些概念,否则容易混淆,搞得糊涂不清。区分度是针对目标属性的,比如上面的顾客就不就餐,而分类是针对当前选定的属性,如上边的根据顾客饿不饿把顾客分成两堆。当一堆数据是纯的,就不需要继续分类了,否则就要选择其它的属性继续对它进行分类。
 

    2、决策树算法

     首先让我们来看一下决策树算法的大致框架。

图2 决策树算法
 
    还是以上面的就餐决策树为例,假设我们一共收集了420条数据,当分类到type这个分支的时候只剩下240条数据了(其中150条数据是就餐的,而另外90条数据是不就餐的),然后这240条数据再根据type分类成4堆,每堆数量依次为0、60、100、80。
 
    显然第1堆已经是空的了,这时候对应的就是决策树算法中的第一个if语句,这个空堆要被设置为叶子节点,叶子节点的值根据它父节点数据的特征来设置,我们发现父节点type的240条数据中更多属于就餐类型的(150>90),因此叶子节点的值被设置为T,即选择在此就餐。
 
    我们继续看第2堆,假设这60条数据都是属于不就餐类型的,这样的数据就是纯的,因此这一堆也可以被设置为叶子节点,它的值就是目标属性的分类,在这里就是F,即选择不在此就餐,这样对应于第二条if语句。
 
    再假设第3堆100条数据经过Fri/Sat分类为两堆,右边那堆有60条数据,其中40条属于就餐类型的,其它20条属于不就餐类型的,这时候我们发现,所有的4个属性都还不足以把这堆分类成纯数据,对它已经不能继续分类了,于是我们必须把它设置为叶子节点,叶子节点的值只能根据这堆数据的特征来设置。因为这60条数据中更多是属于就餐类型的(40>20),所以叶子节点的值为T,即选择在此就餐,这样对应于第三条if语句。
 
    在所有其它情况之下,就要选一个最恰当的属性对数据进行分类,那么必须有一个衡量标准,怎样的属性才是一个好的属性?这里就涉及到信息熵的概念了。信息熵反映了一个事件或者一个格局包含信息的多少。在决策树算法中,信息熵也是针对目标属性的。以我们上边的就餐问题为例,如果一堆数据全都属于就餐或者不就餐类型的,那么这堆数据信息熵为0,而如果就餐和不就餐类型分别占了一半,这时候的信息熵就是最大的,信息熵反映了数据的参差程度。对一堆数据求信息熵的公式如下:
    其中,q为某一类的数据(比如就餐类型,或者不就餐类型)占所有数据的比例。
 
    而要求一个属性的信息熵,我们要先用这个属性将数据分类为若干堆,然后对每一堆求信息熵,最后加权相加得到的值即为这个属性的信息熵,权值为(子堆数据集大小/父节点数据集大小),计算公式如下:
    其中,p、n分别表示目标属性的两个分类的数据的多少。
 
    显然,生成树算法就是要枚举所有的属性,从中选出使得信息熵最小的属性作为当前节点,据此属性把数据分类为若干堆,然后递归对各个堆继续进行分类。
 
 

代码实现

    只要掌握了上面的算法原理,我相信要写出这个投票记录的生成树算法应该是不难的。废话少说,直接看代码吧,代码本身就是最好的解释!
    注:在文档末尾可以下载到数据集和所有代码文件。
 
// record.h
/*
 * 选举记录的数据结构
 * attr[0]为政党名字,'A'对应“democrat”,'B'对应"republican"
 */
class Record {
public:
	Record(char*);
	char get_attr(int);
private:
	char attr[17];
};
 
// record.cpp
#include <cstring>
#include "record.h"

using namespace std;

/*
 * Record类的构造函数
 * line字符串格式:政党名字,16个逗号分隔的属性值(y, n, ?)
 */
Record::Record(char *line) {
	char delims[] = ",";
	char* result = NULL;

	result = strtok(line, delims);
	if (!strcmp(result, "democrat"))
		attr[0] = 'A';
	else
		attr[0] = 'B';

	int i = 1;
	result = strtok(NULL, delims);
	while (result != NULL) {
		attr[i++] = result[0];
		result = strtok(NULL, delims);
	}
}

char Record::get_attr(int i) {
	return attr[i];
}
 
// main.cpp
/*
 * =========================================================================
 *
 *       Filename:  main.cpp
 *
 *    Description:  决策树算法
 *
 *        Version:  1.0
 *        Created:  2014年12月06日 17时38分47秒
 *       Revision:  none
 *       Compiler:  gcc
 *
 *         Author:  阮仕海
 *   Organization:  AI 选修班第8组
 *
 * =========================================================================
 */

#include <iostream>
#include <fstream>
#include <ctime>
#include <stdlib.h>
#include <vector>
#include <cmath>
#include "record.h"

// 训练集占所有数据的比例
#define RATE 0.8
// 包括政党名字在内的属性值的个数
#define ATTR_AMOUNT 17

using namespace std;

/* 
 * 决策树每个节点的数据结构
 * attr: 表示第几个属性,范围1-16,‘A’,'B'
 * 	'A','B'表示叶子节点,分别对应政党democrat和republican
 * ptr: 分别对应属性值为'y','?','n'的子树
 */
struct Node {
	Node() {
		ptr[0] = NULL;
		ptr[1] = NULL;
		ptr[2] = NULL;
	}
	int attr;
	Node *ptr[3];
};

/* 
 * 判断给定选举记录是否同属一个政党
 */
bool has_same_class(vector<Record> &examples) {
	for (int i=0; i<examples.size()-1; i++) {
		if (examples[i].get_attr(0) != examples[i+1].get_attr(0))
			return false;
	}
	return true;
}

/* 
 * 返回给定选举记录中大部分记录所属政党的标志
 * 'A': democrat
 * 'B': republican
 */
int get_majority_value(vector<Record> &examples) {
	int v[2] = {0};
	for (int i=0; i<examples.size(); i++) {
		v[examples[i].get_attr(0)-'A']++;
	}
	if (v[0] > v[1])
		return 'A';
	return 'B';
}

/* 
 * 以给定属性对给定记录分类得到结果的信息量
 */
double cal_arg_remainder(vector<Record> &examples, int arg) {
	double remainder = 0.0; // 信息量初始为0
	vector<Record> sub_examples[3];
	int la = 0, lb = 0, ma = 0, mb = 0, ra = 0, rb = 0;

	// 根据给定属性的属性值'y','?','n'把记录分为3类
	for (int i=0; i<examples.size(); i++) {
		if (examples[i].get_attr(arg) == 'y') {
			sub_examples[0].push_back(examples[i]);
			if (examples[i].get_attr(0) == 'A')
				la++; // 统计此分类中democrat政党的记录数量
			else
				lb++; // 统计此分类中republican政党的记录数量
		}
		else if (examples[i].get_attr(arg) == '?') {
			sub_examples[1].push_back(examples[i]);
			if (examples[i].get_attr(0) == 'A')
				ma++;
			else
				mb++;
		}
		else if (examples[i].get_attr(arg) == 'n') {
			sub_examples[2].push_back(examples[i]);
			if (examples[i].get_attr(0) == 'A')
				ra++;
			else
				rb++;
		}
	}

	int l_len = sub_examples[0].size();
	int m_len = sub_examples[1].size();
	int r_len = sub_examples[2].size();
	int len = examples.size();

	// 对每个分类加权求信息量,若该分类所有记录同属一个政党,则信息量为0
	if (la!=0 && lb!=0)
		remainder -= ((double)l_len/len)*
			((double)la/l_len*(log10((double)la/l_len)/log10(2))+
			 (double)lb/l_len*(log10((double)lb/l_len)/log10(2)));
	if (ma!=0 && mb!=0)
		remainder -= ((double)m_len/len)*
			((double)ma/m_len*(log10((double)ma/m_len)/log10(2))+
			 (double)mb/m_len*(log10((double)mb/m_len)/log10(2)));
	if (ra!=0 && rb!=0)
		remainder -= ((double)r_len/len)*
			((double)ra/r_len*(log10((double)ra/r_len)/log10(2))+
			 (double)rb/r_len*(log10((double)rb/r_len)/log10(2)));

	return remainder;
}

/* 
 * 决策树算法
 * 根据给定记录集和给定属性集生成决策树,并且返回
 * 决策树非叶子节点为属性,叶子节点为政党标志
 * majority_value为父节点记录集大部分记录所属政党的标志
 */
Node *dcl_dfs(vector<Record> &examples, vector<int> &attr_set, int majority_value) {
	Node *head = new Node();
	int max_value;

	// 当前记录集大部分记录所属政党的标志
	max_value = get_majority_value(examples);

	// 记录集为空
	if (examples.empty()) {
		head->attr = majority_value;
		return head;
	}

	// 记录集所有记录同属一个政党
	if (has_same_class(examples)) {
		head->attr = examples[0].get_attr(0);
		return head;
	}

	// 属性值为空
	if (attr_set.empty()) {
		head->attr = max_value;
		return head;
	}

	int arg;
	double remainder = 1000.0;
	// 获取最佳属性
	for (int i=0; i<attr_set.size(); i++) {
		double cur_value = cal_arg_remainder(examples, attr_set[i]);
		if (cur_value < remainder) {
			arg = attr_set[i];
			remainder = cur_value;
		}
	}

	vector<Record> sub_examples[3];
	// 根据最佳属性对记录集进行分类
	for (int i=0; i<examples.size(); i++) {
		if (examples[i].get_attr(arg) == 'y')
			sub_examples[0].push_back(examples[i]);
		else if (examples[i].get_attr(arg) == '?')
			sub_examples[1].push_back(examples[i]);
		else if (examples[i].get_attr(arg) == 'n')
			sub_examples[2].push_back(examples[i]);
	}

	vector<int> sub_attr_set;
	// 获取子树属性集
	for (int i=0; i<attr_set.size(); i++) {
		if (attr_set[i] != arg)
			sub_attr_set.push_back(attr_set[i]);
	}

	head->attr = arg;
	// 深度优先递归生成子树
	head->ptr[0] = dcl_dfs(sub_examples[0], sub_attr_set, max_value);
	head->ptr[1] = dcl_dfs(sub_examples[1], sub_attr_set, max_value);
	head->ptr[2] = dcl_dfs(sub_examples[2], sub_attr_set, max_value);

	return head;
}

/* 
 * 分类算法
 * 对一条记录,根据决策树进行分类,返回该记录的政党标志
 */
int classify(Node *head, Record &rec) {
	Node *cur = head;

	// 自顶向下遍历到叶子节点
	while (cur->ptr[0] != NULL) {
		if (rec.get_attr(cur->attr) == 'y')
			cur = cur->ptr[0];
		else if (rec.get_attr(cur->attr) == '?')
			cur = cur->ptr[1];
		else if (rec.get_attr(cur->attr) == 'n')
			cur = cur->ptr[2];
	}

	return cur->attr;
}

int main() {
	ifstream fin("dataset.txt");
	char line[50];
	vector<Record> test_set, training_set;
	vector<int> attr_set;
	Node *head = NULL;

	if (!fin.is_open())
		exit(1);

	// 设置随机数种子
	srand((unsigned)time(NULL));
	// 从文件读入记录按照给定概率生成训练集和测试集
	while (!fin.eof()) {
		fin.getline(line, 50);
		Record rec(line);
		if ((double)(rand()%1000)/1000.0 < RATE)
			training_set.push_back(rec);
		else
			test_set.push_back(rec);
	}
	fin.close();

	// 生成属性集,1-16
	for (int i=1; i<ATTR_AMOUNT; i++) {
		attr_set.push_back(i);
	}
	// 生成决策树
	head = dcl_dfs(training_set, attr_set, 0);

	int t_count = 0;
	// 对测试记录集进行分类,统计分类正确的记录数量
	for (int i=0; i<test_set.size(); i++) {
		int class_name = classify(head, test_set[i]);
		if (class_name == (int)test_set[i].get_attr(0))
			t_count++;
	}

	cout << "Training set size: " << training_set.size() << endl
		<< "Test set size: " << test_set.size() << endl
		<< endl
		<< "Acuracy amount: " << t_count << endl
		<< "Acuracy ratio: " << (double)t_count/test_set.size() << endl;

	return 0;
}
 
    运行结果:

   
 
 
  • 大小: 32.8 KB
  • 大小: 72.2 KB
  • 大小: 3.8 KB
  • 大小: 7.3 KB
  • 大小: 5.6 KB
分享到:
评论

相关推荐

    人工智能实验报告:决策树、循环神经网络、遗传算法、A*算法、归结原理

    在本篇人工智能实验报告中,我们深入探讨了五个核心主题:决策树、循环神经网络、遗传算法、A*算法以及归结原理。这些是人工智能领域中的关键算法和技术,它们在解决复杂问题时扮演着重要角色。 首先,让我们来了解...

    人工智能决策树

    ### 人工智能决策树 #### 一、决策树基础概述 决策树是一种常用的数据挖掘和机器学习方法,通过构建树状模型来进行预测分析。在决策树中,每一个内部节点表示一个特征上的测试,每个分支代表一个测试输出,而叶...

    决策树课件,简介及应用

    决策树是一种非常重要的机器学习算法,它广泛应用于数据挖掘、人工智能、机器学习等领域。继承自概念学习系统 CLS,决策树方法发展到 ID3 方法,然后演化为能处理连续属性的 C4.5。另外, CART 和 Assistant 也是...

    决策树机器学习算法在乳腺癌诊断中的应用.pdf

    "决策树机器学习算法在乳腺癌诊断中的应用" 本文主要介绍了决策树机器学习算法在乳腺癌诊断中的应用,旨在解决传统医疗诊断的弊端,提高医疗效率和质量。文章首先介绍了机器学习在医疗领域的重要性和国内外研究现状...

    决策树资料合集.rar_决策树_决策树 word_决策树 文档_源代码

    决策树是一种广泛应用于数据分析、机器学习以及人工智能领域的算法模型,它通过模拟人类做决策的过程,以树状结构来表示可能的决策路径和结果。在这个"决策树资料合集"中,包含了关于决策树的源文件、实例、内容详解...

    信贷树决策树的金融行业应用

    #### 应用实例 假设一个银行需要评估客户的贷款申请。根据客户是否有房产、是否有车辆、收入水平、学历以及婚姻状况等特征,可以构建一个决策树模型来预测客户是否会按时偿还贷款。例如: 1. **客户是否有房产**:...

    java 决策树Demo2

    决策树是一种常用的人工智能和机器学习算法,用于分类和回归任务。在Java中实现决策树可以帮助开发者构建预测模型,解决复杂的问题。本教程将详细讲解如何使用Java进行决策树的实现,以及“Demo2”可能涉及的具体...

    决策树应用拓展及算法

    决策树是一种广泛应用于数据分析、机器学习以及人工智能领域的算法,它以树状结构来表示一系列决策过程,通过将数据集划分为不同的子集,逐步形成一个能够预测目标变量的模型。在商业环境中,决策树是最常用的数据...

    决策树ID3算法(Java实现)

    决策树是一种广泛应用于机器学习领域的算法,主要用于分类和回归任务。ID3(Iterative Dichotomiser 3)是决策树算法的一种早期形式,由Ross Quinlan在1986年提出。这个算法主要基于信息熵和信息增益来选择最佳属性...

    C45决策树算法 C45决策树算法

    - 决策树是一种图形模型,它通过树状结构来表示对实例进行分类的过程,每个内部节点代表一个特征或属性测试,每个分支代表一个测试输出,而叶子节点则代表类别。 2. **C45算法改进**: - ID3算法主要基于信息熵...

    c++决策树算法源码

    决策树是一种常用的人工智能和机器学习算法,用于分类和回归任务。在C++中实现决策树算法,我们可以采用ID3(Iterative Dichotomiser 3)算法作为基础,这是一种早期的基于信息熵和信息增益的决策树构建方法。下面...

    人工智能-决策树实验(对西瓜数据集 3.0 的分类)

    在本实验中,我们将深入探讨如何利用人工智能中的决策树算法对西瓜数据集 3.0 进行分类。决策树是一种流行的监督学习方法,尤其适用于分类问题,它通过构建一个树形结构来模拟一系列决策过程,最终达到预测目标变量...

    决策树二元分类

    决策树二元分类是机器学习领域中一种广泛应用的算法,特别是在人工智能和数据挖掘中。它是一种监督学习方法,主要用于处理二分类问题,即将数据集分为两个明显的类别。在这个压缩包中,很可能包含了关于决策树二元...

    java 决策树Demo1

    决策树是一种常用的人工智能和机器学习算法,用于分类和回归任务。在Java中实现决策树可以帮助开发者构建预测模型,解决复杂的问题。本教程将基于`java 决策树Demo1`来深入探讨决策树的基本概念、工作原理以及如何在...

    ID3,C4.5决策树完整代码以及结果图片

    ID3和C4.5是两种著名的决策树算法,在机器学习和数据挖掘领域广泛应用。它们主要用于分类任务,通过构建树状模型来预测目标变量的值,以实现对数据集的解释性和预测性分析。 **ID3算法(Iterative Dichotomiser 3)...

    决策树与随机森林算法,随机森林算法应用实例,Python源码.rar

    它们在数据挖掘、预测分析和人工智能中占据着重要地位。本资料包包含了对这两种算法的深入理解、实际应用案例以及相关的Python源码,旨在帮助用户更好地掌握这些概念和技术。 **决策树(Decision Tree)** 决策树是...

    机器学习实战 - 决策树PDF知识点总结 + 代码实现

    决策树通过一系列if-else规则来对数据进行分类,确保每个实例都能被且仅被一条规则覆盖。 **一、决策树的基本概念** 1. **节点类型**:根节点无入边,中间节点有单入多出,叶节点只有入边无出边。 2. **父节点与子...

    决策树ID3\ID4算法实例源码

    决策树是一种常用的人工智能和机器学习算法,用于分类和回归任务。ID3(Iterative Dichotomiser 3)是决策树算法的早期版本,由Ross Quinlan于1986年提出,主要用于分类问题。ID4是ID3的后续改进,增加了对连续属性...

Global site tag (gtag.js) - Google Analytics