`

tensorflow ccn图片学习

阅读更多

 接着上一篇的 读取, 这篇是cnn的识别代码。 很多网上的代码其实都是可以直接跑的, 我只不过自己码了一遍 理解了一遍

 

# -*- coding: utf-8 -*-
import numpy as np

w_alpha=0.01
b_alpha=0.1

IMAGE_HEIGHT = 240
IMAGE_WIDTH = 320
MAX_CAPTCHA = 1
# 图片种类37
CHAR_SET_LEN = 37
dropout = 0.7


conv_dict = {
    # 第一层卷积参数 3*3, 因为是彩色图片 所以第一层输入通道是3, 输出为32
    "w_1": tf.Variable(w_alpha * tf.random_normal([3, 3, 3, 32]), name='w_1'),
    "b_1": tf.Variable(b_alpha * tf.random_normal([32]), name='b_1'),
    # 第二层卷积参数
    "w_2": tf.Variable(w_alpha * tf.random_normal([3, 3, 32, 64]), name='w_2'),
    "b_2": tf.Variable(b_alpha * tf.random_normal([64]), name='b_2'),
    # 第三层卷积参数
    "w_3": tf.Variable(w_alpha * tf.random_normal([3, 3, 64, 128]), name='w_3'),
    "b_3": tf.Variable(b_alpha * tf.random_normal([128]), name='b_3'),
    # 第四层卷积参数
    "w_4": tf.Variable(w_alpha * tf.random_normal([3, 3, 128, 128]), name='w_4'),
    "b_4": tf.Variable(b_alpha * tf.random_normal([128]), name='b_4'),

    'out': tf.Variable(tf.random_normal([1024, CHAR_SET_LEN])),
    'out_add': tf.Variable(tf.random_normal([CHAR_SET_LEN]))
}

# 批量标准化 - 防止 梯度弥散
# wx_plus_b tensor
# out_size  通道数
def batch_normal(wx_plus_b, out_size):
    fc_mean, fc_var = tf.nn.moments(
        wx_plus_b,
        axes=[0, 1, 2],  # 想要 normalize 的维度, [0] 代表 batch 维度
        # 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度
    )
    # out_size 和wx_plus_b 输出通道数一致
    scale = tf.Variable(tf.ones([out_size]))
    shift = tf.Variable(tf.zeros([out_size]))
    epsilon = 0.001
    wx_plus_b = tf.nn.batch_normalization(wx_plus_b, fc_mean, fc_var, shift, scale, epsilon)
    return wx_plus_b



X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT , IMAGE_WIDTH,3])
Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN])
NOR = tf.placeholder(tf.float32)
keep_prob = tf.placeholder(tf.float32)  # dropout

# 把单个数变成数组
def one_hot_n(x, n):
    x = np.array(x)
    return np.eye(n)[x]


def conv2d(conv, cd1, cd2, out_size, nor):
    conv = tf.nn.bias_add(tf.nn.conv2d(conv, cd1, strides=[1, 1, 1, 1], padding='SAME'), cd2)
    # 做 batch_normal
#    if nor > 1:
#        conv = batch_normal(conv, out_size)
    conv = tf.nn.relu(conv)
    conv = tf.nn.max_pool(conv, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    # dropout 防止 过拟合
    conv = tf.nn.dropout(conv, keep_prob)
    return conv


# 定义CNN
def crack_captcha_cnn():

    # 四层卷积池化
    conv1 = conv2d(X, conv_dict['w_1'], conv_dict['b_1'], 32, NOR)
    conv2 = conv2d(conv1, conv_dict['w_2'], conv_dict['b_2'], 64, NOR)
    conv3 = conv2d(conv2, conv_dict['w_3'], conv_dict['b_3'], 128, NOR)
    conv4 = conv2d(conv3, conv_dict['w_4'], conv_dict['b_4'], 128, NOR)

    # Fully connected layer  全连接
    # 240/16=15  320/16=20
    w_d = tf.Variable(w_alpha * tf.random_normal([15 * 20 * 128, 1024]))
    b_d = tf.Variable(b_alpha * tf.random_normal([1024]))

    dense = tf.reshape(conv4, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))

    out = tf.add(tf.matmul(dense, conv_dict['out']), conv_dict['out_add'])

    return out

# 读取tfrecrods 数据
def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
       features={
           'label': tf.FixedLenFeature([], tf.int64),
           'img_raw' : tf.FixedLenFeature([], tf.string),
       })
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
    # normalize
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)
    return img, label

# 训练
def train_crack_captcha_cnn():

    output = crack_captcha_cnn()
    # softmax  ,sigmoid  第一个是用于单结果, 第二个用于多个结果
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=Y))
    #loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(output, Y))
    # optimizer 为了加快训练 learning_rate应该开始大,然后慢慢衰
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
    # Evaluate model
    # 给出pred在 横向维度上的最大值的 index . prd tensor, 1 横向维度 , 返回的是boolen
    correct_pred = tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1))
    # 把boolean 转成 浮点数据 , 求平均值
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    img, label = read_and_decode("anm_pic_train.tfrecords")
    img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=7000,min_after_dequeue=1000)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # saver.restore(sess, tf.train.latest_checkpoint('/home/root/wtf/yzm/code/'))
        step = 0
#        img, label = read_and_decode("anm_pic_train.tfrecords")
#       img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=64, capacity=70000,min_after_dequeue=1000)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
        while True:
#       for i in range(3000):
            imgs, labs = sess.run([img_batch, label_batch])
#           print (labs)
            one_hot_labs = sess.run(tf.cast(one_hot_n(labs, CHAR_SET_LEN), tf.float32))
            sess.run(optimizer, feed_dict={X: imgs, Y: one_hot_labs, keep_prob: dropout, NOR: 1.})
            if step % 50 == 0:
                acc = sess.run( accuracy, feed_dict={X: imgs, Y: one_hot_labs, keep_prob: 1., NOR: 1.})
                print(step, acc)
                if acc > 0.5:
                    saver.save(sess, "crack_capcha.model", global_step=step)
                    print("Complete!!")
                    coord.request_stop()
                    coord.join(threads)
                    sess.close()
                    break
            step += 1
#       print("Complete!!")
#       coord.request_stop()
#        coord.join(threads)
#        sess.close()

train_crack_captcha_cnn()

 

 1.batch_noraml 是为了防止梯度弥散的,但是 到底是不是放在激活之前还不清楚,而且怎么做if 语句眨眼....   

 2.这个代码是用 cpu跑的, 别问我为什么用cpu,穷。 有条件的最好用gpu跑, 省心啊

 

跑了半天的结果:



 

 

推荐博客地址:

http://blog.topspeedsnail.com

https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-13-BN/

 

  • 大小: 12.5 KB
分享到:
评论

相关推荐

    CCN.rar_CCN

    可能包含的文件可能有实验指导文档、代码示例、配置文件、数据分析报告等,这些都能帮助学习者深入理解和应用CCN技术。 为了更深入地学习和实践,你可以从以下几个角度入手: 1. **理论理解**:阅读文档,了解CCN的...

    使用 TensorFlow 实现 YOLOv3.zip

    使用 TensorFlow 实现 YOLOv3YOLOv3 TensorFlow为船只构建实时边界框物体检测系统(使用基于 COCO 数据集上训练的 YOLOv3-416 权重的 TensorFlow 微调)。然后使用我自己的数据集来区分不同类型的船只受到YAD2K、...

    ccn.rar_CCN网络

    在深入研究这个压缩包的内容时,开发者可以学习到如何构建一个基于CCN的网络通信框架,了解内容寻址和数据分发的机制,同时还能掌握网络游戏中常见的同步算法,如状态同步、命令同步等。此外,通过分析源代码,还...

    基于CCN的动态自适应流媒体传输

    ### 基于CCN的动态自适应流媒体传输 #### 概述 本文献介绍了一种基于内容中心网络(Content Centric Networking,简称CCN)的动态自适应流媒体传输系统的设计与实现。该系统主要面向移动网络环境,并且利用了国际标准...

    CCN仿真程序CCNSim0.3

    这时一个基于omnet++的内容中心网络(CCN)仿真程序,你可以通过它实现网络拓扑,并仿真运行路由的一系列的转发策略,缓存决策 和替换策略等。

    内容中心网络(CCN)

    "内容中心网络(CCN)" 内容中心网络(CCN)是一种新型的网络架构体系,其设计原则是方便内容的存取,以解决当前互联网存在的弊端。当前互联网存在的问题包括网络拥堵、信息分享困难、网络安全漏洞等,基于 TCP/IP ...

    Epson QMEMS工艺晶体振荡器SG5032CCN数据手册.zip

    Epson QMEMS工艺晶体振荡器SG5032CCN是一款先进的电子元件,广泛应用于各类电子设备中,如通信系统、计算机、工业控制和消费电子产品等。该器件的核心技术在于其采用的QMEMS(Quantum Micro-Electromechanical ...

    YOLOv3-tensorflow:使用TensorFlow实施YOLOv3

    受启发, 完整的细节在输入到CCN(功能块) 一般3个音阶产品特点测试克隆此文件夹将Darknet中的预训练权重转换为keras(可以将此etape跳过为etape 3) wget python3 convert.py yolov3.cfg yolov3.weights yolov3....

    www.chinafix.com迅维网_G4070mG5070m9CCN32WW.part2.rar

    光驱改机械111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111

    ccn选择题chapter01

    根据提供的文件信息,我们可以归纳出一系列与计算机网络相关的知识点,主要围绕SCU计算机网络课程的选择题展开。下面将对每一道题目所涉及的核心概念进行详细解释。 ### 1.... - **问题**: 在下列选项中,哪一项不是...

    CCN中基于节点状态模型的缓存污染攻击检测算法

    针对内容中心网络中的缓存污染攻击问题,以污染内容数量、分布状态和攻击强度3个参数对缓存污染攻击进行定量描述和分析,建立了攻击下的节点缓存状态模型。通过分析节点关键参数的变化,提出了基于节点状态模型的...

    C-jpn.ccn

    C-jpn.ccn

    18-InternetProtocols_hardware_CCN_

    《互联网协议在计算机通信与网络(CCN)中的应用》 在计算机通信和网络(CCN)领域,互联网协议起着至...通过深入学习"18-InternetProtocols.ppt",我们可以更好地理解和运用这些关键知识,提升在CCN领域的专业能力。

    基于8051和ADC0809CCN的数据采集设计

    《基于8051和ADC0809CCN的数据采集设计》 本文将深入探讨一个基于8051微控制器和ADC0809CCN模数转换器的数据采集系统的设计过程。8051单片机是广泛应用的微处理器,而ADC0809CCN则是一款8位模拟数字转换器,两者...

    CCN仿真软件CCNsim详解

    CCN仿真软件CCNSIM的使用教程,包括各个模块的用途、参数的设置、输出文件的分析处理等信息。

Global site tag (gtag.js) - Google Analytics