`

贝叶斯推断及其互联网应用: 已知推断未知概率

阅读更多
已知推断未知概率, 也叫贝叶斯分类

先上问题吧,我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据:sunny,cool,high,TRUE,判断一下会不会去打球。

table 1

outlooktemperaturehumiditywindyplay
sunnyhothighFALSEno
sunnyhothighTRUEno
overcasthothighFALSEyes
rainymildhighFALSEyes
rainycoolnormalFALSEyes
rainycoolnormalTRUEno
overcastcoolnormalTRUEyes
sunnymildhighFALSEno
sunnycoolnormalFALSEyes
rainymildnormalFALSEyes
sunnymildnormalTRUEyes
overcastmildhighTRUEyes
overcasthotnormalFALSEyes
rainymildhighTRUEno

这个问题可以用决策树的方法来求解,当然我们今天讲的是朴素贝叶斯法。这个一”打球“还是“不打球”是个两类分类问题,实际上朴素贝叶斯可以没有任何改变地解决多类分类问题。决策树也一样,它们都是有导师的分类方法。

朴素贝叶斯模型有两个假设:所有变量对分类均是有用的,即输出依赖于所有的属性;这些变量是相互独立的,即不相关的。之所以称为“朴素”,就是因为这些假设从未被证实过。

注意上面每项属性(或称指标)的取值都是离散的,称为“标称变量”。

step1.对每项指标分别统计:在不同的取值下打球和不打球的次数。

table 2

outlooktemperaturehumiditywindyplay
yesno yesno yesno yesnoyesno
sunny23hot22high34FALSE6295
overcast40mild42normal61TRUR33
rainy32cool31

step2.分别计算在给定“证据”下打球和不打球的概率。

这里我们的“证据”就是sunny,cool,high,TRUE,记为E,E1=sunny,E2=cool,E3=high,E4=TRUE。

A、B相互独立时,由:


得贝叶斯定理:


得:




又因为4个指标是相互独立的,所以


我们只需要比较P(yes|E)和P(no|E)的大小,就可以决定打不打球了。所以分母P(E)实际上是不需要计算的。

P(yes|E)*P(E)=2/9×3/9×3/9×3/9×9/14=0.0053

P(no|E)*P(E)=3/5×1/5×4/5×3/5×5/14=0.0206

所以不打球的概率更大。

零频问题

注意table 2中有一个数据为0,这意味着在outlook为overcast的情况下,不打球和概率为0,即只要为overcast就一定打球,这违背了朴素贝叶斯的基本假设:输出依赖于所有的属性。

数据平滑的方法很多,最简单最古老的是拉普拉斯估计(Laplace estimator)--即为table2中的每个计数都加1。它的一种演变是每个计数都u(0<u<1)。

Good-Turing是平滑算法中的佼佼者,有兴趣的可以了解下。我在作基于隐马尔可夫的词性标注时发现Good-Turing的效果非常不错。
对于任何发生r次的事件,都假设它发生了r*次:



nr是历史数据中发生了r次的事件的个数。

数值属性

当属性的取值为连续的变量时,称这种属性为“数值属性“。通常我们假设数值属性的取值服从正态分布。

outlooktemperaturehumiditywindyplay
yesno yesno yesno yesnoyesno
sunny23 8385 8685FALSE6295
overcast40 7080 9690TRUR33
rainy32 6865 8070
6472 6595
6971 7091
75 80
75 70
72 90
81 75
sunny2/93/5mean value7374.6mean value79.186.2FALSE6/92/59/155/14
overcast4/90/5deviation6.27.9deviation10.29.7TRUR3/93/5

正态分布的概率密度函数为:


现在已知天气为:outlook=overcast,temperature=66,humidity=90,windy=TRUE。问是否打球?

f(温度=66|yes)=0.0340

f(湿度=90|yes)=0.0221

yes的似然=2/9×0.0340×0.0221×3/9×9/14=0.000036

no的似然=3/5×0.0291×0.0380×3/5×9/14=0.000136

不打球的概率更大一些。

用于文本分类

朴素贝叶斯分类是一种基于概率的有导师分类器。

词条集合W,文档集合D,类别集合C。

根据(1)式(去掉分母)得文档d属于类别cj的概率为:


p(cj)表示类别j出现的概率,让属于类别j的文档数量除以总文档数量即可。

而已知类别cj的情况下词条wt出现的后验概率为:类别cj中包含wt的文档数目  除以 类别cj中包含的文档总数目 。

结束语

实践已多次证明,朴素贝叶斯在许多数据集上不逊于甚至优于一些更复杂的分类方法。这里的原则是:优先尝试简单的方法。

机器学习的研究者尝试用更复杂的学习模型来得到良好的结果,许多年后发现简单的方法仍可取得同样甚至更好的结果。

实现代码:
Classifier.java
/**
 * 
 * 描述: 算法接口.
 * @author 
 *
 */
public interface Classifier {

    /**
     * 处理模型数据.
     * @param lable 标签名称.
     * @param value 标签值.
     * @param cnt 数量(该条数据的数量)
     * @param target 目标名称.
     * @param targetValue 目标值.
     */
    void train(String[] lable, String[] value, int cnt, String target, String targetValue);

    /**
     * 先验概率计算出其后验概率.
     * @param features 属性值.
     * @return 后验概率较大的数值.
     */
    String predict(String[] features);
}


NaiveBayes.java

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 
 * 描述: 朴树贝叶斯算法.
 * @author 
 *
 */
public class NaiveBayes implements Classifier {
    private static final Logger LOG = LoggerFactory.getLogger(NaiveBayes.class);
    // 小数点后6位
    private static final int AFTER_POINT = 6;
    // 所有标签名称
    private String[] labelName;
    // 目标名称
    private String targetName;
    // 目标标签
    private String[] targetLabelName;
    // 标签列表
    private List<LabelBo> lstLabelBo;
    // 目标值列表
    private Set<String> setTargetVal;
    
    /**
     * 构造函数.
     */
    public NaiveBayes() {
        this.lstLabelBo = new ArrayList<LabelBo>();
        this.setTargetVal = new TreeSet<String>();
    }

    public void setLabelName(String[] labelName) {
        this.labelName = labelName;
    }
    
    public void setTargetLabelName(String[] targetLabelName) {
        this.targetLabelName = targetLabelName;
    }

    public void setTargetName(String targetName) {
        this.targetName = targetName;
    }
    
    public String[] getLabelName() {
        return labelName;
    }
    
    public List<LabelBo> getLstLabelBo() {
        return lstLabelBo;
    }
    
    public String[] getTargetLabelName() {
        return targetLabelName;
    }
    
    public String getTargetName() {
        return targetName;
    }
    
    public Set<String> getSetTargetVal() {
        return setTargetVal;
    }
    
    /**
     * 读文件.
     * @param path 路径.
     */
    public void readFile(String path) {
        if (null == this.labelName || null == this.targetName) {
            return;
        }
        
        try {
            BufferedReader reader = new BufferedReader(new FileReader(path));
            String line;
            boolean isTrue = false;
            while ((line = reader.readLine()) != null) {
                if ("@data".equals(line)) {
                    isTrue = true;
                    continue;
                }
                
                if (!isTrue) {
                    continue;
                }
                
                String[] atts = line.split(",");
                this.train(this.labelName, atts, 1,  this.targetName, atts[atts.length - 1]);
            }
            reader.close();
        } catch (FileNotFoundException ex) {
            LOG.error("Read naivebayes mode data failed, not found file, " + ex.getMessage());
        } catch (IOException ex) {
            LOG.error("Read naivebayes mode data failed, IO exception, " + ex.getMessage());
        }
    }
    
    @Override
    public void train(String[] lable, String[] value, int cnt, String target, String targetValue) {
        for (int i = 0; i < lable.length; i++) {
            LabelBo labelBo = null;
            for (LabelBo lb : this.lstLabelBo) {
                if (lable[i].equals(lb.getLableName()) && value[i].equals(lb.getItemName())) {
                    labelBo = lb;
                    break;
                }
            }

            if (null == labelBo) {
                labelBo = new LabelBo();
                labelBo.setLableName(lable[i]);
                labelBo.setItemName(value[i]);
                this.lstLabelBo.add(labelBo);
            }
            int index = labelBo.addName(targetValue);
            labelBo.addCount(index, cnt);
        }
    }

    /**
     * 计算比例.
     */
    public void rate() {
        Map<String, Integer> mapTotal = new HashMap<String, Integer>();
        for (LabelBo lb : this.lstLabelBo) {
            for (int i = 0; i < lb.getLstCount().size(); i++) {
                String tmp = lb.getLableName() + "." + lb.getLstName().get(i);
                if (mapTotal.containsKey(tmp)) {
                    mapTotal.put(tmp, mapTotal.get(tmp) + lb.getLstCount().get(i));
                } else {
                    mapTotal.put(tmp, lb.getLstCount().get(i));
                }
            }
        }
        for (LabelBo lb : this.lstLabelBo) {
            List<Integer> lst = lb.getLstTotal();
            for (int i = 0; i < lb.getLstName().size(); i++) {
                String tmp = lb.getLableName() + "." + lb.getLstName().get(i);
                lst.add(mapTotal.get(tmp));
            }
        }

        // 目标计算
        List<LabelBo> lstTmpLabelBo = new ArrayList<LabelBo>();
        for (LabelBo lb : this.lstLabelBo) {
            if (this.targetName.equalsIgnoreCase(lb.getLableName())) {
                lstTmpLabelBo.add(lb);
            }
        }

        int total = 0;
        for (LabelBo labelBo : lstTmpLabelBo) {
            if (null != labelBo) {
                for (int i = 0; i < labelBo.getLstCount().size(); i++) {
                    total += labelBo.getLstCount().get(i);
                    this.setTargetVal.add(labelBo.getLstName().get(i));
                }
            }
        }

        for (LabelBo labelBo : lstTmpLabelBo) {
            for (int i = 0; i < labelBo.getLstName().size(); i++) {
                labelBo.getLstTotal().set(i, total);
            }
        }
    }
    
    @Override
    public String predict(String[] features) {
        String score = "";
        double rate = 0;
        Set<String> lstTv = this.getSetTargetVal();
//        double total = 0;
        for (String v : lstTv) {
            String result = this.doPredict(this.targetLabelName, features, this.targetName, v);
            if (rate < Double.valueOf(result)) {
                rate = Double.valueOf(result);
                score =  v;
//                total += Double.valueOf(result);
            }
//            System.out.println(result + ":" +  v);
        }
        return score + ":" + rate;
    }

    /**
     * 计算后验概率.
     * @param lable 标签名称
     * @param features 标签值
     * @param target 目标名称
     * @param targetValue 目标值
     * @return 结果.
     */
    private String doPredict(String[] lable, String[] features, String target, String targetValue) {
        int pre = 1;
        int dev = 1;
        for (int i = 0; i < lable.length; i++) {
            LabelBo labelBo = null;
            for (LabelBo lb : this.lstLabelBo) {
                if (lable[i].equalsIgnoreCase(lb.getLableName()) && features[i].equalsIgnoreCase(lb.getItemName())) {
                    labelBo = lb;
                    break;
                }
            }
            
            if (null == labelBo) {
                continue;
            }

            List<String> lstName = labelBo.getLstName();
            for (String str : lstName) {
                if (targetValue.equals(str)) {
                    pre *= labelBo.getLstCount().get(lstName.indexOf(str));
                    dev *= labelBo.getLstTotal().get(lstName.indexOf(str));
                }
            }
        }

        LabelBo labelBo = null;
        for (LabelBo lb : this.lstLabelBo) {
            if (target.equalsIgnoreCase(lb.getLableName()) && targetValue.equalsIgnoreCase(lb.getItemName())) {
                labelBo = lb;
                break;
            }
        }

        if (null != labelBo) {
            List<String> lstName = labelBo.getLstName();
            for (String str : lstName) {
                if (targetValue.equals(str)) {
                    pre *= labelBo.getLstCount().get(lstName.indexOf(str));
                    dev *= labelBo.getLstTotal().get(lstName.indexOf(str));
                }
            }
        }

        BigDecimal result = new BigDecimal(pre).divide(new BigDecimal(dev), AFTER_POINT, BigDecimal.ROUND_HALF_UP);
        return result.toString();
    }
    
    /**
     * 重置.
     */
    public void reset() {
        this.lstLabelBo.clear();
        this.setTargetVal.clear();
    }
    
    /**
     * 打印数据.
     */
    public void print() {
        for (LabelBo key : this.lstLabelBo) {
            System.out.println(key.getLableName() + "=======>" + key.getItemName());
            List<String> lstName = key.getLstName();
            List<Integer> lstCount = key.getLstCount();
            List<Integer> lstTotal = key.getLstTotal();
            for (int i = 0; i < lstName.size(); i++) {
                System.out.println(lstName.get(i) + ":" + lstCount.get(i) + "/" + lstTotal.get(i));
            }
        }
    }
}


LabelBo.java

import java.util.ArrayList;
import java.util.List;


/**
 * 
 * 描述: 标签对象.
 * @author 
 *
 */
public class LabelBo {
    private String lableName;
    private String itemName;
    // 目标项对应的值.
    private List<String> lstName;
    private List<Integer> lstCount;
    private List<Integer> lstTotal;
    
    /**
     * 构造方法.
     */
    public LabelBo() {
        this.lstCount = new ArrayList<Integer>();
        this.lstName = new ArrayList<String>();
        this.lstTotal = new ArrayList<Integer>();
    }
    
    public void setLableName(String lableName) {
        this.lableName = lableName;
    }
    
    public String getLableName() {
        return lableName;
    }
    
    public void setItemName(String itemName) {
        this.itemName = itemName;
    }
    
    public String getItemName() {
        return itemName;
    }
    
    public List<String> getLstName() {
        return lstName;
    }
    
    public List<Integer> getLstCount() {
        return lstCount;
    }
    
    public List<Integer> getLstTotal() {
        return lstTotal;
    }
    
    /**
     * 添加标签对应的种类名称.
     * @param name 名称.
     * @return 下标.
     */
    public int addName(String name) {
        if (!this.lstName.contains(name)) {
            this.lstName.add(name);
        }
        return this.lstName.indexOf(name);
    }

    /**
     * 添加标签对应的种类名称的数量.
     * @param index 下标.
     * @param count 数量.
     */
    public void addCount(int index, Integer count) {
        if (this.lstCount.size() - 1 < index) {
            this.lstCount.add(count);
            return;
        }
        int temp = this.lstCount.get(index) + count;
        this.lstCount.set(index, temp);
    }

    public void setLstRate(List<Integer> lstTotal) {
        this.lstTotal = lstTotal;
    }
}
  • 大小: 1.5 KB
  • 大小: 1.2 KB
  • 大小: 1.6 KB
  • 大小: 1.3 KB
  • 大小: 2 KB
  • 大小: 670 Bytes
  • 大小: 1.1 KB
  • 大小: 1.8 KB
分享到:
评论

相关推荐

    贝叶斯推断及其互联网应用.doc

    贝叶斯推断的核心在于先验概率与后验概率的概念,以及可能性函数的作用,它能够帮助我们更准确地估计事件发生的概率,尤其在面对大量数据和不确定性时。随着大数据和高性能运算的发展,贝叶斯推断在各种领域中都显示...

    贝叶斯网络20题目.docx

    4.贝叶斯网络的推断:在贝叶斯网络中,推断是指根据已知数据计算未知数据的概率。 5.高阶联合概率计算低阶联合概率:高阶联合概率是指多个变量之间的联合概率,而低阶联合概率是指少数变量之间的联合概率。 6....

    全概率公式和贝叶斯公式的证明与应用

    全概率公式和贝叶斯公式是概率论中的两个核心概念,它们在统计推断、机器学习、信息检索、人工智能等领域有着广泛的应用。这篇毕业论文详细探讨了这两个公式的理论证明及其实际应用。 全概率公式(Total ...

    贝叶斯方法 概率编程与贝叶斯推断

    贝叶斯方法是一种基于概率论的统计分析方法,它在数据科学、机器学习以及人工智能等...通过研读这些材料,你可以系统地学习贝叶斯推断的原理,学习如何构建和应用概率模型,并掌握如何利用编程工具进行高效的推断计算。

    贝叶斯统计推断 统计学习

    在统计学习中,贝叶斯推断被广泛应用,特别是在处理不确定性和数据信息不完整时,贝叶斯方法能够提供一种灵活的框架来更新和调整概率判断。 贝叶斯定理的提出,虽然在其生前未被正式发表,但在贝叶斯去世后,由其...

    贝叶斯方法 概率编程与贝叶斯推断 附代码

    本资料主要探讨了概率编程与贝叶斯推断的概念,并提供了相关的代码实例,帮助读者深入理解并掌握这些理论知识。 首先,我们来详细了解贝叶斯方法。贝叶斯方法的核心是贝叶斯定理,它描述了在已知观测数据的情况下,...

    统计经典贝叶斯推断论文

    在实际应用中,贝叶斯推断不仅限于简单事件的概率计算,它还可以应用于更复杂的模型,如贝叶斯网络、贝叶斯回归等,这些模型能够处理多变量之间的关系,为各类实际问题提供解决方案。贝叶斯推断的优势在于它能够结合...

    贝叶斯滤波器及其应用

    综上所述,贝叶斯滤波器是一类基于贝叶斯推断的统计滤波技术,它可以适应各种噪声分布和系统模型的不确定性。从理论到应用,贝叶斯滤波器及其分支技术,如卡尔曼滤波器和粒子滤波器,在许多领域都有广泛的应用,如...

    贝叶斯推断在MCDB分布式平台上的实现.pdf

    在信息技术和数据分析领域,贝叶斯推断是一种重要的统计分析方法,它基于贝叶斯定理,允许我们在已知某些先验信息的情况下,更新我们对未知参数的概率分布的理解。在大数据和分布式计算的背景下,贝叶斯推断的应用变...

    大数据-算法-现代经济管理中的线性贝叶斯推断理论与多总体贝叶斯分类识别方法研究.pdf

    同时,论文还提出了基于贝叶斯推断的随机误差序列自相关诊断和单位根检验方法,以及方差已知和未知情况下的贝叶斯均值控制图。 对于多方程模型系统,论文证明了矩阵正态-Wishart先验分布作为模型参数的共轭先验分布...

    贝叶斯思维:统计建模的PYTHON学习法

    3. **贝叶斯推断**:探讨如何利用MCMC(马尔科夫链蒙特卡洛)方法,如Metropolis-Hastings算法和Gibbs采样,进行高维复杂模型的推断。 4. **PyMC3**:讲解如何使用PyMC3库进行贝叶斯模型的搭建和求解,包括定义随机...

    贝叶斯的博弈:数学、思维与人工智能.docx

    贝叶斯博弈是一种基于贝叶斯概率理论的动态博弈模型,在不完全信息环境下,参与者根据自身观察到的信息进行推断和决策的一种博弈方式。它在数学、思维和人工智能领域具有重要的应用价值,被广泛应用于金融、医疗、...

    贝叶斯统计-习题答案)借鉴.pdf

    贝叶斯推断是指通过已知条件来推断未知概率的过程。在实际应用中,贝叶斯推断通常用于分析随机事件,例如在机器学习和人工智能领域,贝叶斯方法被用于分类、预测等任务。 3. 概率分布的应用: 文件内容中提到了多种...

    高斯贝叶斯进行概率估计_贝叶斯估计_贝叶斯估计_贝叶斯概率_wherels3_print_

    高斯贝叶斯概率估计是一种在统计学和机器学习领域广泛应用的方法,特别是在处理分类问题时。这个算法基于贝叶斯定理,结合了先验概率和似然性来估计未知参数。在这里,我们主要探讨以下几个核心概念: 1. **贝叶斯...

    贝叶斯算法讲义.pdf

    贝叶斯算法是一种统计学上用于推理的算法,由英国数学家托马斯·贝叶斯(Thomas Bayes)提出,用于在已知某些其他概率的情况下计算某个事件的概率。贝叶斯算法的核心在于逆概率问题的求解,即在已知某些条件下求解另...

    贝叶斯统计的入门介绍书籍

    ### 二、贝叶斯推断在正态分布中的应用 #### 2.1 贝叶斯推断的本质 - 讨论贝叶斯推断的基本原理,强调利用先验知识的重要性。 #### 2.2 正态先验与似然函数 - 当数据服从正态分布时,如何选择合适的先验分布,并...

    贝叶斯短信文本分类,应用于手机客户端

    简单来说,贝叶斯定理描述了事件A发生的条件下,事件B发生的概率(P(B|A))如何通过已知的先验概率P(A)和条件概率P(A|B)来计算。在文本分类中,事件A可以代表一个类别,事件B则是文档中的一条特征。 ### 短信分类...

    朴素贝叶斯实战.pptx贝叶斯分类器ppt代码全

    条件概率是事件A在已知事件B发生的前提下的概率,表示为P(A|B)。全概率公式则用于计算事件A的概率,通过将样本空间划分为互斥的事件B1, B2, ..., Bn,并利用加法规则,公式可以表示为P(A) = ∑[P(Bi) * P(A|Bi)]。 ...

    5-贝叶斯算法.7z5-贝叶斯算法.7z

    8. **贝叶斯推断**:在数据稀疏或者不确定性强的情况下,贝叶斯推断可以帮助我们估计参数的不确定性,提供更全面的见解。 9. **贝叶斯非参数方法**:这种方法不预先设定模型的参数数量,而是让数据自适应地确定模型...

Global site tag (gtag.js) - Google Analytics