队列读取样本数据
来自软件实验室
在tensorflow中,读取数据一般有三种。一就是constant定义,二是通过placeholder,三是通过队列的方式读取二进制文件。 比较一下,第一种就不说了; 第二种通过批量读取,数据转换,手动输入,在数据量不大的情况下还是听不错的选择,但是在数据量大的情况下,显然速度就跟不上了,效率太低,比如在集群分布中,不可能使用这种方式来喂数据,我们平时来玩玩还是挺不错的。 第三种通过队列的形式,把需要喂的数据放在一个队列中,学过数据结构就知道,速度必然上升,官方都说了:在使用TensorFlow进行异步计算时,队列是一种强大的机制。但官方这里说的队列是图中的tensor,就是节点。 下面说一下如何通过队列读取文件,文件的制作看上一篇就好。下面的代码就是读取上一节制作的二进制文件。 def read_and_decode(filename): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, label if __name__ == '__main__': img, label = read_and_decode("train.tfrecords") img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=2000, min_after_dequeue=1000) #初始化所有的op init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) #启动队列 threads = tf.train.start_queue_runners(sess=sess) for i in range(4): val, l= sess.run([img_batch, label_batch]) #l = to_categorical(l, 12) print(val.shape, l) 下面是控制台打印: Use `tf.global_variables_initializer` instead. ((30, 224, 224, 3), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)) ((30, 224, 224, 3), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)) ((30, 224, 224, 3), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)) ((30, 224, 224, 3), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)) W tensorflow/core/kernels/queue_base.cc:294] _1_input_producer: Skipping cancelled enqueue attempt with queue not closed