`
zhimaruanjian
  • 浏览: 32650 次
  • 性别: Icon_minigender_1
文章分类
社区版块
存档分类
最新评论

芝麻HTTP:TensorFlow LSTM MNIST分类

 
阅读更多

本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间。

初始化 首先我们可以先初始化一些变量,如学习率、节点单元数、RNN 层数等:

learning_rate = 1e-3
num_units = 256
num_layer = 3
input_size = 28
time_step = 28
total_steps = 2000
category_num = 10
steps_per_validate = 100
steps_per_test = 500
batch_size = tf.placeholder(tf.int32, [])
keep_prob = tf.placeholder(tf.float32, [])

然后还需要声明一下 MNIST 数据生成器:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

接下来常规声明一下输入的数据,输入数据用 x 表示,标注数据用 y_label 表示:

x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])

这里输入的 x 维度是 [None, 784],代表 batch_size 不确定,输入维度 784,y_label 同理。

接下来我们需要对输入的 x 进行 reshape 操作,因为我们需要将一张图分为多个 time_step 来输入,这样才能构建一个 RNN 序列,所以这里直接将 time_step 设成 28,这样一来 input_size 就变为了 28,batch_size 不变,所以reshape 的结果是一个三维的矩阵:

x_shape = tf.reshape(x, [-1, time_step, input_size])

RNN 层 接下来我们需要构建一个 RNN 模型了,这里我们使用的 RNN Cell 是 LSTMCell,而且要搭建一个三层的 RNN,所以这里还需要用到 MultiRNNCell,它的输入参数是 LSTMCell 的列表。

所以我们可以先声明一个方法用于创建 LSTMCell,方法如下:

def cell(num_units):
    cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)
    return DropoutWrapper(cell, output_keep_prob=keep_prob)

这里还加入了 Dropout,来减少训练过程中的过拟合。

接下来我们再利用它来构建多层的 RNN:

cells = tf.nn.rnn_cell.MultiRNNCell([cell(num_units) for _ in range(num_layer)])

注意这里使用了 for 循环,每循环一次新生成一个 LSTMCell,而不是直接使用乘法来扩展列表,因为这样会导致 LSTMCell 是同一个对象,导致构建完 MultiRNNCell 之后出现维度不匹配的问题。

接下来我们需要声明一个初始状态:

h0 = cells.zero_state(batch_size, dtype=tf.float32)

然后接下来调用 dynamic_rnn() 方法即可完成模型的构建了:

output, hs = tf.nn.dynamic_rnn(cells, inputs=x_shape, initial_state=h0)

这里 inputs 的输入就是 x 做了 reshape 之后的结果,初始状态通过 initial_state 传入,其返回结果有两个,一个 output 是所有 time_step 的输出结果,赋值为 output,它是三维的,第一维长度等于 batch_size,第二维长度等于 time_step,第三维长度等于 num_units。另一个 hs 是隐含状态,是元组形式,长度即 RNN 的层数 3,每一个元素都包含了 c 和 h,即 LSTM 的两个隐含状态。

这样的话 output 的最终结果可以取最后一个 time_step 的结果,所以可以使用:

output = output[:, -1, :]

或者直接取隐藏状态最后一层的 h 也是相同的:

h = hs[-1].h

在此模型中,二者是等价的。但注意如果用于文本处理,可能由于文本长度不一,而 padding,导致二者不同。

输出层 接下来我们再做一次线性变换和 Softmax 输出结果即可:

# Output Layer
w = tf.Variable(tf.truncated_normal([num_units, category_num], stddev=0.1), dtype=tf.float32)
b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32)
y = tf.matmul(output, w) + b
# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)

这里的 Loss 直接调用了 softmax_cross_entropy_with_logits 先计算了 Softmax,然后计算了交叉熵。

训练和评估 最后再定义训练和评估的流程即可,在训练过程中每隔一定的 step 就输出 Train Accuracy 和 Test Accuracy:

# Train
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)

# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

# Train
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(total_steps + 1):
        batch_x, batch_y = mnist.train.next_batch(100)
        sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, batch_size: batch_x.shape[0]})
        # Train Accuracy
        if step % steps_per_validate == 0:
            print('Train', step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5,
                                                               batch_size: batch_x.shape[0]}))
        # Test Accuracy
        if step % steps_per_test == 0:
            test_x, test_y = mnist.test.images, mnist.test.labels
            print('Test', step,
                  sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]}))

运行 直接运行之后,只训练了几轮就可以达到 98% 的准确率:

Train 0 0.27
Test 0 0.2223
Train 100 0.87
Train 200 0.91
Train 300 0.94
Train 400 0.94
Train 500 0.99
Test 500 0.9595
Train 600 0.95
Train 700 0.97
Train 800 0.98

可以看出来 LSTM 在做 MNIST 字符分类的任务上还是比较有效的。

分享到:
评论

相关推荐

    PyTorch案例:使用LSTM进行文本分类.zip

    PyTorch案例:使用LSTM进行文本分类.zip PyTorch案例:使用LSTM进行文本分类.zip PyTorch案例:使用LSTM进行文本分类.zip PyTorch案例:使用LSTM进行文本分类.zip PyTorch案例:使用LSTM进行文本分类.zip PyTorch...

    基于tensorflow lstm模型的彩票预测.zip

    基于tensorflow lstm模型的彩票预测

    Stanford CS 20: Tensorflow for Deep Learning Research

    【Tensorflow深度学习详解】 Tensorflow,由Google Brain团队开发,是目前最广泛使用的深度学习框架之一。在“Stanford CS 20: Tensorflow for Deep Learning Research”课程中,学员将深入理解如何利用Tensorflow...

    基于tensorflow LSTM+CNN+CRF的命名实体识别算法python源码+项目说明.zip

    【资源说明】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和...基于tensorflow LSTM+CNN+CRF的命名实体识别算法python源码+项目说明.zip

    自然语言处理课程实验:基于LSTM的命名实体识别

    python编写的简单程序,一共只有130多行,...给每个输入和其对应编号建立一个张量 构成训练批 输入LSTM单元 输入全连接层 使用sorftmax或其他分类器进行预测 模型构建 pytorch自带LSTM类/其他工具也可以/自己编码也可以

    上海理工大学C语言课程设计作业:基于LSTM模型的头条号热词分析.zip

    上海理工大学C语言课程设计作业:基于LSTM模型的头条号热词分析.zip上海理工大学C语言课程设计作业:基于LSTM模型的头条号热词分析.zip上海理工大学C语言课程设计作业:基于LSTM模型的头条号热词分析.zip上海理工...

    TensorFlow LSTM 写诗代码与数据

    标题中的“TensorFlow LSTM 写诗代码与数据”揭示了我们今天将深入探讨的主题:如何使用TensorFlow库中的长短期记忆网络(LSTM)来创作诗歌。LSTM是一种特殊的循环神经网络(RNN),在处理序列数据,如文本,语音等...

    social-lstm-tf-master.zip_LSTM tensorflow_TensorFlow LSTM_social

    标题中的“social-lstm-tf-master.zip_LSTM tensorflow_TensorFlow LSTM_social”表明这是一个与社交场景相关的项目,使用了TensorFlow库实现LSTM(长短期记忆网络)模型。描述中的“social lstm tensorflow”进一步...

    tensorflow下用LSTM网络进行时间序列预测

    在TensorFlow框架下,利用LSTM(长短期记忆网络)进行时间序列预测是一种常见的机器学习技术,尤其适用于处理具有时间依赖性的序列数据,如股票价格、天气预报、电力消耗等。LSTM是一种特殊的循环神经网络(RNN),...

    人工智能实践:Tensorflow笔记.zip

    《人工智能实践:TensorFlow笔记》是一份深度探讨人工智能领域中TensorFlow框架的综合学习资料。TensorFlow是由Google Brain团队开发的开源库,主要用于数值计算,广泛应用于机器学习和深度学习领域。本压缩包包含了...

    Tensorflow-LSTM-股票预测DEMO-注释版

    **Tensorflow LSTM 股票预测详解** 在金融领域,股票价格预测是一项具有挑战性的任务,因为价格走势受到多种复杂因素的影响,如市场情绪、宏观经济数据、公司新闻等。近年来,随着深度学习技术的发展,尤其是长短期...

    基于TensorFlow的LSTM循环神经网络短期电力负荷预测.pdf

    基于TensorFlow的LSTM循环神经网络短期电力负荷预测 本文主要介绍了基于TensorFlow的LSTM循环神经网络短期电力负荷预测算法。该算法通过利用深度学习技术和LSTM循环神经网络,实现了对电力负荷的高精度预测。实验...

    keras tensorflow lstm 多变量序列的预测 + 数据文件

    标题中的“keras tensorflow lstm 多变量序列的预测”指的是使用Keras库,一个高级神经网络API,构建基于TensorFlow的LSTM(长短期记忆)模型来预测多变量时间序列数据。LSTM是一种特殊的循环神经网络(RNN),在...

    人工智能实践:Tensorflow个人学习笔记

    1. **图像分类**:使用TensorFlow构建卷积神经网络(CNN)进行图像识别,如MNIST手写数字识别。 2. **自然语言处理**:利用LSTM处理文本序列数据,如情感分析、机器翻译。 3. **推荐系统**:通过协同过滤或深度学习...

    LSTMCryptoPricePrediction:TensorFlow LSTM神经网络可预测加密货币的未来价格。 从Sentdex的教程开始

    在从事LSTM和Transformers预测加密货币价格之前的最后一个介绍项目。 (供我在Deepn担任AI开发人员时的实习) 依存关系 Python 水蟒 Jupyter笔记本 TensorFlow 脾气暴躁的 Matplotlib 大熊猫 斯克莱恩 Unidecode...

    LSTMBtcPricePrediction:TensorFlow LSTM神经网络可预测加密货币的未来价格

    TensorFlowTest介绍了解有关tensorflow,jupyter笔记本,numpy,matplotlib,sklearn和其他库的信息,以快速构建和测试神经网络。依存关系Python 水蟒Jupyter笔记本TensorFlow 脾气暴躁的Matplotlib 大熊猫斯克莱恩...

    LSTM.zip_LSTM_LSTM tensorflow_TensorFlow LSTM_图像识别;

    在TensorFlow中实现LSTM,通常使用`tf.keras.layers.LSTM` API。以下是一个简单的示例: ```python import tensorflow as tf model = tf.keras.models.Sequential([ tf.keras.layers.LSTM(64, input_shape=...

    CNN和LSTM识别MNIST数据集.zip

    在本项目中,我们主要探讨了两种深度学习模型——卷积神经网络(CNN)和长短时记忆网络(LSTM)在识别MNIST手写数字数据集上的应用。MNIST数据集是机器学习领域中非常经典的图像识别数据集,它包含了60,000个训练...

    人工智能实践:Tensorflow笔记2 课件&源码.7z

    4. **数据输入与预处理**:如何使用Tensorflow的数据集API加载和预处理数据,例如MNIST、CIFAR-10等经典数据集。 5. **模型构建**:介绍如何构建常见的神经网络结构,如全连接层(Dense)、卷积神经网络(CNN)、...

Global site tag (gtag.js) - Google Analytics