您现在的位置是:网站首页> 编程资料编程资料

Tensorflow中TFRecord生成与读取的实现_python_

2023-05-26 374人已围观

简介 Tensorflow中TFRecord生成与读取的实现_python_

一、为什么使用TFRecord?

正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

二、 生成TFRecord简单实现方式

我们可以分成两个部分来介绍如何生成TFRecord,分别是TFRecord生成器以及样本Example模块。

  • TFRecord生成器
writer = tf.python_io.TFRecordWriter(record_path) writer.write(tf_example.SerializeToString()) writer.close() 

这里面writer就是我们TFrecord生成器。接着我们就可以通过writer.write(tf_example.SerializeToString())来生成我们所要的tfrecord文件了。这里需要注意的是我们TFRecord生成器在写完文件后需要关闭writer.close()。这里tf_example.SerializeToString()是将Example中的map压缩为二进制文件,更好的节省空间。那么tf_example是如何生成的呢?那就是下面所要介绍的样本Example模块了。

  • Example模块
    首先们来看一下Example协议块是什么样子的。
message Example { Features features = 1; }; message Features { map feature = 1; }; message Feature { oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } }; 

我们可以看出上面的tf_example可以写入的数据形式有三种,分别是BytesList, FloatList以及Int64List的类型。那我们如何写一个tf_example呢?下面有一个简单的例子。

def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) tf_example = tf.train.Example( features=tf.train.Features(feature={ 'image/encoded': bytes_feature(encoded_jpg), 'image/format': bytes_feature('jpg'.encode()), 'image/class/label': int64_feature(label), 'image/height': int64_feature(height), 'image/width': int64_feature(width)})) 

下面我们来好好从外部往内部分解来解释一下上面的内容。
(1)tf.train.Example(features = None) 这里的features是tf.train.Features类型的特征实例。
(2)tf.train.Features(feature = None) 这里的feature是以字典的形式存在,*key:要保存数据的名字    value:要保存的数据,但是格式必须符合tf.train.Feature实例要求。

三、 生成TFRecord文件完整代码实例

首先我们需要提供数据集

图片文件夹

通过图片文件夹我们可以知道这里面总共有七种分类图片,类别的名称就是每个文件夹名称,每个类别文件夹存储各自的对应类别的很多图片。下面我们通过一下代码(generate_annotation_json.pygenerate_tfrecord.py)生成train.record。

  • generate_annotation_json.py
# -*- coding: utf-8 -*- # @Time : 2018/11/22 22:12 # @Author : MaochengHu # @Email : wojiaohumaocheng@gmail.com # @File : generate_annotation_json.py # @Software: PyCharm import os import json def get_annotation_dict(input_folder_path, word2number_dict): label_dict = {} father_file_list = os.listdir(input_folder_path) for father_file in father_file_list: full_father_file = os.path.join(input_folder_path, father_file) son_file_list = os.listdir(full_father_file) for image_name in son_file_list: label_dict[os.path.join(full_father_file, image_name)] = word2number_dict[father_file] return label_dict def save_json(label_dict, json_path): with open(json_path, 'w') as json_path: json.dump(label_dict, json_path) print("label json file has been generated successfully!") 
  • generate_tfrecord.py
# -*- coding: utf-8 -*- # @Time : 2018/11/23 0:09 # @Author : MaochengHu # @Email : wojiaohumaocheng@gmail.com # @File : generate_tfrecord.py # @Software: PyCharm import os import tensorflow as tf import io from PIL import Image from generate_annotation_json import get_annotation_dict flags = tf.app.flags flags.DEFINE_string('images_dir', '/data2/raycloud/jingxiong_datasets/six_classes/images', 'Path to image(directory)') flags.DEFINE_string('annotation_path', '/data1/humaoc_file/classify/data/annotations/annotations.json', 'Path to annotation') flags.DEFINE_string('record_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record', 'Path to TFRecord') FLAGS = flags.FLAGS def int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def process_image_channels(image): process_flag = False # process the 4 channels .png if image.mode == 'RGBA': r, g, b, a = image.split() image = Image.merge("RGB", (r,g,b)) process_flag = True # process the channel image elif image.mode != 'RGB': image = image.convert("RGB") process_flag = True return image, process_flag def process_image_reshape(image, resize): width, height = image.size if resize is not None: if width > height: width = int(width * resize / height) height = resize else: width = resize height = int(height * resize / width) image = image.resize((width, height), Image.ANTIALIAS) return image def create_tf_example(image_path, label, resize=None): with tf.gfile.GFile(image_path, 'rb') as fid: encode_jpg = fid.read() encode_jpg_io = io.BytesIO(encode_jpg) image = Image.open(encode_jpg_io) # process png pic with four channels image, process_flag = process_image_channels(image) # reshape image image = process_image_reshape(image, resize) if process_flag == True or resize is not None: bytes_io = io.BytesIO() image.save(bytes_io, format='JPEG') encoded_jpg = bytes_io.getvalue() width, height = image.size tf_example = tf.train.Example( features=tf.train.Features( feature={ 'image/encoded': bytes_feature(encode_jpg), 'image/format': bytes_feature(b'jpg'), 'image/class/label': int64_feature(label), 'image/height': int64_feature(height), 'image/width': int64_feature(width) } )) return tf_example def generate_tfrecord(annotation_dict, record_path, resize=None): num_tf_example = 0 writer = tf.python_io.TFRecordWriter(record_path) for image_path, label in annotation_dict.items(): if not tf.gfile.GFile(image_path): print("{} does not exist".format(image_path)) tf_example = create_tf_example(image_path, label, resize) writer.write(tf_example.SerializeToString()) num_tf_example += 1 if num_tf_example % 100 == 0: print("Create %d TF_Example" % num_tf_example) writer.close() print("{} tf_examples has been created successfully, which are saved in {}".format(num_tf_example, record_path)) def main(_): word2number_dict = { "combinations": 0, "details": 1, "sizes": 2, "tags": 3, "models": 4, "tileds": 5, "hangs": 6 } images_dir = FLAGS.images_dir #annotation_path = FLAGS.annotation_path record_path = FLAGS.record_path annotation_dict = get_annotation_dict(images_dir, word2number_dict) generate_tfrecord(annotation_dict, record_path) if __name__ == '__main__': tf.app.run() 

* 这里需要说明的是generate_annotation_json.py是为了得到图片标注的label_dict。通过这个代码块可以获得我们需要的图片标注字典,key是图片具体地址, value是图片的类别,具体实例如下:

{ "/images/hangs/862e67a8-5bd9-41f1-8c6d-876a3cb270df.JPG": 6, "/images/tags/adc264af-a76b-4477-9573-ac6c435decab.JPG": 3, "/images/tags/fd231f5a-b42c-43ba-9e9d-4abfbaf38853.JPG": 3, "/images/hangs/2e47d877-1954-40d6-bfa2-1b8e3952ebf9.jpg": 6, "/images/tileds/a07beddc-4b39-4865-8ee2-017e6c257e92.png": 5, "/images/models/642015c8-f29d-4930-b1a9-564f858c40e5.png": 4 } 
  • 如何运行代码

(1)首先我们的文件夹构成形式是如下结构,其中images_root是图片根文件夹,combinations, details, sizes, tags, models, tileds, hangs分别存放不同类别的图片文件夹。

- - -图片.jpg -
-图片.jpg - -图片.jpg - -图片.jpg - -图片.jpg - -图片.jpg - -图片.jpg

(2)建立文件夹TFRecord,并将generate_tfrecord.pygenerate_annotation_json.py这两个python文件放入文件夹内,需要注意的是我们需要将 generate_tfrecord.py文件中字典word2number_dict换成自己的字典(即key是放不同类别的图片文件夹名称,value是对应的分类number)

 word2number_dict = { "combinations": 0, "details": 1, "sizes": 2, "tags": 3, "models": 4, "tileds": 5, "hangs": 6 } 

(3)直接执行代码 python3/python2 ./TFRecord/generate_tfrecord.py --image_dir="images_root地址" --record_path="你想要保存record地址(.record文件全路径)"即可。如下是一个实例:

python3 generate_tfrecord.py --image_dir /images/ --record_path /classify/data/train_tfrecord/train.record 

TFRecord读取

上面我们介绍了如何生成TFRecord,现在我们尝试如何通过使用队列读取读取我们的TFRecord。
读取TFRecord可以通过tensorflow两个个重要的函数实现,分别是tf.train.string_input_producertf.TFRecordReadertf.parse_single_example解析器。如下图

AnimatedFileQueues.gif

四、 读取TFRecord的简单实现方式

解析TFRecord有两种解析方式一种是利用tf.parse_single_example, 另一种是通过tf.contrib.slim(* 推荐使用)。

 第一种方式(tf.parse_single_example)解析步骤如下

(1).第一步,我们将train.record文件读入到队列中,如下所示:
filename_queue = tf.train.string_input_producer([tf

-六神源码网