Keras 简介 #
什么是 Keras? #
Keras 是一个高级深度学习 API,用 Python 编写,能够在 TensorFlow、PyTorch 或 JAX 之上运行。它的设计理念是"让深度学习变得简单"。
text
┌─────────────────────────────────────────────────────────────┐
│ Keras 的定位 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 高级 API ──────────────────────────────────────────────► │
│ │ │
│ │ Keras: 简洁、易用、快速原型开发 │
│ │ │
│ │ ┌─────────────────────────────────────────────────┐ │
│ │ │ model = Sequential([Dense(64), Dense(10)]) │ │
│ │ │ model.compile(optimizer='adam', loss='mse') │ │
│ │ │ model.fit(x, y, epochs=10) │ │
│ │ └─────────────────────────────────────────────────┘ │
│ │ │
│ 中级 API ─────────────────────────────────────────────► │
│ │ │
│ │ TensorFlow/PyTorch: 灵活、可控、研究导向 │
│ │ │
│ 低级 API ─────────────────────────────────────────────► │
│ │ │
│ │ CUDA/cuDNN: 高性能、底层优化 │
│ │ │
│ └─────────────────────────────────────────────────────────►│
│ │
└─────────────────────────────────────────────────────────────┘
Keras 的设计理念 #
1. 用户友好 #
text
┌─────────────────────────────────────────────────────────────┐
│ 用户友好设计 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 设计原则: │
│ │
│ ✅ 简洁的 API │
│ - 一行代码定义层 │
│ - 链式调用构建模型 │
│ - 直观的参数命名 │
│ │
│ ✅ 清晰的错误信息 │
│ - 详细的错误描述 │
│ - 修复建议 │
│ │
│ ✅ 完善的文档 │
│ - 丰富的示例代码 │
│ - 详细的概念解释 │
│ │
└─────────────────────────────────────────────────────────────┘
2. 模块化 #
text
┌─────────────────────────────────────────────────────────────┐
│ 模块化架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 神经网络 = 独立模块的组合 │
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 模型 │ │ 层 │ │ 损失 │ │ 优化器 │ │
│ ├─────────┤ ├─────────┤ ├─────────┤ ├─────────┤ │
│ │Sequential│ │ Dense │ │ MSE │ │ SGD │ │
│ │Functional│ │ Conv2D │ │ CrossEnt│ │ Adam │ │
│ │ Subclass │ │ LSTM │ │ MAE │ │ RMSprop │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
│ │
│ 每个模块可独立配置、替换、组合 │
│ │
└─────────────────────────────────────────────────────────────┘
3. 可扩展 #
python
import keras
class MyLayer(keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.output_dim),
initializer='glorot_uniform',
trainable=True
)
def call(self, inputs):
return keras.ops.matmul(inputs, self.kernel)
model = keras.Sequential([
keras.layers.Dense(64, activation='relu'),
MyLayer(32),
keras.layers.Dense(10)
])
Keras 的发展历史 #
text
┌─────────────────────────────────────────────────────────────┐
│ Keras 发展历程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 2015年3月 ─── Keras 首次发布 │
│ │ │
│ │ François Chollet 开发 │
│ │ 设计理念:简单易用 │
│ │ │
│ 2017年 ─── TensorFlow 集成 │
│ │ │
│ │ 成为 TensorFlow 高级 API │
│ │ tf.keras 模块 │
│ │ │
│ 2019年 ─── Keras 2.3 │
│ │ │
│ │ 多后端支持 │
│ │ 统一 API │
│ │ │
│ 2023年 ─── Keras 3.0 │
│ │ │
│ │ 支持 TensorFlow、PyTorch、JAX │
│ │ 全新的架构设计 │
│ │ 更好的性能 │
│ │ │
│ 现在 ─── 持续发展 │
│ 活跃的社区 │
│ 丰富的生态系统 │
│ │
└─────────────────────────────────────────────────────────────┘
Keras 3 的核心特性 #
多后端支持 #
text
┌─────────────────────────────────────────────────────────────┐
│ 多后端架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ │
│ │ Keras 3 │ │
│ │ 统一 API │ │
│ └──────┬──────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ TensorFlow │ │ PyTorch │ │ JAX │ │
│ ├─────────────┤ ├─────────────┤ ├─────────────┤ │
│ │ 生产部署 │ │ 研究开发 │ │ 高性能计算 │ │
│ │ 生态完善 │ │ 动态图 │ │ 自动向量化 │ │
│ │ TPU 支持 │ │ 社区活跃 │ │ JIT 编译 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ 切换后端只需一行代码: │
│ os.environ["KERAS_BACKEND"] = "torch" │
│ │
└─────────────────────────────────────────────────────────────┘
切换后端 #
python
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["KERAS_BACKEND"] = "torch"
os.environ["KERAS_BACKEND"] = "jax"
import keras
Keras 的核心组件 #
1. 模型(Model) #
text
┌─────────────────────────────────────────────────────────────┐
│ 模型类型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Sequential(序列模型) │
│ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │
│ │ L1│ → │ L2│ → │ L3│ → │ L4│ │
│ └───┘ └───┘ └───┘ └───┘ │
│ 简单的线性堆叠,适合简单网络 │
│ │
│ Functional API(函数式 API) │
│ ┌───┐ │
│ ┌──│ L1│──┐ │
│ │ └───┘ │ │
│ ▼ ▼ │
│ ┌───┐ ┌───┐ │
│ │ L2│ │ L3│ │
│ └───┘ └───┘ │
│ │ │ │
│ └────┬────┘ │
│ ▼ │
│ ┌───┐ │
│ │ L4│ │
│ └───┘ │
│ 支持复杂拓扑结构,多输入多输出 │
│ │
│ Model Subclassing(模型子类化) │
│ 完全自定义,最大灵活性 │
│ │
└─────────────────────────────────────────────────────────────┘
2. 层(Layer) #
text
┌─────────────────────────────────────────────────────────────┐
│ 层的分类 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 核心层 │
│ ├── Dense(全连接层) │
│ ├── Activation(激活层) │
│ ├── Flatten(展平层) │
│ └── Dropout(随机失活) │
│ │
│ 卷积层 │
│ ├── Conv1D / Conv2D / Conv3D │
│ ├── MaxPooling / AveragePooling │
│ └── UpSampling │
│ │
│ 循环层 │
│ ├── LSTM(长短期记忆) │
│ ├── GRU(门控循环单元) │
│ └── SimpleRNN │
│ │
│ 正则化层 │
│ ├── BatchNormalization │
│ ├── LayerNormalization │
│ └── Dropout │
│ │
└─────────────────────────────────────────────────────────────┘
3. 损失函数(Loss) #
text
┌─────────────────────────────────────────────────────────────┐
│ 损失函数选择 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 回归问题 │
│ ├── MeanSquaredError (MSE) │
│ ├── MeanAbsoluteError (MAE) │
│ └── Huber Loss │
│ │
│ 分类问题 │
│ ├── BinaryCrossentropy(二分类) │
│ ├── CategoricalCrossentropy(多分类) │
│ └── SparseCategoricalCrossentropy │
│ │
└─────────────────────────────────────────────────────────────┘
4. 优化器(Optimizer) #
text
┌─────────────────────────────────────────────────────────────┐
│ 优化器选择 │
├─────────────────────────────────────────────────────────────┤
│ │
│ SGD ─────────── 随机梯度下降 │
│ Momentum ─────── 带动量的 SGD │
│ Adam ─────────── 自适应学习率(推荐) │
│ RMSprop ───────── RMS 优化 │
│ Adagrad ───────── 适合稀疏数据 │
│ AdamW ─────────── 带权重衰减的 Adam │
│ │
└─────────────────────────────────────────────────────────────┘
Keras 的应用场景 #
1. 计算机视觉 #
text
┌─────────────────────────────────────────────────────────────┐
│ 计算机视觉任务 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 图像分类 │
│ 输入: 图像 → CNN → 输出: 类别 │
│ │
│ 目标检测 │
│ 输入: 图像 → 检测网络 → 输出: 边界框 + 类别 │
│ │
│ 图像分割 │
│ 输入: 图像 → 分割网络 → 输出: 像素级标签 │
│ │
│ 图像生成 │
│ 输入: 噪声/条件 → GAN/VAE → 输出: 生成图像 │
│ │
└─────────────────────────────────────────────────────────────┘
2. 自然语言处理 #
text
┌─────────────────────────────────────────────────────────────┐
│ NLP 任务 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 文本分类 │
│ 输入: 文本 → Embedding + LSTM/Transformer → 输出: 类别 │
│ │
│ 序列标注 │
│ 输入: 文本 → BiLSTM + CRF → 输出: 标签序列 │
│ │
│ 机器翻译 │
│ 输入: 源语言 → Encoder-Decoder → 输出: 目标语言 │
│ │
│ 文本生成 │
│ 输入: 提示 → Transformer → 输出: 生成文本 │
│ │
└─────────────────────────────────────────────────────────────┘
3. 时间序列 #
text
┌─────────────────────────────────────────────────────────────┐
│ 时间序列任务 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 预测 │
│ 历史数据 → LSTM/GRU → 未来值预测 │
│ │
│ 异常检测 │
│ 数据序列 → AutoEncoder → 重构误差 → 异常判断 │
│ │
│ 分类 │
│ 时间序列 → CNN + LSTM → 类别判断 │
│ │
└─────────────────────────────────────────────────────────────┘
Keras vs PyTorch #
| 特性 | Keras | PyTorch |
|---|---|---|
| 学习曲线 | 平缓,适合初学者 | 中等,需要理解更多概念 |
| 代码风格 | 声明式 | 命令式 |
| 调试体验 | 较难调试 | 原生 Python 调试 |
| 灵活性 | 中等 | 高 |
| 原型开发 | 快速 | 中等 |
| 研究用途 | 适合应用研究 | 适合算法研究 |
| 生产部署 | TensorFlow Serving | TorchServe |
| 社区规模 | 大 | 非常大 |
学习建议 #
text
┌─────────────────────────────────────────────────────────────┐
│ 学习路径 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 第1周: 基础入门 │
│ ├── 环境搭建 │
│ ├── 第一个模型 │
│ └── 理解模型训练流程 │
│ │
│ 第2-3周: 模型构建 │
│ ├── Sequential API │
│ ├── Functional API │
│ └── 常用网络层 │
│ │
│ 第4-5周: 训练技巧 │
│ ├── 损失函数与优化器 │
│ ├── 回调函数 │
│ └── 数据预处理 │
│ │
│ 第6-8周: 高级主题 │
│ ├── 迁移学习 │
│ ├── 自定义组件 │
│ └── 模型部署 │
│ │
│ 持续: 实战项目 │
│ ├── 图像分类 │
│ ├── 文本处理 │
│ └── 时间序列 │
│ │
└─────────────────────────────────────────────────────────────┘
下一步 #
现在你已经了解了 Keras 的基本概念,接下来学习 安装与配置,搭建你的深度学习开发环境!
最后更新:2026-04-04