一、算法介绍
GBDT(Gradient boosting decision tree)包括两种树分类树和回归树。今天主要介绍MPI版回归树pGBRT。
几篇比较好介绍GBRT算法的博文:
http://hi.baidu.com/hehehehello/item/96cc42e45c16e7265a2d64ee
http://blog.csdn.net/puqutogether/article/details/44781035
pGBRT源码下载:
http://machinelearning.wustl.edu/pmwiki.php/Main/Pgbrt
pGBRT版本是基于最小二乘(残差)梯度回归树
二、Gradient Boosting 和 decision tree
Boosting的步骤如上:
三、源码介绍
单颗回归树的构建过程
class StaticNode { public: int feature; //属性全局index float split; //属性分割点 double label;//均值 double loss; //最小化均方误差 //temp var int m_infty, m_s; //总个数, leftNode个数 float s; //当前分隔点 double l_infty, l_s;//总残差和, leftNode残差和 }; void StaticTree::updateBestSplits(FeatureData* data, int f) { // compute global feature index int globalf = data->globalFeatureIndex(f); // reset counts at nodes for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->m_s = 0; node->l_s = 0.0; } // iterate over feature for (int j=0; j<data->getN(); j++) { // get current value float v = data->getSortedFeature(f,j); int i = data->getSortedIndex(f,j); int n = data->getNode(i); float l = data->getResidual(i); // get node StaticNode* node = layers[layer][n]; // if not first instance at node and greater than split point, consider new split at v if (node->m_s > 0 and v > node->s) { //最小化均方误差 double loss_i = pow(node->l_s,2.0) / (double) node->m_s + pow(node->l_infty - node->l_s,2.0) / (double) (node->m_infty - node->m_s); if (node->loss < 0 or loss_i > node->loss) { node->loss = loss_i; node->feature = globalf; node->split = (node->s + v) / 2.f; // TODO : create a lookup table for these child values at tree construction, store in static tree or store each in static node StaticNode* child1 = layers[layer+1][2*n]; child1->label = node->l_s / (double) node->m_s; StaticNode* child2 = layers[layer+1][2*n+1]; child2->label = (node->l_infty - node->l_s) / (double) (node->m_infty - node->m_s); // if (child2->label > 5.0) // printf("### %f %d %d %f %d %f %d %f %d %d %d %f %d\n", v, i, n, l, node->m_s, node->l_s, node->m_infty, node->l_infty, globalf, f, j, loss_i, layer); } } // update variables node->m_s += 1; node->l_s += l; node->s = v; } } void StaticTree::exchangeBestSplits() { // instantiate buffer int buffersize = nodes*5; double* buffer = new double[buffersize]; // write layer of tree to buffer for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // get myid and numprocs int myid; MPI_Comm_rank(MPI_COMM_WORLD, &myid); int numprocs; MPI_Comm_size(MPI_COMM_WORLD, &numprocs); // determine isRoot int root = numprocs-1; bool isRoot = (myid == root); // exchange buffers double* rbuffer = (isRoot ? new double[numprocs*buffersize] : NULL); MPI_Gather(buffer, buffersize, MPI_DOUBLE, rbuffer, buffersize, MPI_DOUBLE, root, MPI_COMM_WORLD); // save best global splits if (isRoot) for (int n=0; n<nodes; n++) { // reset loss at node and get pointers StaticNode* node = layers[layer][n]; node->loss = -1; StaticNode* child1 = layers[layer+1][n*2]; StaticNode* child2 = layers[layer+1][n*2+1]; // consider loss from all processors for (int p=0; p<numprocs; p++) { int offset = p*buffersize + n*5; double loss = rbuffer[offset + 0]; // update if better than current if (node->loss < 0 or loss > node->loss) { node->loss = loss; node->feature = (int) rbuffer[offset + 1]; node->split = (float) rbuffer[offset + 2]; child1->label = rbuffer[offset + 3]; child2->label = rbuffer[offset + 4]; } } } // buffer best global splits if (isRoot) for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // broadcast best splits MPI_Bcast(buffer, nodes*5, MPI_DOUBLE, root, MPI_COMM_WORLD); // update tree with best global splits for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->loss = buffer[n*5 + 0]; node->feature = (int) buffer[n*5 + 1]; node->split = (float) buffer[n*5 + 2]; StaticNode* child1 = layers[layer+1][n*2]; child1->label = buffer[n*5 + 3]; StaticNode* child2 = layers[layer+1][n*2+1]; child2->label = buffer[n*5 + 4]; } // delete buffers delete [] buffer; delete [] rbuffer; }
四、缺点
MPI每个节点按列读取数据,所以每个节点至少读取一列的数据,导致数据行数扩展有限,在数据量大时单节点内存狂涨。后续会介绍如何按行列分隔数据。
相关推荐
Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习教程源码;Python数据分析+机器学习+深度学习...
【资源介绍】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设项目,也可以作为小白...基于深度学习的异常行为分析算法源码+数据.zip
基于oneflow框架开发的深度学习模型源码.zip基于oneflow框架开发的深度学习模型源码.zip基于oneflow框架开发的深度学习模型源码.zip基于oneflow框架开发的深度学习模型源码.zip基于oneflow框架开发的深度学习模型...
Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring ...
Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring 源码深度剖析Spring ...
【资源介绍】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设...用于CTR预估的深度学习模型源码.zip用于CTR预估的深度学习模型源码.zip
基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的...
计算机大作业设计源码基于Python车牌识别系统深度学习项目源码.zip本项目是一套成熟的大作业项目系统,获取98分,主要针对计算机相关专业的正在做大作业的学生和需要项目实战练习的学习者,可作为课程设计、期末大...
【资源介绍】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设项目,也...基于深度学习的原油与化工期货的预测分析算法源码+数据.zip
深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析...
【资源介绍】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设项目,也可以作为小白实战演练和...生成式对抗网络深度学习模型源码 .zip
《cpp-darknet深度学习框架源码分析》 cpp-darknet是一个基于C++的深度学习框架,专注于速度和效率,尤其适合于嵌入式系统和实时应用。本资料将深入探讨cpp-darknet的源码,通过详细中文注释帮助读者理解框架的原理...
为方便阅读,把blog上的libevent源码深度剖析系列文章整合成一个pdf。
【资源介绍】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业...基于深度学习技术与手机性能结合的抗病新品种培育分析算法源码+说明+视频.zip
【资源介绍】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设...基于度量学习分类器的人脸识别系统(深度学习matlab源码+说明).zip
基于深度学习(LSTM)的电商购物情感分析项目源码+全部数据(高分毕业设计项目)基于深度学习(LSTM)的电商购物情感分析项目源码+全部数据(高分毕业设计项目)基于深度学习(LSTM)的电商购物情感分析项目源码+...
【资源介绍】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设项目,也可以作为小白实战...深度学习与股票分析预测算法源码+数据.zip
期末大作业Python基于深度学习的电影评论情感分析源码+报告PDF期末大作业Python基于深度学习的电影评论情感分析源码+报告PDF期末大作业Python基于深度学习的电影评论情感分析源码+报告PDF期末大作业Python基于深度...
基于深度学习 LSTM + BERT 词向量的混合架构的药品评论情感分析系统源码(可自动分析药品评论的情感倾向(积极、中性、消极)).zip 基于深度学习 LSTM + BERT 词向量的混合架构的药品评论情感分析系统源码(可自动...
Python期末大作业-深度学习与股票分析预测项目源码(高分项目)Python期末大作业-深度学习与股票分析预测项目源码(高分项目)Python期末大作业-深度学习与股票分析预测项目源码(高分项目)Python期末大作业-深度...