图像分类实战 #

项目概述 #

text
┌─────────────────────────────────────────────────────────────┐
│                    项目流程                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 数据准备                                                │
│     ├── 数据加载                                           │
│     ├── 数据预处理                                         │
│     └── 数据增强                                           │
│                                                             │
│  2. 模型构建                                                │
│     ├── CNN 架构设计                                       │
│     └── 迁移学习                                           │
│                                                             │
│  3. 模型训练                                                │
│     ├── 编译配置                                           │
│     └── 回调函数                                           │
│                                                             │
│  4. 模型评估                                                │
│     ├── 性能指标                                           │
│     └── 可视化分析                                         │
│                                                             │
│  5. 模型部署                                                │
│     ├── 模型保存                                           │
│     └── 预测接口                                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

CIFAR-10 图像分类 #

数据准备 #

python
import keras
import numpy as np
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f'训练集: {x_train.shape}')
print(f'测试集: {x_test.shape}')

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

y_train = y_train.flatten()
y_test = y_test.flatten()

数据增强 #

python
import keras_cv

augmenter = keras.Sequential([
    keras_cv.layers.RandomFlip(mode='horizontal'),
    keras_cv.layers.RandomRotation(factor=0.1),
    keras_cv.layers.RandomZoom(height_factor=0.1, width_factor=0.1),
])

模型构建 #

python
import keras

model = keras.Sequential([
    keras.layers.Input(shape=(32, 32, 3)),
    augmenter,
    
    keras.layers.Conv2D(32, 3, padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.Conv2D(32, 3, padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.MaxPooling2D(),
    keras.layers.Dropout(0.25),
    
    keras.layers.Conv2D(64, 3, padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.Conv2D(64, 3, padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.MaxPooling2D(),
    keras.layers.Dropout(0.25),
    
    keras.layers.Conv2D(128, 3, padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.Conv2D(128, 3, padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.MaxPooling2D(),
    keras.layers.Dropout(0.25),
    
    keras.layers.Flatten(),
    keras.layers.Dense(256),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

模型训练 #

python
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=15,
        restore_best_weights=True
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    ),
    keras.callbacks.ModelCheckpoint(
        'best_cifar10.keras',
        monitor='val_accuracy',
        save_best_only=True
    )
]

history = model.fit(
    x_train, y_train,
    validation_split=0.1,
    epochs=100,
    batch_size=64,
    callbacks=callbacks
)

模型评估 #

python
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f'测试准确率: {test_acc:.4f}')

predictions = model.predict(x_test)
predicted_classes = np.argmax(predictions, axis=1)

from sklearn.metrics import classification_report, confusion_matrix

print(classification_report(y_test, predicted_classes, target_names=class_names))

import seaborn as sns

cm = confusion_matrix(y_test, predicted_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

可视化训练过程 #

python
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history.history['loss'], label='Train Loss')
axes[0].plot(history.history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()

axes[1].plot(history.history['accuracy'], label='Train Acc')
axes[1].plot(history.history['val_accuracy'], label='Val Acc')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()

预测示例 #

python
def predict_and_plot(images, labels, model, class_names, n=10):
    predictions = model.predict(images[:n])
    predicted_classes = np.argmax(predictions, axis=1)
    
    plt.figure(figsize=(15, 5))
    for i in range(n):
        plt.subplot(2, 5, i + 1)
        plt.imshow(images[i])
        true_label = class_names[labels[i]]
        pred_label = class_names[predicted_classes[i]]
        color = 'green' if labels[i] == predicted_classes[i] else 'red'
        plt.title(f'True: {true_label}\nPred: {pred_label}', color=color)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

predict_and_plot(x_test, y_test, model, class_names)

使用迁移学习 #

python
import keras

base_model = keras.applications.EfficientNetB0(
    weights='imagenet',
    include_top=False,
    input_shape=(32, 32, 3)
)
base_model.trainable = False

inputs = keras.Input(shape=(32, 32, 3))
x = keras.layers.Rescaling(1./255)(inputs)
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(256, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)

model = keras.Model(inputs, outputs)

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    x_train, y_train,
    validation_split=0.1,
    epochs=20,
    batch_size=64
)

base_model.trainable = True
for layer in base_model.layers[:-20]:
    layer.trainable = False

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history_fine = model.fit(
    x_train, y_train,
    validation_split=0.1,
    epochs=10,
    batch_size=64
)

下一步 #

现在你已经完成了图像分类实战,接下来学习 文本分类,处理文本数据!

最后更新:2026-04-04