Keras 简介 #

什么是 Keras? #

Keras 是一个高级深度学习 API,用 Python 编写,能够在 TensorFlow、PyTorch 或 JAX 之上运行。它的设计理念是"让深度学习变得简单"。

text
┌─────────────────────────────────────────────────────────────┐
│                    Keras 的定位                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  高级 API ──────────────────────────────────────────────►   │
│  │                                                          │
│  │  Keras: 简洁、易用、快速原型开发                          │
│  │                                                          │
│  │  ┌─────────────────────────────────────────────────┐    │
│  │  │  model = Sequential([Dense(64), Dense(10)])     │    │
│  │  │  model.compile(optimizer='adam', loss='mse')    │    │
│  │  │  model.fit(x, y, epochs=10)                     │    │
│  │  └─────────────────────────────────────────────────┘    │
│  │                                                          │
│  中级 API ─────────────────────────────────────────────►    │
│  │                                                          │
│  │  TensorFlow/PyTorch: 灵活、可控、研究导向                │
│  │                                                          │
│  低级 API ─────────────────────────────────────────────►    │
│  │                                                          │
│  │  CUDA/cuDNN: 高性能、底层优化                            │
│  │                                                          │
│  └─────────────────────────────────────────────────────────►│
│                                                             │
└─────────────────────────────────────────────────────────────┘

Keras 的设计理念 #

1. 用户友好 #

text
┌─────────────────────────────────────────────────────────────┐
│                    用户友好设计                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  设计原则:                                                  │
│                                                             │
│  ✅ 简洁的 API                                              │
│     - 一行代码定义层                                        │
│     - 链式调用构建模型                                      │
│     - 直观的参数命名                                        │
│                                                             │
│  ✅ 清晰的错误信息                                          │
│     - 详细的错误描述                                        │
│     - 修复建议                                              │
│                                                             │
│  ✅ 完善的文档                                              │
│     - 丰富的示例代码                                        │
│     - 详细的概念解释                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. 模块化 #

text
┌─────────────────────────────────────────────────────────────┐
│                    模块化架构                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  神经网络 = 独立模块的组合                                   │
│                                                             │
│  ┌─────────┐   ┌─────────┐   ┌─────────┐   ┌─────────┐    │
│  │  模型   │   │   层    │   │  损失   │   │ 优化器  │    │
│  ├─────────┤   ├─────────┤   ├─────────┤   ├─────────┤    │
│  │Sequential│   │ Dense   │   │ MSE     │   │ SGD     │    │
│  │Functional│   │ Conv2D  │   │ CrossEnt│   │ Adam    │    │
│  │ Subclass │   │ LSTM    │   │ MAE     │   │ RMSprop │    │
│  └─────────┘   └─────────┘   └─────────┘   └─────────┘    │
│                                                             │
│  每个模块可独立配置、替换、组合                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

3. 可扩展 #

python
import keras

class MyLayer(keras.layers.Layer):
    def __init__(self, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.output_dim = output_dim
    
    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.output_dim),
            initializer='glorot_uniform',
            trainable=True
        )
    
    def call(self, inputs):
        return keras.ops.matmul(inputs, self.kernel)

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu'),
    MyLayer(32),
    keras.layers.Dense(10)
])

Keras 的发展历史 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Keras 发展历程                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  2015年3月 ─── Keras 首次发布                               │
│       │                                                      │
│       │      François Chollet 开发                          │
│       │      设计理念:简单易用                              │
│       │                                                      │
│  2017年 ─── TensorFlow 集成                                 │
│       │                                                      │
│       │      成为 TensorFlow 高级 API                       │
│       │      tf.keras 模块                                  │
│       │                                                      │
│  2019年 ─── Keras 2.3                                       │
│       │                                                      │
│       │      多后端支持                                      │
│       │      统一 API                                       │
│       │                                                      │
│  2023年 ─── Keras 3.0                                       │
│       │                                                      │
│       │      支持 TensorFlow、PyTorch、JAX                  │
│       │      全新的架构设计                                  │
│       │      更好的性能                                      │
│       │                                                      │
│  现在 ─── 持续发展                                          │
│           活跃的社区                                         │
│           丰富的生态系统                                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Keras 3 的核心特性 #

多后端支持 #

text
┌─────────────────────────────────────────────────────────────┐
│                    多后端架构                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│                    ┌─────────────┐                          │
│                    │  Keras 3    │                          │
│                    │  统一 API   │                          │
│                    └──────┬──────┘                          │
│                           │                                 │
│         ┌─────────────────┼─────────────────┐              │
│         │                 │                 │              │
│         ▼                 ▼                 ▼              │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐        │
│  │ TensorFlow  │  │   PyTorch   │  │    JAX      │        │
│  ├─────────────┤  ├─────────────┤  ├─────────────┤        │
│  │ 生产部署    │  │ 研究开发    │  │ 高性能计算  │        │
│  │ 生态完善    │  │ 动态图      │  │ 自动向量化  │        │
│  │ TPU 支持    │  │ 社区活跃    │  │ JIT 编译    │        │
│  └─────────────┘  └─────────────┘  └─────────────┘        │
│                                                             │
│  切换后端只需一行代码:                                      │
│  os.environ["KERAS_BACKEND"] = "torch"                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

切换后端 #

python
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["KERAS_BACKEND"] = "torch"
os.environ["KERAS_BACKEND"] = "jax"

import keras

Keras 的核心组件 #

1. 模型(Model) #

text
┌─────────────────────────────────────────────────────────────┐
│                    模型类型                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Sequential(序列模型)                                      │
│  ┌───┐   ┌───┐   ┌───┐   ┌───┐                            │
│  │ L1│ → │ L2│ → │ L3│ → │ L4│                            │
│  └───┘   └───┘   └───┘   └───┘                            │
│  简单的线性堆叠,适合简单网络                                │
│                                                             │
│  Functional API(函数式 API)                                │
│       ┌───┐                                                 │
│    ┌──│ L1│──┐                                              │
│    │  └───┘  │                                              │
│    ▼         ▼                                              │
│  ┌───┐     ┌───┐                                            │
│  │ L2│     │ L3│                                            │
│  └───┘     └───┘                                            │
│    │         │                                              │
│    └────┬────┘                                              │
│         ▼                                                   │
│       ┌───┐                                                 │
│       │ L4│                                                 │
│       └───┘                                                 │
│  支持复杂拓扑结构,多输入多输出                              │
│                                                             │
│  Model Subclassing(模型子类化)                             │
│  完全自定义,最大灵活性                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. 层(Layer) #

text
┌─────────────────────────────────────────────────────────────┐
│                    层的分类                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  核心层                                                     │
│  ├── Dense(全连接层)                                      │
│  ├── Activation(激活层)                                   │
│  ├── Flatten(展平层)                                      │
│  └── Dropout(随机失活)                                    │
│                                                             │
│  卷积层                                                     │
│  ├── Conv1D / Conv2D / Conv3D                               │
│  ├── MaxPooling / AveragePooling                            │
│  └── UpSampling                                             │
│                                                             │
│  循环层                                                     │
│  ├── LSTM(长短期记忆)                                     │
│  ├── GRU(门控循环单元)                                    │
│  └── SimpleRNN                                              │
│                                                             │
│  正则化层                                                   │
│  ├── BatchNormalization                                     │
│  ├── LayerNormalization                                     │
│  └── Dropout                                                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

3. 损失函数(Loss) #

text
┌─────────────────────────────────────────────────────────────┐
│                    损失函数选择                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  回归问题                                                   │
│  ├── MeanSquaredError (MSE)                                 │
│  ├── MeanAbsoluteError (MAE)                                │
│  └── Huber Loss                                             │
│                                                             │
│  分类问题                                                   │
│  ├── BinaryCrossentropy(二分类)                           │
│  ├── CategoricalCrossentropy(多分类)                      │
│  └── SparseCategoricalCrossentropy                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

4. 优化器(Optimizer) #

text
┌─────────────────────────────────────────────────────────────┐
│                    优化器选择                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  SGD ─────────── 随机梯度下降                               │
│  Momentum ─────── 带动量的 SGD                              │
│  Adam ─────────── 自适应学习率(推荐)                      │
│  RMSprop ───────── RMS 优化                                 │
│  Adagrad ───────── 适合稀疏数据                             │
│  AdamW ─────────── 带权重衰减的 Adam                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Keras 的应用场景 #

1. 计算机视觉 #

text
┌─────────────────────────────────────────────────────────────┐
│                    计算机视觉任务                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  图像分类                                                   │
│  输入: 图像 → CNN → 输出: 类别                              │
│                                                             │
│  目标检测                                                   │
│  输入: 图像 → 检测网络 → 输出: 边界框 + 类别                │
│                                                             │
│  图像分割                                                   │
│  输入: 图像 → 分割网络 → 输出: 像素级标签                   │
│                                                             │
│  图像生成                                                   │
│  输入: 噪声/条件 → GAN/VAE → 输出: 生成图像                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. 自然语言处理 #

text
┌─────────────────────────────────────────────────────────────┐
│                    NLP 任务                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  文本分类                                                   │
│  输入: 文本 → Embedding + LSTM/Transformer → 输出: 类别     │
│                                                             │
│  序列标注                                                   │
│  输入: 文本 → BiLSTM + CRF → 输出: 标签序列                 │
│                                                             │
│  机器翻译                                                   │
│  输入: 源语言 → Encoder-Decoder → 输出: 目标语言            │
│                                                             │
│  文本生成                                                   │
│  输入: 提示 → Transformer → 输出: 生成文本                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

3. 时间序列 #

text
┌─────────────────────────────────────────────────────────────┐
│                    时间序列任务                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  预测                                                       │
│  历史数据 → LSTM/GRU → 未来值预测                           │
│                                                             │
│  异常检测                                                   │
│  数据序列 → AutoEncoder → 重构误差 → 异常判断               │
│                                                             │
│  分类                                                       │
│  时间序列 → CNN + LSTM → 类别判断                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Keras vs PyTorch #

特性 Keras PyTorch
学习曲线 平缓,适合初学者 中等,需要理解更多概念
代码风格 声明式 命令式
调试体验 较难调试 原生 Python 调试
灵活性 中等
原型开发 快速 中等
研究用途 适合应用研究 适合算法研究
生产部署 TensorFlow Serving TorchServe
社区规模 非常大

学习建议 #

text
┌─────────────────────────────────────────────────────────────┐
│                    学习路径                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  第1周: 基础入门                                            │
│  ├── 环境搭建                                              │
│  ├── 第一个模型                                            │
│  └── 理解模型训练流程                                      │
│                                                             │
│  第2-3周: 模型构建                                         │
│  ├── Sequential API                                        │
│  ├── Functional API                                        │
│  └── 常用网络层                                            │
│                                                             │
│  第4-5周: 训练技巧                                         │
│  ├── 损失函数与优化器                                      │
│  ├── 回调函数                                              │
│  └── 数据预处理                                            │
│                                                             │
│  第6-8周: 高级主题                                         │
│  ├── 迁移学习                                              │
│  ├── 自定义组件                                            │
│  └── 模型部署                                              │
│                                                             │
│  持续: 实战项目                                             │
│  ├── 图像分类                                              │
│  ├── 文本处理                                              │
│  └── 时间序列                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

下一步 #

现在你已经了解了 Keras 的基本概念,接下来学习 安装与配置,搭建你的深度学习开发环境!

最后更新:2026-04-04