`
wang_peng1
  • 浏览: 3942747 次
  • 性别: Icon_minigender_1
  • 来自: 北京
社区版块
存档分类
最新评论

write into and read from a TFRecords file in TensorFlow

 
阅读更多
本文是全文复制 http://www.machinelearninguru.com/deep_learning/tensorflow/basics/tfrecord/tfrecord.html

Introduction

In the previous post we explained the benefits of saving a large dataset in a single HDF5 file. In this post we will learn how to convert our data into the Tensorflow standard format, called TFRecords. When we are training a deep network, we have two options to feed the data into out Tensorflow program: loading the data using pure python code at each step and feed it into a computaion graph or use an input pipeline which takes a list of filenames (any supported format), shuffle them (optional), create a file queue, read, and decode the data. However, TFRecords is the recommended file format for Tensorflow.

In this post, we load, resize and save all the images inside the train folder of the well-known Dogs vs. Cats data set into a single TFRecords file and then load and plot a couple of them as samples. To follow the rest of this post you need to download the train part of the Dogs vs. Cats data set.

List images and their labels

First, we need to list all images and label them. We give each cat image a label = 0 and each dog image a label = 1. The following code list all images, give them proper labels, and then shuffle the data. We also divide the data set into three train (%60), validation (%20), and test parts (%20).

from random import shuffle
import glob
shuffle_data = True  # shuffle the addresses before saving
cat_dog_train_path = 'Cat vs Dog/train/*.jpg'

# read addresses and labels from the 'train' folder
addrs = glob.glob(cat_dog_train_path)
labels = [0 if 'cat' in addr else 1 for addr in addrs]  # 0 = Cat, 1 = Dog

# to shuffle data
if shuffle_data:
    c = list(zip(addrs, labels))
    shuffle(c)
    addrs, labels = zip(*c)
    
# Divide the hata into 60% train, 20% validation, and 20% test
train_addrs = addrs[0:int(0.6*len(addrs))]
train_labels = labels[0:int(0.6*len(labels))]

val_addrs = addrs[int(0.6*len(addrs)):int(0.8*len(addrs))]
val_labels = labels[int(0.6*len(addrs)):int(0.8*len(addrs))]

test_addrs = addrs[int(0.8*len(addrs)):]
test_labels = labels[int(0.8*len(labels)):]

Create a TFRecords file

First we need to load the image and convert it to the data type (float32 in this example) in which we want to save the data into a TFRecords file. Let's write a function which take an image address, load, resize, and return the image in proper data type:

def load_image(addr):
    # read an image and resize to (224, 224)
    # cv2 load images as BGR, convert it to RGB
    img = cv2.imread(addr)
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    return img

Before we can store the data into a TFRecords file, we should stuff it in a protocol buffer called Example. Then, we serialize the protocol buffer to a string and write it to a TFRecords file. Example protocol buffer contains Features. Feature is a protocol to describe the data and could have three types: bytes, float, and int64. In summary, to store your data you need to follow these steps:

Open a TFRecords file using tf.python_io.TFRecordWriter
Convert your data into the proper data type of the feature using tf.train.Int64List, tf.train.BytesList, or  tf.train.FloatList
Create a feature using tf.train.Feature and pass the converted data to it
Create an Example protocol buffer using tf.train.Example and pass the feature to it
Serialize the Example to string using example.SerializeToString()
Write the serialized example to TFRecords file using writer.write
We are going to use the following two functions to create features (Functions are from this Tensorflow Tutorial)

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

train_filename = 'train.tfrecords'  # address to save the TFRecords file
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)
for i in range(len(train_addrs)):
    # print how many images are saved every 1000 images
    if not i % 1000:
        print 'Train data: {}/{}'.format(i, len(train_addrs))
        sys.stdout.flush()
    # Load the image
    img = load_image(train_addrs[i])
    label = train_labels[i]
    # Create a feature
    feature = {'train/label': _int64_feature(label),
               'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    
    # Serialize to string and write on the file
    writer.write(example.SerializeToString())
    
writer.close()
sys.stdout.flush()

and finaly we close the file using: writer.close(). Similarly we write the validation and test data to two other TFRecords files.

# open the TFRecords file
val_filename = 'val.tfrecords'  # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(val_filename)
for i in range(len(val_addrs)):
    # print how many images are saved every 1000 images
    if not i % 1000:
        print 'Val data: {}/{}'.format(i, len(val_addrs))
        sys.stdout.flush()
    # Load the image
    img = load_image(val_addrs[i])
    label = val_labels[i]
    # Create a feature
    feature = {'val/label': _int64_feature(label),
               'val/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    # Serialize to string and write on the file
    writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
# open the TFRecords file
test_filename = 'test.tfrecords'  # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(test_filename)
for i in range(len(test_addrs)):
    # print how many images are saved every 1000 images
    if not i % 1000:
        print 'Test data: {}/{}'.format(i, len(test_addrs))
        sys.stdout.flush()
    # Load the image
    img = load_image(test_addrs[i])
    label = test_labels[i]
    # Create a feature
    feature = {'test/label': _int64_feature(label),
               'test/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    # Serialize to string and write on the file
    writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()

Read the TFRecords file

It's time to learn how to read data from the TFRecords file. To do so, we load the data from the train data in batchs of an arbitrary size and plot images of the 5 batchs. We also check the label of each image. To read from files in tensorflow, you need to do the following steps:

Create a list of filenames: In our case we only have a single file data_path = 'train.tfrecords'. Therefore, our list is gonna be like this: [data_path]
Create a queue to hold filenames: To do so, we use tf.train.string_input_producer tf.train.string_input_producer function which hold filenames in a FIFO queue. it gets the list of filnames. It also has some optional arguments including  num_epochs which indicates the number of epoch you want to to load the data and shuffle which indicates whether to suffle the filenames in the list or not. It is set to True by default.
Define a reader: For files of TFRecords we need to define a TFRecordReader with reader = tf.TFRecordReader(). Now, the reader returns the next record using: reader.read(filename_queue)
Define a decoder: A decoder is needed to decode the record read by the reader. In case of using TFRecords files the decoder should be tf.parse_single_example. it takes a serialized Example and a dictionary which maps feature keys to FixedLenFeature or VarLenFeature values and returns a dictionary which maps feature keys to Tensor values: features = tf.parse_single_example(serialized_example, features=feature)
Convert the data from string back to the numbers: tf.decode_raw(bytes, out_type) takes a Tensor of type string and convert it to typeout_type. However, for labels which have not been converted to string, we just need to cast them using tf.cast(x, dtype)
Reshape data into its original shape: You should reshape the data (image) into it's original shape before serialization using image = tf.reshape(image, [224, 224, 3])
Preprocessing: if you want to do any preprocessing you should do it now.
Batching: Another queue is needed to create batches from the examples. You can create the batch queue using tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10) where capacity is the maximum size of queue, min_after_dequeue is the minimum size of queue after dequeue, and num_threads is the number of threads enqueuing examples. Using more than one thread, it comes up with a faster reading. The first argument in a list of tensors which you want to create batches from.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
data_path = 'train.tfrecords'  # address to save the hdf5 file
with tf.Session() as sess:
    feature = {'train/image': tf.FixedLenFeature([], tf.string),
               'train/label': tf.FixedLenFeature([], tf.int64)}
    # Create a list of filenames and pass it to a queue
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
    # Define a reader and read the next record
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # Decode the record read by the reader
    features = tf.parse_single_example(serialized_example, features=feature)
    # Convert the image data from string back to the numbers
    image = tf.decode_raw(features['train/image'], tf.float32)
    
    # Cast label data into int32
    label = tf.cast(features['train/label'], tf.int32)
    # Reshape image data into the original shape
    image = tf.reshape(image, [224, 224, 3])
    
    # Any preprocessing here ...
    
    # Creates batches by randomly shuffling tensors
    images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10)


Initialize all global and local variables
Filing the example queue: Some functions of tf.train such as tf.train.shuffle_batch add tf.train.QueueRunner objects to your graph. Each of these objects hold a list of enqueue op for a queue to run in a thread. Therefore, to fill a queue you need to call tf.train.start_queue_runners which starts threades for all the queue runners in the graph. However, to manage these threads you need a tf.train.Coordinator to terminate the threads at the proper time.
Everything is ready. Now you can read a batch and plot all batch images and labels. Do not forget to stop the threads (by stopping the cordinator) when you are done with your reading process.

# Initialize all global and local variables
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    # Create a coordinator and run all QueueRunner objects
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for batch_index in range(5):
        img, lbl = sess.run([images, labels])
        img = img.astype(np.uint8)
        for j in range(6):
            plt.subplot(2, 3, j+1)
            plt.imshow(img[j, ...])
            plt.title('cat' if lbl[j]==0 else 'dog')
        plt.show()
    # Stop the threads
    coord.request_stop()
    
    # Wait for threads to stop
    coord.join(threads)
    sess.close()

 

分享到:
评论

相关推荐

    处理Excel的插件文档

    In JAVA project, class HSSFWorkbook not only can help to read data from excel file, it can also copy one excel from excel template file and write data into new excel file. In node project, I found ...

    Great .Bas Module file which can enables you to Read Write I

    Great .Bas Module file which can enables you to Read Write INI files Read Write the Registry a Numerical Encrypter Decrypter which will encrypt strings of data into numbers and back and a customizable...

    Foundations for Analytics with Python O-Reilly-2016-Clinton W. Brownley

    From there, the examples mirror the sqlite3 examples, including creating a database and table, loading data in a CSV input file into a database table, updat‐ ing records in a table using a CSV input...

    File_实用案例_实现文件拷贝_FileCopy.java

    // Read a chunk of bytes into the buffer, then write them out, // looping until we reach the end of the file (when read() returns // -1). Note the combination of assignment and comparison in this ...

    S7A驱动720版本

    - From the S7A power tool it wasn't possible to open the online help file. Now it can be opened and also the context-sensitive help works properly. - When the signal conditioning function "S&M" ...

    Advanced Apple Debugging & Reverse Engineering v0.9.5

    You’ll look at a special register used to tell the processor where it should read the next instruction from, as well as how different sizes and groupings of memory can produce very different results...

    TX SDK 3.08.1.0

    Reading a .dwg file in multi-threaded (MT) mode is now turned on and off using a new system variable, MTMODE. The system variable is turned off by default. Note that MT mode is still supported for ...

    VclZip pro v3.10.1

    PLEASE TAKE A LOOK AT THE "WHAT's NEW IN THIS VERSION" LINK IN THE HELP FILE AS IT HAS CONVENIENT LINKS TO ALL OF THE NEW TOPICS. ==================== Version 3.10 Build 1 - Several bug fixes. - ...

    BURNINTEST--硬件检测工具

    in a log file. - Updated license management, in an attempt to remove a rare crash on startup. Release 5.3 build 1012.0002 WIN32 release 31 October 2007 - New build of Rebooter (64-bit Windows ...

    squashfs1.3r3.tar.gz

    in cases where a .tar.gz file may be used), and in constrained block device/memory systems (e.g. embedded systems) where low overhead is needed. The section 'mksquashfs' gives information on using ...

    计算机网络第六版答案

    Wifi (802.11) In a wireless LAN, wireless users transmit/receive packets to/from an base station (i.e., wireless access point) within a radius of few tens of meters. The base station is typically ...

    Practical Data Wrangling

    Read a csv file into python and R, and print out some statistics on the data. Gain knowledge of the data formats and programming stuctures involved in retrieving API data. Make effective use of ...

    a project model for the FreeBSD Project.7z

    Committers fall into three groups: committers who are only concerned with one area of the project (for instance file systems), committers who are involved only with one sub-project and committers who ...

    VC技术内幕第五版.chm

    After searching in vain for such a book and spending a year learning MFC the old-fashioned way, I decided to write one myself. It's the book you hold in your hands. And it's the book I would like to ...

    Lerner -- Python Workout. 50 Essential Exercises -- 2020.pdf

    - **Objective:** Parse the `/etc/passwd` file and convert it into a dictionary. - **Key Concepts:** - File parsing using regular expressions. - Creating dictionaries from parsed data. 20. **Word ...

    RxLib控件包内含RxGIF,全部源码及DEMO

    TFormStorage allows you to read and write virtually any component published property to an INI file or the system Registry with virtually no code. Works with 3rd party and your own custom controls as ...

    微软内部资料-SQL性能优化2

    A hard page fault results in a read from disk, either a page file or memory-mapped file. A soft page fault is resolved from one of the modified, standby, free or zero page transition lists. Paging is...

Global site tag (gtag.js) - Google Analytics