接着上一篇的 读取, 这篇是cnn的识别代码。 很多网上的代码其实都是可以直接跑的, 我只不过自己码了一遍 理解了一遍
# -*- coding: utf-8 -*- import numpy as np w_alpha=0.01 b_alpha=0.1 IMAGE_HEIGHT = 240 IMAGE_WIDTH = 320 MAX_CAPTCHA = 1 # 图片种类37 CHAR_SET_LEN = 37 dropout = 0.7 conv_dict = { # 第一层卷积参数 3*3, 因为是彩色图片 所以第一层输入通道是3, 输出为32 "w_1": tf.Variable(w_alpha * tf.random_normal([3, 3, 3, 32]), name='w_1'), "b_1": tf.Variable(b_alpha * tf.random_normal([32]), name='b_1'), # 第二层卷积参数 "w_2": tf.Variable(w_alpha * tf.random_normal([3, 3, 32, 64]), name='w_2'), "b_2": tf.Variable(b_alpha * tf.random_normal([64]), name='b_2'), # 第三层卷积参数 "w_3": tf.Variable(w_alpha * tf.random_normal([3, 3, 64, 128]), name='w_3'), "b_3": tf.Variable(b_alpha * tf.random_normal([128]), name='b_3'), # 第四层卷积参数 "w_4": tf.Variable(w_alpha * tf.random_normal([3, 3, 128, 128]), name='w_4'), "b_4": tf.Variable(b_alpha * tf.random_normal([128]), name='b_4'), 'out': tf.Variable(tf.random_normal([1024, CHAR_SET_LEN])), 'out_add': tf.Variable(tf.random_normal([CHAR_SET_LEN])) } # 批量标准化 - 防止 梯度弥散 # wx_plus_b tensor # out_size 通道数 def batch_normal(wx_plus_b, out_size): fc_mean, fc_var = tf.nn.moments( wx_plus_b, axes=[0, 1, 2], # 想要 normalize 的维度, [0] 代表 batch 维度 # 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度 ) # out_size 和wx_plus_b 输出通道数一致 scale = tf.Variable(tf.ones([out_size])) shift = tf.Variable(tf.zeros([out_size])) epsilon = 0.001 wx_plus_b = tf.nn.batch_normalization(wx_plus_b, fc_mean, fc_var, shift, scale, epsilon) return wx_plus_b X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT , IMAGE_WIDTH,3]) Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN]) NOR = tf.placeholder(tf.float32) keep_prob = tf.placeholder(tf.float32) # dropout # 把单个数变成数组 def one_hot_n(x, n): x = np.array(x) return np.eye(n)[x] def conv2d(conv, cd1, cd2, out_size, nor): conv = tf.nn.bias_add(tf.nn.conv2d(conv, cd1, strides=[1, 1, 1, 1], padding='SAME'), cd2) # 做 batch_normal # if nor > 1: # conv = batch_normal(conv, out_size) conv = tf.nn.relu(conv) conv = tf.nn.max_pool(conv, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') # dropout 防止 过拟合 conv = tf.nn.dropout(conv, keep_prob) return conv # 定义CNN def crack_captcha_cnn(): # 四层卷积池化 conv1 = conv2d(X, conv_dict['w_1'], conv_dict['b_1'], 32, NOR) conv2 = conv2d(conv1, conv_dict['w_2'], conv_dict['b_2'], 64, NOR) conv3 = conv2d(conv2, conv_dict['w_3'], conv_dict['b_3'], 128, NOR) conv4 = conv2d(conv3, conv_dict['w_4'], conv_dict['b_4'], 128, NOR) # Fully connected layer 全连接 # 240/16=15 320/16=20 w_d = tf.Variable(w_alpha * tf.random_normal([15 * 20 * 128, 1024])) b_d = tf.Variable(b_alpha * tf.random_normal([1024])) dense = tf.reshape(conv4, [-1, w_d.get_shape().as_list()[0]]) dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d)) out = tf.add(tf.matmul(dense, conv_dict['out']), conv_dict['out_add']) return out # 读取tfrecrods 数据 def read_and_decode(filename): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [IMAGE_HEIGHT, IMAGE_WIDTH, 3]) # normalize img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, label # 训练 def train_crack_captcha_cnn(): output = crack_captcha_cnn() # softmax ,sigmoid 第一个是用于单结果, 第二个用于多个结果 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=Y)) #loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(output, Y)) # optimizer 为了加快训练 learning_rate应该开始大,然后慢慢衰 optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) # Evaluate model # 给出pred在 横向维度上的最大值的 index . prd tensor, 1 横向维度 , 返回的是boolen correct_pred = tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1)) # 把boolean 转成 浮点数据 , 求平均值 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) img, label = read_and_decode("anm_pic_train.tfrecords") img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=7000,min_after_dequeue=1000) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # saver.restore(sess, tf.train.latest_checkpoint('/home/root/wtf/yzm/code/')) step = 0 # img, label = read_and_decode("anm_pic_train.tfrecords") # img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=64, capacity=70000,min_after_dequeue=1000) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) while True: # for i in range(3000): imgs, labs = sess.run([img_batch, label_batch]) # print (labs) one_hot_labs = sess.run(tf.cast(one_hot_n(labs, CHAR_SET_LEN), tf.float32)) sess.run(optimizer, feed_dict={X: imgs, Y: one_hot_labs, keep_prob: dropout, NOR: 1.}) if step % 50 == 0: acc = sess.run( accuracy, feed_dict={X: imgs, Y: one_hot_labs, keep_prob: 1., NOR: 1.}) print(step, acc) if acc > 0.5: saver.save(sess, "crack_capcha.model", global_step=step) print("Complete!!") coord.request_stop() coord.join(threads) sess.close() break step += 1 # print("Complete!!") # coord.request_stop() # coord.join(threads) # sess.close() train_crack_captcha_cnn()
1.batch_noraml 是为了防止梯度弥散的,但是 到底是不是放在激活之前还不清楚,而且怎么做if 语句....
2.这个代码是用 cpu跑的, 别问我为什么用cpu,穷。 有条件的最好用gpu跑, 省心啊
跑了半天的结果:
推荐博客地址:
http://blog.topspeedsnail.com
https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-13-BN/
相关推荐
可能包含的文件可能有实验指导文档、代码示例、配置文件、数据分析报告等,这些都能帮助学习者深入理解和应用CCN技术。 为了更深入地学习和实践,你可以从以下几个角度入手: 1. **理论理解**:阅读文档,了解CCN的...
使用 TensorFlow 实现 YOLOv3YOLOv3 TensorFlow为船只构建实时边界框物体检测系统(使用基于 COCO 数据集上训练的 YOLOv3-416 权重的 TensorFlow 微调)。然后使用我自己的数据集来区分不同类型的船只受到YAD2K、...
在深入研究这个压缩包的内容时,开发者可以学习到如何构建一个基于CCN的网络通信框架,了解内容寻址和数据分发的机制,同时还能掌握网络游戏中常见的同步算法,如状态同步、命令同步等。此外,通过分析源代码,还...
### 基于CCN的动态自适应流媒体传输 #### 概述 本文献介绍了一种基于内容中心网络(Content Centric Networking,简称CCN)的动态自适应流媒体传输系统的设计与实现。该系统主要面向移动网络环境,并且利用了国际标准...
这时一个基于omnet++的内容中心网络(CCN)仿真程序,你可以通过它实现网络拓扑,并仿真运行路由的一系列的转发策略,缓存决策 和替换策略等。
"内容中心网络(CCN)" 内容中心网络(CCN)是一种新型的网络架构体系,其设计原则是方便内容的存取,以解决当前互联网存在的弊端。当前互联网存在的问题包括网络拥堵、信息分享困难、网络安全漏洞等,基于 TCP/IP ...
Epson QMEMS工艺晶体振荡器SG5032CCN是一款先进的电子元件,广泛应用于各类电子设备中,如通信系统、计算机、工业控制和消费电子产品等。该器件的核心技术在于其采用的QMEMS(Quantum Micro-Electromechanical ...
受启发, 完整的细节在输入到CCN(功能块) 一般3个音阶产品特点测试克隆此文件夹将Darknet中的预训练权重转换为keras(可以将此etape跳过为etape 3) wget python3 convert.py yolov3.cfg yolov3.weights yolov3....
光驱改机械111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111
根据提供的文件信息,我们可以归纳出一系列与计算机网络相关的知识点,主要围绕SCU计算机网络课程的选择题展开。下面将对每一道题目所涉及的核心概念进行详细解释。 ### 1.... - **问题**: 在下列选项中,哪一项不是...
针对内容中心网络中的缓存污染攻击问题,以污染内容数量、分布状态和攻击强度3个参数对缓存污染攻击进行定量描述和分析,建立了攻击下的节点缓存状态模型。通过分析节点关键参数的变化,提出了基于节点状态模型的...
C-jpn.ccn
《互联网协议在计算机通信与网络(CCN)中的应用》 在计算机通信和网络(CCN)领域,互联网协议起着至...通过深入学习"18-InternetProtocols.ppt",我们可以更好地理解和运用这些关键知识,提升在CCN领域的专业能力。
《基于8051和ADC0809CCN的数据采集设计》 本文将深入探讨一个基于8051微控制器和ADC0809CCN模数转换器的数据采集系统的设计过程。8051单片机是广泛应用的微处理器,而ADC0809CCN则是一款8位模拟数字转换器,两者...
CCN仿真软件CCNSIM的使用教程,包括各个模块的用途、参数的设置、输出文件的分析处理等信息。