数据管道 #
tf.data API 概述 #
tf.data API 是 TensorFlow 提供的高效数据管道构建工具,支持从各种数据源创建数据集,并进行灵活的数据转换和批处理。
核心概念 #
text
┌─────────────────────────────────────────────────────────────┐
│ tf.data 数据管道 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 数据源 ──► 转换 ──► 批处理 ──► 预取 ──► 训练 │
│ │
│ Dataset ──► map ──► batch ──► prefetch ──► Model │
│ filter │
│ shuffle │
│ repeat │
│ │
└─────────────────────────────────────────────────────────────┘
创建数据集 #
从内存数据创建 #
python
import tensorflow as tf
import numpy as np
# 从 NumPy 数组
x_train = np.random.random((1000, 784)).astype(np.float32)
y_train = np.random.randint(10, size=(1000,))
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
print(f"数据集大小: {tf.data.experimental.cardinality(dataset)}")
# 从张量
x_tensor = tf.random.normal([1000, 784])
y_tensor = tf.random.uniform([1000], maxval=10, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x_tensor, y_tensor))
# 单个张量
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
for item in dataset:
print(item.numpy(), end=' ')
从文件创建 #
python
import tensorflow as tf
# 从文本文件
text_dataset = tf.data.TextLineDataset(['file1.txt', 'file2.txt'])
# 从 CSV 文件
csv_dataset = tf.data.experimental.make_csv_dataset(
'data.csv',
batch_size=32,
label_name='label',
num_epochs=1
)
# 从 TFRecord 文件
tfrecord_dataset = tf.data.TFRecordDataset(['data.tfrecord'])
# 从多个文件
files = tf.data.Dataset.list_files('data/*.tfrecord')
dataset = files.interleave(
tf.data.TFRecordDataset,
cycle_length=4,
num_parallel_calls=tf.data.AUTOTUNE
)
从生成器创建 #
python
import tensorflow as tf
def data_generator():
for i in range(100):
yield tf.random.normal([784]), tf.random.uniform([], maxval=10, dtype=tf.int32)
dataset = tf.data.Dataset.from_generator(
data_generator,
output_signature=(
tf.TensorSpec(shape=(784,), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
for x, y in dataset.take(3):
print(f"x shape: {x.shape}, y: {y.numpy()}")
范围数据集 #
python
import tensorflow as tf
# 数值范围
dataset = tf.data.Dataset.range(10)
print(list(dataset.as_numpy_iterator()))
# 带步长
dataset = tf.data.Dataset.range(0, 100, 10)
print(list(dataset.as_numpy_iterator()))
数据转换 #
map 转换 #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(10)
# 简单映射
dataset = dataset.map(lambda x: x * 2)
print(list(dataset.as_numpy_iterator()))
# 带并行处理
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(
lambda x: x ** 2,
num_parallel_calls=tf.data.AUTOTUNE
)
# 复杂映射
def preprocess(x, y):
x = tf.cast(x, tf.float32) / 255.0
x = tf.reshape(x, [28, 28, 1])
y = tf.one_hot(y, depth=10)
return x, y
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
filter 过滤 #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(100)
# 过滤偶数
dataset = dataset.filter(lambda x: x % 2 == 0)
print(list(dataset.as_numpy_iterator())[:10])
# 过滤条件
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.filter(lambda x, y: y < 5)
shuffle 打乱 #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(100)
# 打乱数据
dataset = dataset.shuffle(buffer_size=1000, seed=42, reshuffle_each_iteration=True)
# 完整示例
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(10000).batch(32)
repeat 重复 #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(5)
# 重复 3 次
dataset = dataset.repeat(3)
print(list(dataset.as_numpy_iterator()))
# 无限重复
dataset = tf.data.Dataset.range(5).repeat()
for i, item in enumerate(dataset):
if i >= 15:
break
print(item.numpy(), end=' ')
take 和 skip #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(100)
# 取前 10 个
subset = dataset.take(10)
print("take(10):", list(subset.as_numpy_iterator()))
# 跳过前 10 个
rest = dataset.skip(10)
print("skip(10) first 5:", list(rest.take(5).as_numpy_iterator()))
# 分割数据集
train_data = dataset.take(80)
test_data = dataset.skip(80)
批处理 #
batch 批处理 #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(100)
# 批处理
batched = dataset.batch(32)
for batch in batched:
print(f"Batch shape: {batch.shape}")
# 带丢弃不完整批次
batched = dataset.batch(32, drop_remainder=True)
print(f"批次数: {len(list(batched))}")
padded_batch 填充批处理 #
python
import tensorflow as tf
# 变长序列
dataset = tf.data.Dataset.from_generator(
lambda: [tf.range(n) for n in [3, 5, 2, 7, 4]],
output_signature=tf.TensorSpec(shape=(None,), dtype=tf.int32)
)
# 填充批处理
batched = dataset.padded_batch(
batch_size=2,
padded_shapes=(None,),
padding_values=-1
)
for batch in batched:
print(batch.numpy())
unbatch 解批处理 #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(10).batch(3)
print("Batched:", list(dataset.as_numpy_iterator()))
unbatched = dataset.unbatch()
print("Unbatched:", list(unbatched.as_numpy_iterator()))
预取与缓存 #
prefetch 预取 #
python
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(1000)
dataset = dataset.batch(32)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# 最佳实践:prefetch 放在管道最后
dataset = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(1000)
.batch(32)
.prefetch(tf.data.AUTOTUNE)
)
cache 缓存 #
python
import tensorflow as tf
# 内存缓存
dataset = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.map(preprocess)
.cache()
.shuffle(1000)
.batch(32)
.prefetch(tf.data.AUTOTUNE)
)
# 文件缓存
dataset = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.map(preprocess)
.cache('./cache/train')
.shuffle(1000)
.batch(32)
.prefetch(tf.data.AUTOTUNE)
)
高级操作 #
interleave 交错读取 #
python
import tensorflow as tf
# 从多个文件交错读取
files = tf.data.Dataset.list_files('data/*.csv')
dataset = files.interleave(
lambda file: tf.data.TextLineDataset(file).skip(1),
cycle_length=4,
block_length=16,
num_parallel_calls=tf.data.AUTOTUNE
)
flat_map 扁平映射 #
python
import tensorflow as tf
dataset = tf.data.Dataset.range(3)
# 每个元素映射为多个元素
dataset = dataset.flat_map(lambda x: tf.data.Dataset.range(x * 2, (x + 1) * 2))
print(list(dataset.as_numpy_iterator()))
concatenate 连接 #
python
import tensorflow as tf
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(10, 15)
concatenated = dataset1.concatenate(dataset2)
print(list(concatenated.as_numpy_iterator()))
zip 合并 #
python
import tensorflow as tf
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(100, 105)
zipped = tf.data.Dataset.zip((dataset1, dataset2))
print(list(zipped.as_numpy_iterator()))
完整数据管道示例 #
图像分类数据管道 #
python
import tensorflow as tf
def load_image(path, label):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [224, 224])
image = tf.cast(image, tf.float32) / 255.0
return image, label
def augment(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
return image, label
image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']
labels = [0, 1, 0]
train_dataset = (
tf.data.Dataset.from_tensor_slices((image_paths, labels))
.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(1000)
.batch(32)
.prefetch(tf.data.AUTOTUNE)
)
文本数据管道 #
python
import tensorflow as tf
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=10000)
def encode_text(text, label):
text = tokenizer.texts_to_sequences([text.numpy().decode()])[0]
text = tf.keras.preprocessing.sequence.pad_sequences([text], maxlen=100)[0]
return text, label
def tf_encode(text, label):
return tf.py_function(encode_text, [text, label], [tf.int32, tf.int32])
texts = ['hello world', 'tensorflow is great', 'deep learning']
labels = [0, 1, 0]
dataset = (
tf.data.Dataset.from_tensor_slices((texts, labels))
.map(tf_encode, num_parallel_calls=tf.data.AUTOTUNE)
.padded_batch(32, padded_shapes=([100], []))
.prefetch(tf.data.AUTOTUNE)
)
下一步 #
现在你已经掌握了数据管道构建,接下来学习 损失函数,了解如何选择合适的损失函数!
最后更新:2026-04-04