一、算法介绍
GBDT(Gradient boosting decision tree)包括两种树分类树和回归树。今天主要介绍MPI版回归树pGBRT。
几篇比较好介绍GBRT算法的博文:
http://hi.baidu.com/hehehehello/item/96cc42e45c16e7265a2d64ee
http://blog.csdn.net/puqutogether/article/details/44781035
pGBRT源码下载:
http://machinelearning.wustl.edu/pmwiki.php/Main/Pgbrt
pGBRT版本是基于最小二乘(残差)梯度回归树
二、Gradient Boosting 和 decision tree
Boosting的步骤如上:
三、源码介绍
单颗回归树的构建过程
class StaticNode { public: int feature; //属性全局index float split; //属性分割点 double label;//均值 double loss; //最小化均方误差 //temp var int m_infty, m_s; //总个数, leftNode个数 float s; //当前分隔点 double l_infty, l_s;//总残差和, leftNode残差和 }; void StaticTree::updateBestSplits(FeatureData* data, int f) { // compute global feature index int globalf = data->globalFeatureIndex(f); // reset counts at nodes for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->m_s = 0; node->l_s = 0.0; } // iterate over feature for (int j=0; j<data->getN(); j++) { // get current value float v = data->getSortedFeature(f,j); int i = data->getSortedIndex(f,j); int n = data->getNode(i); float l = data->getResidual(i); // get node StaticNode* node = layers[layer][n]; // if not first instance at node and greater than split point, consider new split at v if (node->m_s > 0 and v > node->s) { //最小化均方误差 double loss_i = pow(node->l_s,2.0) / (double) node->m_s + pow(node->l_infty - node->l_s,2.0) / (double) (node->m_infty - node->m_s); if (node->loss < 0 or loss_i > node->loss) { node->loss = loss_i; node->feature = globalf; node->split = (node->s + v) / 2.f; // TODO : create a lookup table for these child values at tree construction, store in static tree or store each in static node StaticNode* child1 = layers[layer+1][2*n]; child1->label = node->l_s / (double) node->m_s; StaticNode* child2 = layers[layer+1][2*n+1]; child2->label = (node->l_infty - node->l_s) / (double) (node->m_infty - node->m_s); // if (child2->label > 5.0) // printf("### %f %d %d %f %d %f %d %f %d %d %d %f %d\n", v, i, n, l, node->m_s, node->l_s, node->m_infty, node->l_infty, globalf, f, j, loss_i, layer); } } // update variables node->m_s += 1; node->l_s += l; node->s = v; } } void StaticTree::exchangeBestSplits() { // instantiate buffer int buffersize = nodes*5; double* buffer = new double[buffersize]; // write layer of tree to buffer for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // get myid and numprocs int myid; MPI_Comm_rank(MPI_COMM_WORLD, &myid); int numprocs; MPI_Comm_size(MPI_COMM_WORLD, &numprocs); // determine isRoot int root = numprocs-1; bool isRoot = (myid == root); // exchange buffers double* rbuffer = (isRoot ? new double[numprocs*buffersize] : NULL); MPI_Gather(buffer, buffersize, MPI_DOUBLE, rbuffer, buffersize, MPI_DOUBLE, root, MPI_COMM_WORLD); // save best global splits if (isRoot) for (int n=0; n<nodes; n++) { // reset loss at node and get pointers StaticNode* node = layers[layer][n]; node->loss = -1; StaticNode* child1 = layers[layer+1][n*2]; StaticNode* child2 = layers[layer+1][n*2+1]; // consider loss from all processors for (int p=0; p<numprocs; p++) { int offset = p*buffersize + n*5; double loss = rbuffer[offset + 0]; // update if better than current if (node->loss < 0 or loss > node->loss) { node->loss = loss; node->feature = (int) rbuffer[offset + 1]; node->split = (float) rbuffer[offset + 2]; child1->label = rbuffer[offset + 3]; child2->label = rbuffer[offset + 4]; } } } // buffer best global splits if (isRoot) for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // broadcast best splits MPI_Bcast(buffer, nodes*5, MPI_DOUBLE, root, MPI_COMM_WORLD); // update tree with best global splits for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->loss = buffer[n*5 + 0]; node->feature = (int) buffer[n*5 + 1]; node->split = (float) buffer[n*5 + 2]; StaticNode* child1 = layers[layer+1][n*2]; child1->label = buffer[n*5 + 3]; StaticNode* child2 = layers[layer+1][n*2+1]; child2->label = buffer[n*5 + 4]; } // delete buffers delete [] buffer; delete [] rbuffer; }
四、缺点
MPI每个节点按列读取数据,所以每个节点至少读取一列的数据,导致数据行数扩展有限,在数据量大时单节点内存狂涨。后续会介绍如何按行列分隔数据。
相关推荐
Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习...
基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的...
深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析...
计算机大作业设计源码基于Python车牌识别系统深度学习项目源码.zip本项目是一套成熟的大作业项目系统,获取98分,主要针对计算机相关专业的正在做大作业的学生和需要项目实战练习的学习者,可作为课程设计、期末大...
《cpp-darknet深度学习框架源码分析》 cpp-darknet是一个基于C++的深度学习框架,专注于速度和效率,尤其适合于嵌入式系统和实时应用。本资料将深入探讨cpp-darknet的源码,通过详细中文注释帮助读者理解框架的原理...
为方便阅读,把blog上的libevent源码深度剖析系列文章整合成一个pdf。
中文情感分析模型源码(含各种主流的情感词典、机器学习、深度学习、预训练模型方法).zip 中文情感分析模型源码(含各种主流的情感词典、机器学习、深度学习、预训练模型方法).zip 中文情感分析模型源码(含各种...
Spring 源码分析 Spring 框架是 Java 语言中最流行的开源框架之一,它提供了一个强大且灵活的基础设施来构建企业级应用程序。在 Spring 框架中,IOC 容器扮演着核心角色,本文将深入分析 Spring 源码,了解 IOC ...
文本通用处理流程:文本分词、分词向量化、文本分类、聚类、深度学习等源码.zip文本通用处理流程:文本分词、分词向量化、文本分类、聚类、深度学习等源码.zip文本通用处理流程:文本分词、分词向量化、文本分类、...
1、该资源内项目代码经过严格调试,下载即用确保可以运行! 2、该资源适合计算机相关专业(如计科、人工智能、大数据、数学、...基于机器学习模型SVM和深度学习模型LSTM的nlp中情感分析实例源码(从打标签语料开始).zip
基于各种机器学习和深度学习算法的中文微博情感分析源码.zip基于各种机器学习和深度学习算法的中文微博情感分析源码.zip基于各种机器学习和深度学习算法的中文微博情感分析源码.zip基于各种机器学习和深度学习算法的...
基于深度学习模型的文本情感分析WSGI应用源码+项目说明.zip 基于深度学习模型的文本情感分析WSGI应用源码+项目说明.zip 基于深度学习模型的文本情感分析WSGI应用源码+项目说明.zip 基于深度学习模型的文本情感分析...
Python实现基于深度学习的图像隐写分析项目源码+GUI界面+毕业论文,该项目是个人毕设项目,答辩评审分达到98分,代码都经过调试测试,确保可以运行!欢迎下载使用,可用于小白学习、进阶。该资源主要针对计算机、...
《深度学习入门-源码-斋藤康逸》是一份深度学习的学习资源,包含了斋藤康逸关于深度学习基础知识和实践的源代码。这个压缩包包含的文件夹主要有`common`、`ch02`至`ch08`、`.idea`、`dataset`和`ch04`至`ch07`,每个...
Spring源码深度解析第二版 Spring是一款广泛应用于Java企业级应用程序的开源框架,旨在简化Java应用程序的开发和部署。Spring框架的核心主要包括了IoC容器、AOP、MVC框架等模块。 第1章 Spring整体架构和环境搭建 ...
基于深度学习(LSTM)的电商购物情感分析项目源码和项目说明.zip个人经导师指导并认可通过的高分毕业设计项目,评审分98分。主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者,也可作为课程设计...
毕设项目-智慧教室基于深度学习开发的课堂专注度分析和考试作弊检测系统python源码.zip个人经导师指导并认可通过的高分毕业设计项目,主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者。...
毕业设计:python基于深度学习的中文情感分析系统(源码 + 数据库 + 说明文档) 二、 技术及工具介绍 4 (一) B/S架构 4 (二) MYSQL 4 (三) 算法 5 (四) Python技术 5 三、 系统分析 5 (一) 可行性分析 5 (二) 需求...
首先,我们从“源码深度解析dubbo.pdf”入手。这本书籍主要涵盖了以下几个方面: 1. **Dubbo架构设计**:书中详细介绍了Dubbo的整体架构,包括服务提供者、消费者、注册中心和服务监控等组件,以及它们之间的交互...
源码分析是理解一个软件系统工作原理的关键步骤,对于深度学习框架Caffe来说,通过阅读源码,开发者可以深入理解其背后的计算流程、优化策略以及与其他框架的差异。这有助于开发者定制自己的网络结构,提高模型的...