徐昭春

记录、学习、成长~

自我要求:认真、沉稳、有原则


花式tfrecord读取

最近因为推荐服务项目的需要,我们在尝试实现阿里的BST(behavior sequence transformer)模型用于微博的用户行为序列建模时,给模型输入训练数据遇到了一些问题:1) 训练文件过大不能一次全部读入内存(tf限定了tensor的最大大小为2g); 2) 输入batch需要获取shape,placeholder不能满足要求。几番资料查阅后,最终选择了tfrecord格式用于模型训练。好了背景先介绍到此,开始直击今天的主题吧—————tfrecord的读取/转换。

1. tfrecord格式介绍

tf现有3种方式进行数据读取:

  1. 直接加载数据,适合于数据量小的情况
  2. 文件读取,例如csv格式读取
  3. tfrecord读取,适合于数据量大的情况

tfrecord数据文件是一种将特征数据与标签统一存储的二进制文件,能够更好的利用内存,可以tf中快速复制、移动、读取、存储。tfrecord以字典的方式一次写一个样本。这样在写入tfrecord文件时,将输入$x_i$与标签$y$一起存储时,将$x_i$和$y$均看作feature,后续读取时按字典以key的方式得到特征和标签。

2. tfrecord生成

这里,我这边一条样本包含如下元素:

  1. label,样本标签
  2. steps, 样本的序列长度
  3. types, 样本序列中各step的action_type
  4. embedding, 样本序列各step的embedding向量

假设拿到上述各元素列表,将其转换为tfrecord,具体实现方式如下:

import collections
import tensorflow as tf

def make_tfrecord(labels, steps, types, embeddings, tfpath):
    """
    labels, label数组,shape=(n,)
    steps, step长度数组,shape=(n,)
    types, action_type数组,shape=(n, max_len)
    embeddings, embedding数组,shape=(n, max_len, hidden_size)
    """

    writer = tf.python_io.TFRecordWriter(out_path)

    for i in range(len(labels)):
        feature_ = collections.OrderedDict()
        feature_["label"] = tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]]))
        feature_["steps"] = tf.train.Feature(int64_list=tf.train.Int64List(value=[steps[i]]))
        feature_["types"] = tf.train.Feature(int64_list=tf.train.Int64List(value=types[i].tolist()))
        feature_["embedding"] = tf.train.Feature(float_list=tf.train.FloatList(value=embedding[i].reshape(-1).tolist()))

        tf_example = tf.train.Example(features=tf.train.Features(feature=feature_))
        writer.write(tf_example.SerializeToString())
    
    writer.close()

3. tfrecord读取

在读取tfrecord时,tf提供了很多方便的API供开发者使用,读取时方式比较多。

1. 使用tf.data.TFRecordDataset

需要实现一个parse_function,解析每一个example

def parse_funciton(example):
    name_to_features = {
        "label": tf.FixedLenFeature(shape=(), dtype=tf.int64),
        "steps": tf.FixedLenFeature(shape=(), dtype=tf.int64),
        "types": tf.FixedLenFeature(shape=(max_len,), dtype=tf.int64),
        "embedding": tf.FixedLenFeature(shape=(max_len, hidden_size))
    }
    parsed_example = tf.parse_single_example(example, name_to_feature)

    return parsed_example

def get_features(tfpath, num_epoch, batch_size, shuffle=False):
    """
    tfpath, tfrecord文件路径
    num_epoch, 读取tfrecord重复次数,对应于训练的epoch数
    batch_size, 每次返回的样本数,对应于训练的batch大小
    shuffle, 是否需要打散
    """

    dataset = tf.data.TFRecordDataset([tfpath]).map(parse_function)
    if shuffle:
        dataset = dataset.shuffle(128*batch_size)   # 在128*batch_size大小的缓存中进行打散
    dataset = dataset.repeat(num_epoch)
    dataset = dataset.batch(batch_size).prefech(1)  # 预先读取一个batch大小

    return dataset

2. 使用tf.data.TFRecordDataset

使用同上API的另一种方式

def read_tfrecord(filepath, batch_size, num_epochs, is_training=True):
    """
    tfrecord样本读取
    """ 
    name_to_features = {
        "label" : tf.FixedLenFeature(shape=(), dtype=tf.int64),
        "steps" : tf.FixedLenFeature(shape=(), dtype=tf.int64),
        "types" : tf.FixedLenFeature(shape=(max_len,), dtype=tf.int64),
        "embedding" : tf.FixedLenFeature(shape=(max_len, hidden_units), dtype=tf.float32)
    }

    def _decode_record(record, name_to_features):
        example = tf.parse_single_example(record, name_to_features)
        example["label"] = tf.cast(example["label"], tf.int64)
        example["steps"] = tf.cast(example["steps"], tf.int64)
        example["types"] = tf.cast(example["types"], tf.int64)
        example["embedding"] = tf.cast(example["embedding"], tf.float32)
        return example

    data = tf.data.TFRecordDataset(filepath)

    if is_training:
        data = data.shuffle(buffer_size=50000)
        data = data.repeat(num_epochs)

    data = data.apply(
        tf.contrib.data.map_and_batch(
            lambda record:_decode_record(record, name_to_features),
            batch_size=batch_size
        )
    )

    return data

3. 使用tf.train.string_input_producer

相对来说稍繁琐些,直接贴代码了

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
 
    name_to_features = {
        "label":tf.FixedLenFeature(shape=(), dtype=tf.int64),
        "steps":tf.FixedLenFeature(shape=(), dtype=tf.int64),
        "types":tf.FixedLenFeature(shape=(max_len,), dtype=tf.int64),
        "embedding":tf.FixedLenFeature(shape=(max_len, hidden_size), dtype=tf.float32)
    }

    parsed_example = tf.parse_single_example(serialized_example, name_to_features)
    return parsed_example

def read_tfrecord(tfpath, num_epochs):
    filename_queue = tf.train.string_input_producer([tfpath], num_epochs=num_epochs, shuffle=True)
    parsed_example = read_and_decode(filename_queue)

    min_after_dequeue = 10000
    capacity = min_after_dequeue + 3 * pm.batch_size
    parsed_example_batch = tf.train.shuffle_batch(
        parsed_example, 
        batch_size=pm.batch_size, 
        capacity=capacity,
        min_after_dequeue=min_after_dequeue
    )

    return parsed_example_batch

4. session中batch读取

在前面得到batch后,还需要一些其它步骤将数据打印出来。这里以get_batch为例读取数据

features = get_batch(tfpath, 1, 256)
iter = tf.data.Iterator.from_structure(features.output_types, features.output_shapes)
batch = iter.get_next()
labels, steps, types, embeddings = batch["label"], batch["steps"], batch["types"], batch["embedding"]
init_op = iter.make_initializer(features)

with tf.Session as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(init_op)
    batch_steps = sess.run(steps)

    print(batch_steps.shape)

至此,完成tfrecord的生成和读取,谢谢阅读!

打赏一个呗

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码支持
扫码打赏,你说多少就多少

打开支付宝扫一扫,即可进行扫码打赏哦