`
flyfoxs
  • 浏览: 298130 次
  • 性别: Icon_minigender_1
  • 来自: 合肥
社区版块
存档分类
最新评论

机器学习之画蛇添足--线性回归

阅读更多

本系列是理解一些开源项目中已经存在的例子,并配上一些读书笔记,分享出来,也许对其他一些初学者有用, 如果对于分享出来的代码,有问题. 欢迎大家提问交流.

 

下面的代码节选至: 

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/notebooks/2_getting_started.ipynb

 

#@test {"output": "ignore"}
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

# Set up the data with a noisy linear relationship between X and Y.
num_examples = 50
X = np.array([np.linspace(-2, 4, num_examples), np.linspace(-6, 6, num_examples)])

#生成矩阵大小为2*50的正态分布的随机数, randn生成的数据数满足:均值为0,方差为1
#用生成的随机数,给X添加噪音
X += np.random.randn(2, num_examples)
x, y = X

#给x补上bias
x_with_bias = np.array([(1., a) for a in x]).astype(np.float32)


losses = []
training_steps = 50
learning_rate = 0.002

with tf.Session() as sess:
    # Set up all the tensors, variables, and operations.
    input = tf.constant(x_with_bias)
    target = tf.constant(np.transpose([y]).astype(np.float32))
    weights = tf.Variable(tf.random_normal([2, 1], 0, 0.1))

    tf.global_variables_initializer().run()

    yhat = tf.matmul(input, weights)
    yerror = tf.subtract(yhat, target)
    
    #l2不是我们常说的l2 regularaztion(正则化),而是平方差也就是标准误差,更接近L2范式
    loss = tf.nn.l2_loss(yerror)
  
    #如何识别weights就是可以update的变量,然后通过调整来逼近最小值?
    #https://stackoverflow.com/questions/34477889/holding-variables-constant-during-optimizer
    update_weights = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
  
    for _ in range(training_steps):
        # Repeatedly run the operations, updating the TensorFlow variable.
        update_weights.run()
        losses.append(loss.eval())

    # Training is done, get the final values for the graphs
    betas = weights.eval()
    #基于最新的Weights计算Y的估计值
    yhat = yhat.eval()

# Show the fit and the loss over time.
fig, (ax1, ax2) = plt.subplots(1, 2)
#调整2个子图的水平间距
plt.subplots_adjust(wspace=.3)
#设置整个画布的大小(inches)
fig.set_size_inches(10, 4)

#x, y的抽样统计,因为有误差,所以会围绕线条波动
ax1.scatter(x, y, alpha=.7)

# X 和 Y的估值画点, 估值是计算出来的,所以会落在线条上
# c="g" => short color code (rgbcmyk), g is green
ax1.scatter(x, np.transpose(yhat)[0], c="g", alpha=.6)

#画线条(只计算起止的2个点, 2点定一线)
line_x_range = (-4, 6)
ax1.plot(line_x_range, [betas[0] + a * betas[1] for a in line_x_range], "g", alpha=0.6)

#Loss逐渐减少的趋势画图
ax2.plot(range(0, training_steps), losses)
ax2.set_ylabel("Loss")
ax2.set_xlabel("Training steps")
plt.show()

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics