图像数据增强 #

什么是数据增强? #

text
┌─────────────────────────────────────────────────────────────┐
│                    数据增强原理                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  原始图像 ──► 变换 ──► 增强图像                             │
│                                                             │
│  变换类型:                                                  │
│  ├── 几何变换: 旋转、翻转、缩放、平移                       │
│  ├── 颜色变换: 亮度、对比度、饱和度                         │
│  └── 噪声添加: 高斯噪声、椒盐噪声                           │
│                                                             │
│  优点:                                                      │
│  ├── 扩充训练数据                                          │
│  ├── 提高模型泛化能力                                      │
│  └── 减少过拟合                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

ImageDataGenerator #

基本用法 #

python
import keras

datagen = keras.preprocessing.image.ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True
)

(x_train, y_train), _ = keras.datasets.cifar10.load_data()
datagen.fit(x_train)

model.fit(
    datagen.flow(x_train, y_train, batch_size=32),
    epochs=100
)

参数详解 #

python
datagen = keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    rotation_range=0,
    width_shift_range=0.0,
    height_shift_range=0.0,
    brightness_range=None,
    shear_range=0.0,
    zoom_range=0.0,
    channel_shift_range=0.0,
    fill_mode='nearest',
    cval=0.0,
    horizontal_flip=False,
    vertical_flip=False,
    rescale=None
)
text
┌─────────────────────────────────────────────────────────────┐
│                    ImageDataGenerator 参数                   │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  rotation_range: 旋转角度范围 (度)                          │
│  例: 20 表示随机旋转 -20° 到 20°                           │
│                                                             │
│  width_shift_range: 水平平移范围                            │
│  例: 0.2 表示平移 0-20% 的宽度                             │
│                                                             │
│  height_shift_range: 垂直平移范围                           │
│                                                             │
│  horizontal_flip: 水平翻转                                  │
│                                                             │
│  vertical_flip: 垂直翻转                                    │
│                                                             │
│  zoom_range: 缩放范围                                       │
│  例: 0.2 表示缩放 0.8-1.2 倍                               │
│                                                             │
│  shear_range: 剪切变换角度                                  │
│                                                             │
│  brightness_range: 亮度调整范围                             │
│  例: [0.8, 1.2]                                            │
│                                                             │
│  fill_mode: 填充方式                                        │
│  ├── 'nearest': 最近邻填充                                 │
│  ├── 'constant': 常数填充                                  │
│  └── 'reflect': 反射填充                                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

KerasCV 数据增强 #

基本用法 #

python
import keras
import keras_cv

augmenter = keras.Sequential([
    keras_cv.layers.RandomFlip(mode='horizontal'),
    keras_cv.layers.RandomRotation(factor=0.2),
    keras_cv.layers.RandomZoom(height_factor=0.2, width_factor=0.2),
])

model = keras.Sequential([
    keras.layers.Input(shape=(32, 32, 3)),
    augmenter,
    keras.layers.Rescaling(1./255),
    keras.layers.Conv2D(32, 3, activation='relu'),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(10, activation='softmax')
])

常用增强层 #

python
import keras_cv

augmenter = keras.Sequential([
    keras_cv.layers.RandomFlip(mode='horizontal_and_vertical'),
    keras_cv.layers.RandomRotation(factor=0.3),
    keras_cv.layers.RandomZoom(height_factor=(-0.2, 0.2), width_factor=(-0.2, 0.2)),
    keras_cv.layers.RandomTranslation(height_factor=0.2, width_factor=0.2),
    keras_cv.layers.RandomContrast(factor=0.2),
    keras_cv.layers.RandomBrightness(factor=0.2),
    keras_cv.layers.RandomHue(factor=0.2, value_range=[0, 255]),
    keras_cv.layers.RandomSaturation(factor=0.2),
])

RandAugment #

python
import keras_cv

rand_augment = keras_cv.layers.RandAugment(
    value_range=(0, 255),
    augmentations_per_image=3,
    magnitude=0.5,
    magnitude_stddev=0.15
)

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    rand_augment,
    keras.layers.Rescaling(1./255),
    keras.layers.Conv2D(32, 3, activation='relu'),
])

MixUp #

python
import keras_cv

mixup = keras_cv.layers.MixUp(alpha=0.2)

model = keras.Sequential([
    keras.layers.Input(shape=(32, 32, 3)),
    mixup,
    keras.layers.Rescaling(1./255),
    keras.layers.Conv2D(32, 3, activation='relu'),
])

CutMix #

python
import keras_cv

cutmix = keras_cv.layers.CutMix(alpha=1.0)

model = keras.Sequential([
    keras.layers.Input(shape=(32, 32, 3)),
    cutmix,
    keras.layers.Rescaling(1./255),
    keras.layers.Conv2D(32, 3, activation='relu'),
])

自定义增强函数 #

python
import keras
import tensorflow as tf

def custom_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    image = tf.image.random_saturation(image, lower=0.9, upper=1.1)
    image = tf.image.random_hue(image, max_delta=0.1)
    return image, label

(x_train, y_train), _ = keras.datasets.cifar10.load_data()

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.map(custom_augment, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(32).prefetch(tf.data.AUTOTUNE)

model.fit(dataset, epochs=10)

完整示例 #

python
import keras
import keras_cv

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

augmenter = keras.Sequential([
    keras_cv.layers.RandomFlip(mode='horizontal'),
    keras_cv.layers.RandomRotation(factor=0.1),
    keras_cv.layers.RandomZoom(height_factor=0.1, width_factor=0.1),
    keras_cv.layers.RandomContrast(factor=0.1),
])

model = keras.Sequential([
    keras.layers.Input(shape=(32, 32, 3)),
    augmenter,
    keras.layers.Rescaling(1./255),
    
    keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(),
    keras.layers.Dropout(0.25),
    
    keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(),
    keras.layers.Dropout(0.25),
    
    keras.layers.Flatten(),
    keras.layers.Dense(512, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=100,
    batch_size=64,
    callbacks=[
        keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
    ]
)

下一步 #

现在你已经掌握了图像数据增强,接下来学习 文本数据处理,处理文本数据!

最后更新:2026-04-04