一、算法介绍
GBDT(Gradient boosting decision tree)包括两种树分类树和回归树。今天主要介绍MPI版回归树pGBRT。
几篇比较好介绍GBRT算法的博文:
http://hi.baidu.com/hehehehello/item/96cc42e45c16e7265a2d64ee
http://blog.csdn.net/puqutogether/article/details/44781035
pGBRT源码下载:
http://machinelearning.wustl.edu/pmwiki.php/Main/Pgbrt
pGBRT版本是基于最小二乘(残差)梯度回归树
二、Gradient Boosting 和 decision tree
Boosting的步骤如上:
三、源码介绍
单颗回归树的构建过程
class StaticNode { public: int feature; //属性全局index float split; //属性分割点 double label;//均值 double loss; //最小化均方误差 //temp var int m_infty, m_s; //总个数, leftNode个数 float s; //当前分隔点 double l_infty, l_s;//总残差和, leftNode残差和 }; void StaticTree::updateBestSplits(FeatureData* data, int f) { // compute global feature index int globalf = data->globalFeatureIndex(f); // reset counts at nodes for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->m_s = 0; node->l_s = 0.0; } // iterate over feature for (int j=0; j<data->getN(); j++) { // get current value float v = data->getSortedFeature(f,j); int i = data->getSortedIndex(f,j); int n = data->getNode(i); float l = data->getResidual(i); // get node StaticNode* node = layers[layer][n]; // if not first instance at node and greater than split point, consider new split at v if (node->m_s > 0 and v > node->s) { //最小化均方误差 double loss_i = pow(node->l_s,2.0) / (double) node->m_s + pow(node->l_infty - node->l_s,2.0) / (double) (node->m_infty - node->m_s); if (node->loss < 0 or loss_i > node->loss) { node->loss = loss_i; node->feature = globalf; node->split = (node->s + v) / 2.f; // TODO : create a lookup table for these child values at tree construction, store in static tree or store each in static node StaticNode* child1 = layers[layer+1][2*n]; child1->label = node->l_s / (double) node->m_s; StaticNode* child2 = layers[layer+1][2*n+1]; child2->label = (node->l_infty - node->l_s) / (double) (node->m_infty - node->m_s); // if (child2->label > 5.0) // printf("### %f %d %d %f %d %f %d %f %d %d %d %f %d\n", v, i, n, l, node->m_s, node->l_s, node->m_infty, node->l_infty, globalf, f, j, loss_i, layer); } } // update variables node->m_s += 1; node->l_s += l; node->s = v; } } void StaticTree::exchangeBestSplits() { // instantiate buffer int buffersize = nodes*5; double* buffer = new double[buffersize]; // write layer of tree to buffer for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // get myid and numprocs int myid; MPI_Comm_rank(MPI_COMM_WORLD, &myid); int numprocs; MPI_Comm_size(MPI_COMM_WORLD, &numprocs); // determine isRoot int root = numprocs-1; bool isRoot = (myid == root); // exchange buffers double* rbuffer = (isRoot ? new double[numprocs*buffersize] : NULL); MPI_Gather(buffer, buffersize, MPI_DOUBLE, rbuffer, buffersize, MPI_DOUBLE, root, MPI_COMM_WORLD); // save best global splits if (isRoot) for (int n=0; n<nodes; n++) { // reset loss at node and get pointers StaticNode* node = layers[layer][n]; node->loss = -1; StaticNode* child1 = layers[layer+1][n*2]; StaticNode* child2 = layers[layer+1][n*2+1]; // consider loss from all processors for (int p=0; p<numprocs; p++) { int offset = p*buffersize + n*5; double loss = rbuffer[offset + 0]; // update if better than current if (node->loss < 0 or loss > node->loss) { node->loss = loss; node->feature = (int) rbuffer[offset + 1]; node->split = (float) rbuffer[offset + 2]; child1->label = rbuffer[offset + 3]; child2->label = rbuffer[offset + 4]; } } } // buffer best global splits if (isRoot) for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // broadcast best splits MPI_Bcast(buffer, nodes*5, MPI_DOUBLE, root, MPI_COMM_WORLD); // update tree with best global splits for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->loss = buffer[n*5 + 0]; node->feature = (int) buffer[n*5 + 1]; node->split = (float) buffer[n*5 + 2]; StaticNode* child1 = layers[layer+1][n*2]; child1->label = buffer[n*5 + 3]; StaticNode* child2 = layers[layer+1][n*2+1]; child2->label = buffer[n*5 + 4]; } // delete buffers delete [] buffer; delete [] rbuffer; }
四、缺点
MPI每个节点按列读取数据,所以每个节点至少读取一列的数据,导致数据行数扩展有限,在数据量大时单节点内存狂涨。后续会介绍如何按行列分隔数据。
相关推荐
Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习...
Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring ...
Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring ...
基于深度学习的单目深度估计总结源码+文档说明+全部资料.zip基于深度学习的单目深度估计总结源码+文档说明+全部资料.zip基于深度学习的单目深度估计总结源码+文档说明+全部资料.zip基于深度学习的单目深度估计总结...
基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的...
计算机大作业设计源码基于Python车牌识别系统深度学习项目源码.zip本项目是一套成熟的大作业项目系统,获取98分,主要针对计算机相关专业的正在做大作业的学生和需要项目实战练习的学习者,可作为课程设计、期末大...
深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析...
《cpp-darknet深度学习框架源码分析》 cpp-darknet是一个基于C++的深度学习框架,专注于速度和效率,尤其适合于嵌入式系统和实时应用。本资料将深入探讨cpp-darknet的源码,通过详细中文注释帮助读者理解框架的原理...
基于深度学习的单目深度估计源码+项目说明(DIP课程项目).zip基于深度学习的单目深度估计源码+项目说明(DIP课程项目).zip基于深度学习的单目深度估计源码+项目说明(DIP课程项目).zip基于深度学习的单目深度估计...
为方便阅读,把blog上的libevent源码深度剖析系列文章整合成一个pdf。
基于深度学习(LSTM)的电商购物情感分析项目源码+全部数据(高分毕业设计项目)基于深度学习(LSTM)的电商购物情感分析项目源码+全部数据(高分毕业设计项目)基于深度学习(LSTM)的电商购物情感分析项目源码+...
Python期末大作业-深度学习与股票分析预测项目源码(高分项目)Python期末大作业-深度学习与股票分析预测项目源码(高分项目)Python期末大作业-深度学习与股票分析预测项目源码(高分项目)Python期末大作业-深度...
Spring 源码分析 Spring 框架是 Java 语言中最流行的开源框架之一,它提供了一个强大且灵活的基础设施来构建企业级应用程序。在 Spring 框架中,IOC 容器扮演着核心角色,本文将深入分析 Spring 源码,了解 IOC ...
文本通用处理流程:文本分词、分词向量化、文本分类、聚类、深度学习等源码.zip文本通用处理流程:文本分词、分词向量化、文本分类、聚类、深度学习等源码.zip文本通用处理流程:文本分词、分词向量化、文本分类、...
1、该资源内项目代码经过严格调试,下载即用确保可以运行! 2、该资源适合计算机相关专业(如计科、人工智能、大数据、数学、...基于机器学习模型SVM和深度学习模型LSTM的nlp中情感分析实例源码(从打标签语料开始).zip
基于各种机器学习和深度学习算法的中文微博情感分析源码.zip基于各种机器学习和深度学习算法的中文微博情感分析源码.zip基于各种机器学习和深度学习算法的中文微博情感分析源码.zip基于各种机器学习和深度学习算法的...
基于深度学习LSTM的电商购物情感分析项目源码+文档说明(高分毕业设计)基于深度学习LSTM的电商购物情感分析项目源码+文档说明(高分毕业设计)基于深度学习LSTM的电商购物情感分析项目源码+文档说明(高分毕业设计...
基于深度学习模型的文本情感分析WSGI应用源码+项目说明.zip 基于深度学习模型的文本情感分析WSGI应用源码+项目说明.zip 基于深度学习模型的文本情感分析WSGI应用源码+项目说明.zip 基于深度学习模型的文本情感分析...
Python实现基于深度学习的图像隐写分析项目源码+GUI界面+毕业论文,该项目是个人毕设项目,答辩评审分达到98分,代码都经过调试测试,确保可以运行!欢迎下载使用,可用于小白学习、进阶。该资源主要针对计算机、...
Spring源码深度解析第二版 Spring是一款广泛应用于Java企业级应用程序的开源框架,旨在简化Java应用程序的开发和部署。Spring框架的核心主要包括了IoC容器、AOP、MVC框架等模块。 第1章 Spring整体架构和环境搭建 ...