`

机器学习初识之KNN算法

 
阅读更多

         刚刚开始在一个视频上学习机器学习,不懂的还是很多,这也算作是学习机器学习的笔记吧

KNN算法,K nearest neighbor 最近的K个邻居,了解一个算法,先从了解一个问题开始,现在问题如下,有很多的数字图片,每个图片上面有一个数字,每个图片是28*28像素的的,灰度值从0~255,我们把每个图片看作是一个1X784的一行矩阵,因为784=28*28,矩阵上的数字大小表示该像素点的灰度值,有一些已知的图像和未知的图像,当然这里的图像都是使用矩阵表示的,现在需要通过已知图像的数字来预测未知图像的数字

 

KNN算法其实是比较需要预测的和已知的结果的用例之间的相似度,寻找相似度最接近的K个已知用例作为预测和分类结果

我们这里使用的相似度比较方法是余弦比较,计算公式如图片所示,所谓余弦比较,就是将带预测的像素矩阵与已知的矩阵的每行求余弦乘积,在所有的乘积中选取最大的数值的那一组作为预测值,因为越大越接近一的表示两者相似度越高

 

这是在一个黑板课教学视频python的算法,可供参考

 

# -*- coding: utf-8 -*-

import pandas as pd
import numpy as np
import time

def normalize(x):
    """
    linalg.norm(x), return sum(abs(xi)**2)**0.5
    apply_along_axis(func, axis, x),
    """
    norms = np.apply_along_axis(np.linalg.norm, 1, x) + 1.0e-7
    return x / np.expand_dims(norms, -1)

def normalize2(x):
    """
    linalg.norm(x), return sum(abs(xi)**2)**0.5
    apply_along_axis(func, axis, x),
    """
    norms = np.apply_along_axis(np.mean, 1, x) + 1.0e-7
    return x - np.expand_dims(norms, -1)

    
def nearest_neighbor(norm_func,train_x, train_y, test_x):
    train_x = norm_func(train_x)
    test_x = norm_func(test_x)
    
    # cosine
    corr = np.dot(test_x, np.transpose(train_x))
    argmax = np.argmax(corr, axis=1)
    preds = train_y[argmax]

    return preds

def validate(preds, test_y):
    count = len(preds)
    correct = (preds == test_y).sum()
    return float(correct) / count

if __name__=='__main__':
    TRAIN_NUM = 220
    TEST_NUM = 420
    # Read data 42000
    data = pd.read_csv('train.csv')
##    print data
    train_data = data.values[0:TRAIN_NUM,1:]
    train_label = data.values[0:TRAIN_NUM,0]
    test_data = data.values[TRAIN_NUM:TEST_NUM,1:]
    test_label = data.values[TRAIN_NUM:TEST_NUM,0]

    norm_funcs =  [normalize,normalize2]
    for norm_f in norm_funcs:
        t = time.time()
        preds = nearest_neighbor(norm_f,train_data, train_label, test_data)
        acc = validate(preds, test_label)
        print("%s Validation Accuracy: %f, %.2fs" % (norm_f.__name__,acc, time.time() - t))


这段代码里面包含两种比较函数,其中第一种就是使用余弦回归计算得到的,我们看一下预测结果:

 

normalize Validation Accuracy: 0.815000, 0.12s
normalize2 Validation Accuracy: 0.770000, 0.03s

 可见,使用余弦比较相似度的算法最终的识别率能达到八成以上,当然数据量越大越准确,当数据到达2200个的时候,准确率能到90%

另外还附加一个数据文件(见附录)

 

 

  • 大小: 59.5 KB
  • 大小: 2.9 KB
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics