Functional API #

什么是 Functional API? #

Functional API 是 Keras 中更灵活的模型构建方式,它将神经网络视为层的有向无环图(DAG)。

text
┌─────────────────────────────────────────────────────────────┐
│                    Functional API vs Sequential             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Sequential: 线性堆叠                                       │
│  输入 ──► 层1 ──► 层2 ──► 输出                              │
│                                                             │
│  Functional: 任意有向无环图                                  │
│                    ┌──► 层2 ──┐                             │
│  输入 ──► 层1 ────┤          ├──► 输出                      │
│                    └──► 层3 ──┘                             │
│                                                             │
│  Functional API 优势:                                       │
│  ✅ 多输入 / 多输出                                         │
│  ✅ 层共享                                                  │
│  ✅ 残差连接                                                │
│  ✅ 任意复杂拓扑                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本用法 #

创建简单模型 #

python
import keras

inputs = keras.Input(shape=(784,))

x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(32, activation='relu')(x)

outputs = keras.layers.Dense(10, activation='softmax')(x)

model = keras.Model(inputs=inputs, outputs=outputs)
text
┌─────────────────────────────────────────────────────────────┐
│                    基本模型结构                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐                                           │
│  │   Input     │ shape: (None, 784)                        │
│  │  (inputs)   │                                           │
│  └──────┬──────┘                                           │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────┐                                           │
│  │   Dense     │ 64 units, ReLU                            │
│  │    (x)      │                                           │
│  └──────┬──────┘                                           │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────┐                                           │
│  │   Dense     │ 32 units, ReLU                            │
│  │    (x)      │                                           │
│  └──────┬──────┘                                           │
│         │                                                   │
│         ▼                                                   │
│  ┌─────────────┐                                           │
│  │   Dense     │ 10 units, Softmax                         │
│  │  (outputs)  │                                           │
│  └─────────────┘                                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

多输入模型 #

python
import keras

title_input = keras.Input(shape=(100,), name='title')
body_input = keras.Input(shape=(500,), name='body')
tags_input = keras.Input(shape=(10,), name='tags')

title_features = keras.layers.Embedding(10000, 64)(title_input)
title_features = keras.layers.LSTM(128)(title_features)

body_features = keras.layers.Embedding(10000, 64)(body_input)
body_features = keras.layers.LSTM(128)(body_features)

tags_features = keras.layers.Dense(32, activation='relu')(tags_input)

x = keras.layers.concatenate([title_features, body_features, tags_features])

priority = keras.layers.Dense(1, activation='sigmoid', name='priority')(x)
department = keras.layers.Dense(4, activation='softmax', name='department')(x)

model = keras.Model(
    inputs=[title_input, body_input, tags_input],
    outputs=[priority, department]
)
text
┌─────────────────────────────────────────────────────────────┐
│                    多输入模型                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌──────────┐                                              │
│  │  title   │──► Embedding ──► LSTM ──┐                    │
│  └──────────┘                         │                    │
│                                       │                    │
│  ┌──────────┐                         │                    │
│  │  body    │──► Embedding ──► LSTM ──┼──► Concatenate    │
│  └──────────┘                         │         │          │
│                                       │         │          │
│  ┌──────────┐                         │         ▼          │
│  │  tags    │──► Dense ───────────────┘    ┌─────────┐    │
│  └──────────┘                              │    x    │    │
│                                            └────┬────┘    │
│                                                 │          │
│                                          ┌──────┴──────┐  │
│                                          ▼             ▼  │
│                                    ┌─────────┐   ┌────────┐│
│                                    │priority │   │depart- ││
│                                    │ (sigmoid)│   │ment   ││
│                                    └─────────┘   └────────┘│
│                                                             │
└─────────────────────────────────────────────────────────────┘

训练多输入模型 #

python
model.compile(
    optimizer='adam',
    loss={
        'priority': 'binary_crossentropy',
        'department': 'categorical_crossentropy'
    },
    loss_weights={
        'priority': 1.0,
        'department': 0.5
    },
    metrics=['accuracy']
)

history = model.fit(
    x={
        'title': title_data,
        'body': body_data,
        'tags': tags_data
    },
    y={
        'priority': priority_targets,
        'department': department_targets
    },
    epochs=10,
    batch_size=32
)

多输出模型 #

python
import keras

inputs = keras.Input(shape=(224, 224, 3))

x = keras.layers.Conv2D(64, 3, activation='relu')(inputs)
x = keras.layers.MaxPooling2D()(x)
x = keras.layers.Conv2D(128, 3, activation='relu')(x)
x = keras.layers.MaxPooling2D()(x)
x = keras.layers.Conv2D(256, 3, activation='relu')(x)
x = keras.layers.GlobalAveragePooling2D()(x)

class_output = keras.layers.Dense(10, activation='softmax', name='class')(x)
bbox_output = keras.layers.Dense(4, activation='sigmoid', name='bbox')(x)

model = keras.Model(inputs=inputs, outputs=[class_output, bbox_output])

层共享 #

python
import keras

shared_embedding = keras.layers.Embedding(10000, 128)

input_a = keras.Input(shape=(100,), name='input_a')
input_b = keras.Input(shape=(100,), name='input_b')

embedded_a = shared_embedding(input_a)
embedded_b = shared_embedding(input_b)

lstm = keras.layers.LSTM(64)
features_a = lstm(embedded_a)
features_b = lstm(embedded_b)

x = keras.layers.concatenate([features_a, features_b])
output = keras.layers.Dense(1, activation='sigmoid')(x)

model = keras.Model(inputs=[input_a, input_b], outputs=output)
text
┌─────────────────────────────────────────────────────────────┐
│                    层共享模型                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌──────────┐      ┌─────────────────────┐                 │
│  │ input_a  │─────►│                     │                 │
│  └──────────┘      │   Shared Embedding  │                 │
│                    │                     │                 │
│  ┌──────────┐      │                     │                 │
│  │ input_b  │─────►│                     │                 │
│  └──────────┘      └──────────┬──────────┘                 │
│                              │                              │
│                    ┌─────────┴─────────┐                   │
│                    ▼                   ▼                   │
│              ┌──────────┐        ┌──────────┐             │
│              │ embedded_a│        │ embedded_b│             │
│              └─────┬────┘        └─────┬────┘             │
│                    │                   │                   │
│                    ▼                   ▼                   │
│              ┌──────────┐        ┌──────────┐             │
│              │   LSTM   │        │   LSTM   │             │
│              │ (共享)   │        │ (共享)   │             │
│              └─────┬────┘        └─────┬────┘             │
│                    │                   │                   │
│                    └─────────┬─────────┘                   │
│                              ▼                              │
│                    ┌─────────────────┐                     │
│                    │   Concatenate   │                     │
│                    └────────┬────────┘                     │
│                             ▼                              │
│                    ┌─────────────────┐                     │
│                    │   Dense(1)      │                     │
│                    └─────────────────┘                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

残差连接 #

python
import keras

def residual_block(x, filters):
    shortcut = x
    
    x = keras.layers.Conv2D(filters, 3, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    
    x = keras.layers.Conv2D(filters, 3, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    
    if shortcut.shape[-1] != filters:
        shortcut = keras.layers.Conv2D(filters, 1)(shortcut)
    
    x = keras.layers.Add()([x, shortcut])
    x = keras.layers.ReLU()(x)
    
    return x

inputs = keras.Input(shape=(224, 224, 3))
x = keras.layers.Conv2D(64, 7, strides=2, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.MaxPooling2D()(x)

x = residual_block(x, 64)
x = residual_block(x, 128)
x = residual_block(x, 256)

x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)

model = keras.Model(inputs, outputs)
text
┌─────────────────────────────────────────────────────────────┐
│                    残差连接结构                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│                    ┌──────────────┐                         │
│           ┌───────►│   shortcut   │─────────────┐           │
│           │        └──────────────┘             │           │
│           │                                     │           │
│  ┌────────┴────────┐                           │           │
│  │      输入       │                           │           │
│  └────────┬────────┘                           │           │
│           │                                     │           │
│           ▼                                     │           │
│  ┌────────────────┐                            │           │
│  │    Conv2D      │                            │           │
│  └────────┬───────┘                            │           │
│           ▼                                     │           │
│  ┌────────────────┐                            │           │
│  │ BatchNorm+ReLU │                            │           │
│  └────────┬───────┘                            │           │
│           ▼                                     │           │
│  ┌────────────────┐                            │           │
│  │    Conv2D      │                            │           │
│  └────────┬───────┘                            │           │
│           ▼                                     │           │
│  ┌────────────────┐                            │           │
│  │   BatchNorm    │                            │           │
│  └────────┬───────┘                            │           │
│           │                                     │           │
│           ▼                                     ▼           │
│  ┌────────────────────────────────────────────────┐         │
│  │                    Add                         │         │
│  └───────────────────────┬────────────────────────┘         │
│                          ▼                                  │
│                   ┌────────────┐                            │
│                   │    ReLU    │                            │
│                   └────────────┘                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Inception 模块 #

python
import keras

def inception_module(x, filters_1x1, filters_3x3, filters_5x5, filters_pool):
    branch_1x1 = keras.layers.Conv2D(filters_1x1, 1, activation='relu')(x)
    
    branch_3x3 = keras.layers.Conv2D(filters_3x3 // 2, 1, activation='relu')(x)
    branch_3x3 = keras.layers.Conv2D(filters_3x3, 3, padding='same', activation='relu')(branch_3x3)
    
    branch_5x5 = keras.layers.Conv2D(filters_5x5 // 2, 1, activation='relu')(x)
    branch_5x5 = keras.layers.Conv2D(filters_5x5, 5, padding='same', activation='relu')(branch_5x5)
    
    branch_pool = keras.layers.MaxPooling2D(3, strides=1, padding='same')(x)
    branch_pool = keras.layers.Conv2D(filters_pool, 1, activation='relu')(branch_pool)
    
    return keras.layers.concatenate([branch_1x1, branch_3x3, branch_5x5, branch_pool])

inputs = keras.Input(shape=(224, 224, 3))
x = keras.layers.Conv2D(64, 7, strides=2, padding='same', activation='relu')(inputs)
x = keras.layers.MaxPooling2D()(x)

x = inception_module(x, 64, 128, 32, 32)
x = inception_module(x, 128, 192, 96, 64)

x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)

model = keras.Model(inputs, outputs)

自动编码器 #

python
import keras

encoder_input = keras.Input(shape=(28, 28, 1), name='img')
x = keras.layers.Conv2D(16, 3, activation='relu')(encoder_input)
x = keras.layers.Conv2D(32, 3, activation='relu')(x)
x = keras.layers.MaxPooling2D()(x)
x = keras.layers.Conv2D(32, 3, activation='relu')(x)
x = keras.layers.Conv2D(16, 3, activation='relu')(x)
encoder_output = keras.layers.GlobalMaxPooling2D()(x)

encoder = keras.Model(encoder_input, encoder_output, name='encoder')

decoder_input = keras.Input(shape=(16,), name='encoded_img')
x = keras.layers.Reshape((4, 4, 1))(decoder_input)
x = keras.layers.Conv2DTranspose(16, 3, activation='relu')(x)
x = keras.layers.Conv2DTranspose(32, 3, activation='relu')(x)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Conv2DTranspose(16, 3, activation='relu')(x)
decoder_output = keras.layers.Conv2DTranspose(1, 3, activation='sigmoid')(x)

decoder = keras.Model(decoder_input, decoder_output, name='decoder')

autoencoder_input = keras.Input(shape=(28, 28, 1), name='img')
encoded = encoder(autoencoder_input)
decoded = decoder(encoded)

autoencoder = keras.Model(autoencoder_input, decoded, name='autoencoder')

模型可视化 #

python
keras.utils.plot_model(model, 'model.png', show_shapes=True)

下一步 #

现在你已经掌握了 Functional API,接下来学习 核心层,深入了解各种网络层!

最后更新:2026-04-04