一、算法介绍
GBDT(Gradient boosting decision tree)包括两种树分类树和回归树。今天主要介绍MPI版回归树pGBRT。
几篇比较好介绍GBRT算法的博文:
http://hi.baidu.com/hehehehello/item/96cc42e45c16e7265a2d64ee
http://blog.csdn.net/puqutogether/article/details/44781035
pGBRT源码下载:
http://machinelearning.wustl.edu/pmwiki.php/Main/Pgbrt
pGBRT版本是基于最小二乘(残差)梯度回归树
二、Gradient Boosting 和 decision tree
Boosting的步骤如上:
三、源码介绍
单颗回归树的构建过程
class StaticNode { public: int feature; //属性全局index float split; //属性分割点 double label;//均值 double loss; //最小化均方误差 //temp var int m_infty, m_s; //总个数, leftNode个数 float s; //当前分隔点 double l_infty, l_s;//总残差和, leftNode残差和 }; void StaticTree::updateBestSplits(FeatureData* data, int f) { // compute global feature index int globalf = data->globalFeatureIndex(f); // reset counts at nodes for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->m_s = 0; node->l_s = 0.0; } // iterate over feature for (int j=0; j<data->getN(); j++) { // get current value float v = data->getSortedFeature(f,j); int i = data->getSortedIndex(f,j); int n = data->getNode(i); float l = data->getResidual(i); // get node StaticNode* node = layers[layer][n]; // if not first instance at node and greater than split point, consider new split at v if (node->m_s > 0 and v > node->s) { //最小化均方误差 double loss_i = pow(node->l_s,2.0) / (double) node->m_s + pow(node->l_infty - node->l_s,2.0) / (double) (node->m_infty - node->m_s); if (node->loss < 0 or loss_i > node->loss) { node->loss = loss_i; node->feature = globalf; node->split = (node->s + v) / 2.f; // TODO : create a lookup table for these child values at tree construction, store in static tree or store each in static node StaticNode* child1 = layers[layer+1][2*n]; child1->label = node->l_s / (double) node->m_s; StaticNode* child2 = layers[layer+1][2*n+1]; child2->label = (node->l_infty - node->l_s) / (double) (node->m_infty - node->m_s); // if (child2->label > 5.0) // printf("### %f %d %d %f %d %f %d %f %d %d %d %f %d\n", v, i, n, l, node->m_s, node->l_s, node->m_infty, node->l_infty, globalf, f, j, loss_i, layer); } } // update variables node->m_s += 1; node->l_s += l; node->s = v; } } void StaticTree::exchangeBestSplits() { // instantiate buffer int buffersize = nodes*5; double* buffer = new double[buffersize]; // write layer of tree to buffer for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // get myid and numprocs int myid; MPI_Comm_rank(MPI_COMM_WORLD, &myid); int numprocs; MPI_Comm_size(MPI_COMM_WORLD, &numprocs); // determine isRoot int root = numprocs-1; bool isRoot = (myid == root); // exchange buffers double* rbuffer = (isRoot ? new double[numprocs*buffersize] : NULL); MPI_Gather(buffer, buffersize, MPI_DOUBLE, rbuffer, buffersize, MPI_DOUBLE, root, MPI_COMM_WORLD); // save best global splits if (isRoot) for (int n=0; n<nodes; n++) { // reset loss at node and get pointers StaticNode* node = layers[layer][n]; node->loss = -1; StaticNode* child1 = layers[layer+1][n*2]; StaticNode* child2 = layers[layer+1][n*2+1]; // consider loss from all processors for (int p=0; p<numprocs; p++) { int offset = p*buffersize + n*5; double loss = rbuffer[offset + 0]; // update if better than current if (node->loss < 0 or loss > node->loss) { node->loss = loss; node->feature = (int) rbuffer[offset + 1]; node->split = (float) rbuffer[offset + 2]; child1->label = rbuffer[offset + 3]; child2->label = rbuffer[offset + 4]; } } } // buffer best global splits if (isRoot) for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; buffer[n*5 + 0] = node->loss; buffer[n*5 + 1] = node->feature; buffer[n*5 + 2] = node->split; StaticNode* child1 = layers[layer+1][n*2]; buffer[n*5 + 3] = child1->label; StaticNode* child2 = layers[layer+1][n*2+1]; buffer[n*5 + 4] = child2->label; } // broadcast best splits MPI_Bcast(buffer, nodes*5, MPI_DOUBLE, root, MPI_COMM_WORLD); // update tree with best global splits for (int n=0; n<nodes; n++) { StaticNode* node = layers[layer][n]; node->loss = buffer[n*5 + 0]; node->feature = (int) buffer[n*5 + 1]; node->split = (float) buffer[n*5 + 2]; StaticNode* child1 = layers[layer+1][n*2]; child1->label = buffer[n*5 + 3]; StaticNode* child2 = layers[layer+1][n*2+1]; child2->label = buffer[n*5 + 4]; } // delete buffers delete [] buffer; delete [] rbuffer; }
四、缺点
MPI每个节点按列读取数据,所以每个节点至少读取一列的数据,导致数据行数扩展有限,在数据量大时单节点内存狂涨。后续会介绍如何按行列分隔数据。
相关推荐
在本项目"基于深度学习的图像隐写分析算法源码和UI系统设计.zip"中,主要探讨了如何利用深度学习技术进行图像隐写分析,并构建了一个用户界面(UI)系统来辅助这一过程。图像隐写分析是数字媒体安全领域的重要组成...
基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的社交媒体谣言分析Python源码基于深度学习LSTM的特征值识别的...
计算机大作业设计源码基于Python车牌识别系统深度学习项目源码.zip本项目是一套成熟的大作业项目系统,获取98分,主要针对计算机相关专业的正在做大作业的学生和需要项目实战练习的学习者,可作为课程设计、期末大...
深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析项目源码.zip深度学习的情感分析...
《cpp-darknet深度学习框架源码分析》 cpp-darknet是一个基于C++的深度学习框架,专注于速度和效率,尤其适合于嵌入式系统和实时应用。本资料将深入探讨cpp-darknet的源码,通过详细中文注释帮助读者理解框架的原理...
为方便阅读,把blog上的libevent源码深度剖析系列文章整合成一个pdf。
中文情感分析模型源码(含各种主流的情感词典、机器学习、深度学习、预训练模型方法).zip 中文情感分析模型源码(含各种主流的情感词典、机器学习、深度学习、预训练模型方法).zip 中文情感分析模型源码(含各种...
Spring 源码分析 Spring 框架是 Java 语言中最流行的开源框架之一,它提供了一个强大且灵活的基础设施来构建企业级应用程序。在 Spring 框架中,IOC 容器扮演着核心角色,本文将深入分析 Spring 源码,了解 IOC ...
在深入Libevent源码分析之前,需要了解它的核心概念,主要包括事件循环、事件处理器、IO事件、定时器事件、信号事件等。事件循环是Libevent的中枢,它在后台运行,检测事件源的变化,并触发相应的事件处理器。事件...
Python实现基于深度学习的图像隐写分析项目源码+GUI界面+毕业论文,该项目是个人毕设项目,答辩评审分达到98分,代码都经过调试测试,确保可以运行!欢迎下载使用,可用于小白学习、进阶。该资源主要针对计算机、...
《深度学习入门-源码-斋藤康逸》是一份深度学习的学习资源,包含了斋藤康逸关于深度学习基础知识和实践的源代码。这个压缩包包含的文件夹主要有`common`、`ch02`至`ch08`、`.idea`、`dataset`和`ch04`至`ch07`,每个...
Spring源码深度解析第二版 Spring是一款广泛应用于Java企业级应用程序的开源框架,旨在简化Java应用程序的开发和部署。Spring框架的核心主要包括了IoC容器、AOP、MVC框架等模块。 第1章 Spring整体架构和环境搭建 ...
基于各种机器学习和深度学习的中文微博情感分析源码+文档说明(高分项目)基于各种机器学习和深度学习的中文微博情感分析源码+文档说明(高分项目)基于各种机器学习和深度学习的中文微博情感分析源码+文档说明...
基于深度学习(LSTM)的电商购物情感分析项目源码和项目说明.zip个人经导师指导并认可通过的高分毕业设计项目,评审分98分。主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者,也可作为课程设计...
毕设项目-智慧教室基于深度学习开发的课堂专注度分析和考试作弊检测系统python源码.zip个人经导师指导并认可通过的高分毕业设计项目,主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者。...
基于各种机器学习和深度学习的中文微博情感分析项目源码(高分项目)含有代码注释,新手也可看懂,个人手打98分项目,毕业设计、期末大作业、课程设计、高分必看,下载下来,简单部署,就可以使用。该项目系统功能...
压缩包中的4个项目源码可能是分别针对不同的任务,比如图像分类、文本情感分析、语音识别或机器翻译。每个项目都可能涵盖数据预处理、模型构建、训练、验证和评估等阶段,这将帮助学习者了解深度学习模型在实际问题...
人工智能基于深度学习的学生课堂行为识别评价综合系统源码(毕业设计).zip本资源中的源码都是经过本地编译过可运行的,资源项目的难度比较适中,内容都是经过助教老师审定过的能够满足学习、使用需求,如果有需要的...
基于主流文本深度学习模型的中文商品金融文本精细化分类和情感分类,可用于对商品评价进行量化分析(源码+项目说明).zip 基于主流文本深度学习模型的中文商品金融文本精细化分类和情感分类,可用于对商品评价进行...
python基于深度学习的电影评论情感分析系统源码 次就是利用了flask框架以及深度学习中的word2vac向量模型来进行一款深度学习的电影评论软件开发,通过该软件的开发来更加有效的对众多的影评文本进行情感分析来判断出...