数据管道 #

tf.data API 概述 #

tf.data API 是 TensorFlow 提供的高效数据管道构建工具,支持从各种数据源创建数据集,并进行灵活的数据转换和批处理。

核心概念 #

text
┌─────────────────────────────────────────────────────────────┐
│                    tf.data 数据管道                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  数据源 ──► 转换 ──► 批处理 ──► 预取 ──► 训练              │
│                                                             │
│  Dataset ──► map ──► batch ──► prefetch ──► Model          │
│            filter                                            │
│            shuffle                                           │
│            repeat                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

创建数据集 #

从内存数据创建 #

python
import tensorflow as tf
import numpy as np

# 从 NumPy 数组
x_train = np.random.random((1000, 784)).astype(np.float32)
y_train = np.random.randint(10, size=(1000,))

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
print(f"数据集大小: {tf.data.experimental.cardinality(dataset)}")

# 从张量
x_tensor = tf.random.normal([1000, 784])
y_tensor = tf.random.uniform([1000], maxval=10, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x_tensor, y_tensor))

# 单个张量
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
for item in dataset:
    print(item.numpy(), end=' ')

从文件创建 #

python
import tensorflow as tf

# 从文本文件
text_dataset = tf.data.TextLineDataset(['file1.txt', 'file2.txt'])

# 从 CSV 文件
csv_dataset = tf.data.experimental.make_csv_dataset(
    'data.csv',
    batch_size=32,
    label_name='label',
    num_epochs=1
)

# 从 TFRecord 文件
tfrecord_dataset = tf.data.TFRecordDataset(['data.tfrecord'])

# 从多个文件
files = tf.data.Dataset.list_files('data/*.tfrecord')
dataset = files.interleave(
    tf.data.TFRecordDataset,
    cycle_length=4,
    num_parallel_calls=tf.data.AUTOTUNE
)

从生成器创建 #

python
import tensorflow as tf

def data_generator():
    for i in range(100):
        yield tf.random.normal([784]), tf.random.uniform([], maxval=10, dtype=tf.int32)

dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_signature=(
        tf.TensorSpec(shape=(784,), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
)

for x, y in dataset.take(3):
    print(f"x shape: {x.shape}, y: {y.numpy()}")

范围数据集 #

python
import tensorflow as tf

# 数值范围
dataset = tf.data.Dataset.range(10)
print(list(dataset.as_numpy_iterator()))

# 带步长
dataset = tf.data.Dataset.range(0, 100, 10)
print(list(dataset.as_numpy_iterator()))

数据转换 #

map 转换 #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(10)

# 简单映射
dataset = dataset.map(lambda x: x * 2)
print(list(dataset.as_numpy_iterator()))

# 带并行处理
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(
    lambda x: x ** 2,
    num_parallel_calls=tf.data.AUTOTUNE
)

# 复杂映射
def preprocess(x, y):
    x = tf.cast(x, tf.float32) / 255.0
    x = tf.reshape(x, [28, 28, 1])
    y = tf.one_hot(y, depth=10)
    return x, y

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

filter 过滤 #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(100)

# 过滤偶数
dataset = dataset.filter(lambda x: x % 2 == 0)
print(list(dataset.as_numpy_iterator())[:10])

# 过滤条件
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.filter(lambda x, y: y < 5)

shuffle 打乱 #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(100)

# 打乱数据
dataset = dataset.shuffle(buffer_size=1000, seed=42, reshuffle_each_iteration=True)

# 完整示例
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(10000).batch(32)

repeat 重复 #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(5)

# 重复 3 次
dataset = dataset.repeat(3)
print(list(dataset.as_numpy_iterator()))

# 无限重复
dataset = tf.data.Dataset.range(5).repeat()
for i, item in enumerate(dataset):
    if i >= 15:
        break
    print(item.numpy(), end=' ')

take 和 skip #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(100)

# 取前 10 个
subset = dataset.take(10)
print("take(10):", list(subset.as_numpy_iterator()))

# 跳过前 10 个
rest = dataset.skip(10)
print("skip(10) first 5:", list(rest.take(5).as_numpy_iterator()))

# 分割数据集
train_data = dataset.take(80)
test_data = dataset.skip(80)

批处理 #

batch 批处理 #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(100)

# 批处理
batched = dataset.batch(32)
for batch in batched:
    print(f"Batch shape: {batch.shape}")

# 带丢弃不完整批次
batched = dataset.batch(32, drop_remainder=True)
print(f"批次数: {len(list(batched))}")

padded_batch 填充批处理 #

python
import tensorflow as tf

# 变长序列
dataset = tf.data.Dataset.from_generator(
    lambda: [tf.range(n) for n in [3, 5, 2, 7, 4]],
    output_signature=tf.TensorSpec(shape=(None,), dtype=tf.int32)
)

# 填充批处理
batched = dataset.padded_batch(
    batch_size=2,
    padded_shapes=(None,),
    padding_values=-1
)

for batch in batched:
    print(batch.numpy())

unbatch 解批处理 #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(10).batch(3)
print("Batched:", list(dataset.as_numpy_iterator()))

unbatched = dataset.unbatch()
print("Unbatched:", list(unbatched.as_numpy_iterator()))

预取与缓存 #

prefetch 预取 #

python
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(1000)
dataset = dataset.batch(32)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

# 最佳实践:prefetch 放在管道最后
dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1000)
    .batch(32)
    .prefetch(tf.data.AUTOTUNE)
)

cache 缓存 #

python
import tensorflow as tf

# 内存缓存
dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .map(preprocess)
    .cache()
    .shuffle(1000)
    .batch(32)
    .prefetch(tf.data.AUTOTUNE)
)

# 文件缓存
dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .map(preprocess)
    .cache('./cache/train')
    .shuffle(1000)
    .batch(32)
    .prefetch(tf.data.AUTOTUNE)
)

高级操作 #

interleave 交错读取 #

python
import tensorflow as tf

# 从多个文件交错读取
files = tf.data.Dataset.list_files('data/*.csv')
dataset = files.interleave(
    lambda file: tf.data.TextLineDataset(file).skip(1),
    cycle_length=4,
    block_length=16,
    num_parallel_calls=tf.data.AUTOTUNE
)

flat_map 扁平映射 #

python
import tensorflow as tf

dataset = tf.data.Dataset.range(3)

# 每个元素映射为多个元素
dataset = dataset.flat_map(lambda x: tf.data.Dataset.range(x * 2, (x + 1) * 2))
print(list(dataset.as_numpy_iterator()))

concatenate 连接 #

python
import tensorflow as tf

dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(10, 15)

concatenated = dataset1.concatenate(dataset2)
print(list(concatenated.as_numpy_iterator()))

zip 合并 #

python
import tensorflow as tf

dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(100, 105)

zipped = tf.data.Dataset.zip((dataset1, dataset2))
print(list(zipped.as_numpy_iterator()))

完整数据管道示例 #

图像分类数据管道 #

python
import tensorflow as tf

def load_image(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

def augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.2)
    image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
    return image, label

image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']
labels = [0, 1, 0]

train_dataset = (
    tf.data.Dataset.from_tensor_slices((image_paths, labels))
    .map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    .map(augment, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(1000)
    .batch(32)
    .prefetch(tf.data.AUTOTUNE)
)

文本数据管道 #

python
import tensorflow as tf

tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=10000)

def encode_text(text, label):
    text = tokenizer.texts_to_sequences([text.numpy().decode()])[0]
    text = tf.keras.preprocessing.sequence.pad_sequences([text], maxlen=100)[0]
    return text, label

def tf_encode(text, label):
    return tf.py_function(encode_text, [text, label], [tf.int32, tf.int32])

texts = ['hello world', 'tensorflow is great', 'deep learning']
labels = [0, 1, 0]

dataset = (
    tf.data.Dataset.from_tensor_slices((texts, labels))
    .map(tf_encode, num_parallel_calls=tf.data.AUTOTUNE)
    .padded_batch(32, padded_shapes=([100], []))
    .prefetch(tf.data.AUTOTUNE)
)

下一步 #

现在你已经掌握了数据管道构建,接下来学习 损失函数,了解如何选择合适的损失函数!

最后更新:2026-04-04