关于TensorFlow中tfrecords文件的的简易教程

it2022-05-05  146

本文属于课程笔记,源自曹健老师的”人工智能实践:Tensorflow笔记”(侵删):https://www.icourse163.org/learn/PKU-1002536002#/learn/announce

tfrecords是一种二进制文件格式,理论上它可以保存任何格式的信息,可将图片和标签制作该改格式文件,使用tfrecords进行存储,可提高内存利用率。 其中:tf.train.Example用来存储训练数据 。 训练数据的特征用键值对的形式表示。 如“ img_raw ” :值, ”label ”: 值 。值的参数分别是 BytesList/FloatList/Int64List,别对应于取值为二进制数,浮点数,整数特征。SerializeToString( ) 把数据序列化成字符串存储。

下列代码为制作,获取,使用tfrecords格式数据集,其中输入图像为28*28的手写数字图像,输出为0-9的数字标签

#coding:utf-8 import tensorflow as tf import numpy as np from PIL import Image import os image_train_path='./mnist_data_jpg/mnist_train_jpg_60000/' label_train_path='./mnist_data_jpg/mnist_train_jpg_60000.txt' tfRecord_train='./data/mnist_train.tfrecords' image_test_path='./mnist_data_jpg/mnist_test_jpg_10000/' label_test_path='./mnist_data_jpg/mnist_test_jpg_10000.txt' tfRecord_test='./data/mnist_test.tfrecords' data_path='./data' resize_height = 28 resize_width = 28 def write_tfRecord(tfRecordName, image_path, label_path): #接收路径/文件名,图像路径,标签路径 #创建一个新的writer(实例化) writer = tf.python_io.TFRecordWriter(tfRecordName) #图片数量以显示进度 num_pic = 0 #打开标签文件(txt文件,格式:图片名(空格)标签),读取内容 f = open(label_path, 'r') contents = f.readlines() f.close() for content in contents: #分隔每行内容 value = content.split() #图片路径:图片路径+图片名 img_path = image_path + value[0] img = Image.open(img_path) #转换为二进制数据 img_raw = img.tobytes() #初始化,并将标签位赋值为1 labels = [0] * 10 labels[int(value[1])] = 1 # 把每张图片和标签封装到example中 (img_raw与labels) example = tf.train.Example(features=tf.train.Features(feature={ 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels)) })) #将example序列化 writer.write(example.SerializeToString()) num_pic += 1 print ("the number of picture:", num_pic) writer.close() print("write tfrecord successful") def generate_tfRecord(): #判断保存路径是否存在,不存在则创建路径,存在则打印已存在 isExists = os.path.exists(data_path) if not isExists: os.makedirs(data_path) print 'The directory was created successfully' else: print 'directory already exists' #调用write_tfRecord将训练集和验证集的图片和标签写成tfrecords文件。 write_tfRecord(tfRecord_train, image_train_path, label_train_path) write_tfRecord(tfRecord_test, image_test_path, label_test_path) def read_tfRecord(tfRecord_path): #新建文件名队列,告知文件名队列包括那些文件 #tf.train.string_input_producer(string_tensor,num_epochs=None,shuffle=True,seed=None,capacity=32,#shared_name=None,name=None,cancel_op=None) #该函数会生成一个先入先出的队列,文件阅读器会使用它来读取数据。 #string_tensor: 存储图像和标签信息的 TFRecord 文件名列表, #num_epochs: 循环读取的轮数(可选),shuffle :布尔值(可选),如果为 True ,则在每轮随机打乱读取顺序, #seed; 随机读取时设置的种子(可选),capacity :队列容量 #shared_name :(可选 如果设置,该队列将在多个会话中以给定名称共享。所有具有此队列的设备都可以通过 shared_name 访问它。在分布式设置中使用这种方法意味着每个名称只能被访问此操作的其中一个会话看到。 #name :操作的名称(可选),cancel_op :取消队列 None filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True) reader = tf.TFRecordReader() #将读出的每个样本保存到serialized_example中,进行解序列化 _, serialized_example = reader.read(filename_queue) #标签要给出实际分类数 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([10], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string) }) img = tf.decode_raw(features['img_raw'], tf.uint8) img.set_shape([784]) img = tf.cast(img, tf.float32) * (1. / 255) label = tf.cast(features['label'], tf.float32) return img, label def get_tfrecord(num, isTrain=True): #获取tfrecords文件,num每次获取的数据量,isTrain--训练集--True,测试集--False if isTrain: tfRecord_path = tfRecord_train else: tfRecord_path = tfRecord_test img, label = read_tfRecord(tfRecord_path) #这个函数随机读取一个batch的数据 。 #从总样本中顺序取出capacity组数据打乱顺序,每次输出batch_size组 #如果少于min_after_dequeue,会从总样中取数据填满capacity,共使用2个线程 img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size = num, num_threads = 2, capacity = 1000, min_after_dequeue = 700) return img_batch, label_batch def main(): generate_tfRecord() if __name__ == '__main__': main()

除此之外,还可以在方向传播过程(文件)中利用多线程提高图片和标签的批获取效率。 方法:将批获取的操作放到线程协调器开启和关闭之间 开启线程协调器: coord = tf.train.Coordinator( )tf.train.Coordinator( ) threads = tf.train.start_queue_runners(sess=sess, coord=coord)

关闭线程协调器: coord.request_stop( ) coord.join(threads)


最新回复(0)