`
daweibalong
  • 浏览: 45582 次
  • 性别: Icon_minigender_1
  • 来自: 厦门
社区版块
存档分类
最新评论

Machine Learning系列实验--感知机学习

阅读更多

感知机时二分类的线性分类模型,其目的就是寻找通过训练将实例划分为正负两类的分离超平面,其采用的策略是根据现有的超平面和输出值来识别出误分类点,也就是说y*(w*t+b)<=0,并采用随机梯度下降的方法不断修改参数,直至没有误分类点。其实质是最小化误分类点到超平面的距离总和。

具体感知机学习的具体算法,包括原始形式和对偶形式。实验采用的是《统计学习方法》中的例2.1:

 

1)原始形式代码如下:

 

#include <iostream>
using namespace std;

int x[3][2] = {
	{3, 3},
	{4, 3},
	{1, 1}
};

int y[3] = {1, 1, -1};

int w[2] = {0};
int b = 0;

int L(int y, int* x)
{
	int temp = (w[0] * x[0] + w[1] * x[1] + b) * y;
	if (temp <= 0)
		return 1;//存在错误点
	else
		return 0;
}

int main(void)
{
	int j = 1;
	while (true)
	{
		cout << j++ << " ";
		
		int i;
		int num = 0;
		for (i = 0; i < 3; i++)
		{
			if (L(y[i], x[i]) == 1)
			{
				cout << "error point:";
				cout << "x" << i <<" w:";
				int j;
				for (j = 0; j < 2; j++)
				{
					w[j] += y[i] * x[i][j];
					cout << w[j] << " ";
				}
				b += y[i];
				cout << "b:" << b <<endl;
				num++;
				break;
			}
		}
		if (num == 0)
			break;
	}
	return 0;
}

 实验结果:

1 error point:x0 w:3 3 b:1

2 error point:x2 w:2 2 b:0

3 error point:x2 w:1 1 b:-1

4 error point:x2 w:0 0 b:-2

5 error point:x0 w:3 3 b:-1

6 error point:x2 w:2 2 b:-2

7 error point:x2 w:1 1 b:-3

 

这跟p30的结果是一样的,不过要注意的是,在极小化的过程中,为了达到书中的结果,选择的误分类点都是第一次遇到的误分类点,而实际上在选择误分类点时应该采用随机的方法来选取,而且每次梯度下降的时候并不是对所有误分类点进行梯度下降,而是只对随机选择的一个误分类点进行梯度下降。结果与误分类点的选择有关。

 

2)对偶形式,代码如下:

 

#include <iostream>
using namespace std;

int x[3][2] = {
	{3, 3},
	{4, 3},
	{1, 1}
};

int y[3] = {1, 1, -1};

int b = 0;
int a[3] = {0};
int G[3][3] = {
	{18, 21, 6},
	{21, 25, 7},
	{6, 7, 2}
};//Gram matrix
int L(int j)
{
	int temp = 0;
	for (int i=0 ;i < 3; i++)
	{
		temp += a[i] * G[i][j] * y[i];
	}
	temp += b;
	temp *= y[j];
	if (temp <= 0)
		return 1;//存在错误点
	else
		return 0;
}

int main(void)
{
	int j = 1;
	while (true)
	{
		cout << j++ << " ";
		
		int i;
		int num = 0;
		for (i = 0; i < 3; i++)
		{
			if (L(i) == 1)
			{
				cout << "error point:";
				cout << "x" << i <<" a:";
				int j;
				a[i] += 1;
				for (j = 0; j < 3; j++)
				{
					cout << a[j] << " ";
				}
				b += y[i];
				cout << "b:" << b <<endl;
				num++;
				break;
			}
		}
		if (num == 0)
			break;
	}
	return 0;
}

 实验结果如下:

1 error point:x0 a:1 0 0 b:1

2 error point:x2 a:1 0 1 b:0

3 error point:x2 a:1 0 2 b:-1

4 error point:x2 a:1 0 3 b:-2

5 error point:x0 a:2 0 3 b:-1

6 error point:x2 a:2 0 4 b:-2

7 error point:x2 a:2 0 5 b:-3

结果与1)相同。(注:书中p35,k=4的数据有误,应该为a1=1 a2=0 a3=3 b=-2)

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics