Functional API #
什么是 Functional API? #
Functional API 是 Keras 中更灵活的模型构建方式,它将神经网络视为层的有向无环图(DAG)。
text
┌─────────────────────────────────────────────────────────────┐
│ Functional API vs Sequential │
├─────────────────────────────────────────────────────────────┤
│ │
│ Sequential: 线性堆叠 │
│ 输入 ──► 层1 ──► 层2 ──► 输出 │
│ │
│ Functional: 任意有向无环图 │
│ ┌──► 层2 ──┐ │
│ 输入 ──► 层1 ────┤ ├──► 输出 │
│ └──► 层3 ──┘ │
│ │
│ Functional API 优势: │
│ ✅ 多输入 / 多输出 │
│ ✅ 层共享 │
│ ✅ 残差连接 │
│ ✅ 任意复杂拓扑 │
│ │
└─────────────────────────────────────────────────────────────┘
基本用法 #
创建简单模型 #
python
import keras
inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(32, activation='relu')(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
text
┌─────────────────────────────────────────────────────────────┐
│ 基本模型结构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ │
│ │ Input │ shape: (None, 784) │
│ │ (inputs) │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Dense │ 64 units, ReLU │
│ │ (x) │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Dense │ 32 units, ReLU │
│ │ (x) │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Dense │ 10 units, Softmax │
│ │ (outputs) │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
多输入模型 #
python
import keras
title_input = keras.Input(shape=(100,), name='title')
body_input = keras.Input(shape=(500,), name='body')
tags_input = keras.Input(shape=(10,), name='tags')
title_features = keras.layers.Embedding(10000, 64)(title_input)
title_features = keras.layers.LSTM(128)(title_features)
body_features = keras.layers.Embedding(10000, 64)(body_input)
body_features = keras.layers.LSTM(128)(body_features)
tags_features = keras.layers.Dense(32, activation='relu')(tags_input)
x = keras.layers.concatenate([title_features, body_features, tags_features])
priority = keras.layers.Dense(1, activation='sigmoid', name='priority')(x)
department = keras.layers.Dense(4, activation='softmax', name='department')(x)
model = keras.Model(
inputs=[title_input, body_input, tags_input],
outputs=[priority, department]
)
text
┌─────────────────────────────────────────────────────────────┐
│ 多输入模型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ │
│ │ title │──► Embedding ──► LSTM ──┐ │
│ └──────────┘ │ │
│ │ │
│ ┌──────────┐ │ │
│ │ body │──► Embedding ──► LSTM ──┼──► Concatenate │
│ └──────────┘ │ │ │
│ │ │ │
│ ┌──────────┐ │ ▼ │
│ │ tags │──► Dense ───────────────┘ ┌─────────┐ │
│ └──────────┘ │ x │ │
│ └────┬────┘ │
│ │ │
│ ┌──────┴──────┐ │
│ ▼ ▼ │
│ ┌─────────┐ ┌────────┐│
│ │priority │ │depart- ││
│ │ (sigmoid)│ │ment ││
│ └─────────┘ └────────┘│
│ │
└─────────────────────────────────────────────────────────────┘
训练多输入模型 #
python
model.compile(
optimizer='adam',
loss={
'priority': 'binary_crossentropy',
'department': 'categorical_crossentropy'
},
loss_weights={
'priority': 1.0,
'department': 0.5
},
metrics=['accuracy']
)
history = model.fit(
x={
'title': title_data,
'body': body_data,
'tags': tags_data
},
y={
'priority': priority_targets,
'department': department_targets
},
epochs=10,
batch_size=32
)
多输出模型 #
python
import keras
inputs = keras.Input(shape=(224, 224, 3))
x = keras.layers.Conv2D(64, 3, activation='relu')(inputs)
x = keras.layers.MaxPooling2D()(x)
x = keras.layers.Conv2D(128, 3, activation='relu')(x)
x = keras.layers.MaxPooling2D()(x)
x = keras.layers.Conv2D(256, 3, activation='relu')(x)
x = keras.layers.GlobalAveragePooling2D()(x)
class_output = keras.layers.Dense(10, activation='softmax', name='class')(x)
bbox_output = keras.layers.Dense(4, activation='sigmoid', name='bbox')(x)
model = keras.Model(inputs=inputs, outputs=[class_output, bbox_output])
层共享 #
python
import keras
shared_embedding = keras.layers.Embedding(10000, 128)
input_a = keras.Input(shape=(100,), name='input_a')
input_b = keras.Input(shape=(100,), name='input_b')
embedded_a = shared_embedding(input_a)
embedded_b = shared_embedding(input_b)
lstm = keras.layers.LSTM(64)
features_a = lstm(embedded_a)
features_b = lstm(embedded_b)
x = keras.layers.concatenate([features_a, features_b])
output = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs=[input_a, input_b], outputs=output)
text
┌─────────────────────────────────────────────────────────────┐
│ 层共享模型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌─────────────────────┐ │
│ │ input_a │─────►│ │ │
│ └──────────┘ │ Shared Embedding │ │
│ │ │ │
│ ┌──────────┐ │ │ │
│ │ input_b │─────►│ │ │
│ └──────────┘ └──────────┬──────────┘ │
│ │ │
│ ┌─────────┴─────────┐ │
│ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ │
│ │ embedded_a│ │ embedded_b│ │
│ └─────┬────┘ └─────┬────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ │
│ │ LSTM │ │ LSTM │ │
│ │ (共享) │ │ (共享) │ │
│ └─────┬────┘ └─────┬────┘ │
│ │ │ │
│ └─────────┬─────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Concatenate │ │
│ └────────┬────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Dense(1) │ │
│ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
残差连接 #
python
import keras
def residual_block(x, filters):
shortcut = x
x = keras.layers.Conv2D(filters, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(filters, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
if shortcut.shape[-1] != filters:
shortcut = keras.layers.Conv2D(filters, 1)(shortcut)
x = keras.layers.Add()([x, shortcut])
x = keras.layers.ReLU()(x)
return x
inputs = keras.Input(shape=(224, 224, 3))
x = keras.layers.Conv2D(64, 7, strides=2, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.MaxPooling2D()(x)
x = residual_block(x, 64)
x = residual_block(x, 128)
x = residual_block(x, 256)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs, outputs)
text
┌─────────────────────────────────────────────────────────────┐
│ 残差连接结构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ │
│ ┌───────►│ shortcut │─────────────┐ │
│ │ └──────────────┘ │ │
│ │ │ │
│ ┌────────┴────────┐ │ │
│ │ 输入 │ │ │
│ └────────┬────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌────────────────┐ │ │
│ │ Conv2D │ │ │
│ └────────┬───────┘ │ │
│ ▼ │ │
│ ┌────────────────┐ │ │
│ │ BatchNorm+ReLU │ │ │
│ └────────┬───────┘ │ │
│ ▼ │ │
│ ┌────────────────┐ │ │
│ │ Conv2D │ │ │
│ └────────┬───────┘ │ │
│ ▼ │ │
│ ┌────────────────┐ │ │
│ │ BatchNorm │ │ │
│ └────────┬───────┘ │ │
│ │ │ │
│ ▼ ▼ │
│ ┌────────────────────────────────────────────────┐ │
│ │ Add │ │
│ └───────────────────────┬────────────────────────┘ │
│ ▼ │
│ ┌────────────┐ │
│ │ ReLU │ │
│ └────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
Inception 模块 #
python
import keras
def inception_module(x, filters_1x1, filters_3x3, filters_5x5, filters_pool):
branch_1x1 = keras.layers.Conv2D(filters_1x1, 1, activation='relu')(x)
branch_3x3 = keras.layers.Conv2D(filters_3x3 // 2, 1, activation='relu')(x)
branch_3x3 = keras.layers.Conv2D(filters_3x3, 3, padding='same', activation='relu')(branch_3x3)
branch_5x5 = keras.layers.Conv2D(filters_5x5 // 2, 1, activation='relu')(x)
branch_5x5 = keras.layers.Conv2D(filters_5x5, 5, padding='same', activation='relu')(branch_5x5)
branch_pool = keras.layers.MaxPooling2D(3, strides=1, padding='same')(x)
branch_pool = keras.layers.Conv2D(filters_pool, 1, activation='relu')(branch_pool)
return keras.layers.concatenate([branch_1x1, branch_3x3, branch_5x5, branch_pool])
inputs = keras.Input(shape=(224, 224, 3))
x = keras.layers.Conv2D(64, 7, strides=2, padding='same', activation='relu')(inputs)
x = keras.layers.MaxPooling2D()(x)
x = inception_module(x, 64, 128, 32, 32)
x = inception_module(x, 128, 192, 96, 64)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs, outputs)
自动编码器 #
python
import keras
encoder_input = keras.Input(shape=(28, 28, 1), name='img')
x = keras.layers.Conv2D(16, 3, activation='relu')(encoder_input)
x = keras.layers.Conv2D(32, 3, activation='relu')(x)
x = keras.layers.MaxPooling2D()(x)
x = keras.layers.Conv2D(32, 3, activation='relu')(x)
x = keras.layers.Conv2D(16, 3, activation='relu')(x)
encoder_output = keras.layers.GlobalMaxPooling2D()(x)
encoder = keras.Model(encoder_input, encoder_output, name='encoder')
decoder_input = keras.Input(shape=(16,), name='encoded_img')
x = keras.layers.Reshape((4, 4, 1))(decoder_input)
x = keras.layers.Conv2DTranspose(16, 3, activation='relu')(x)
x = keras.layers.Conv2DTranspose(32, 3, activation='relu')(x)
x = keras.layers.UpSampling2D()(x)
x = keras.layers.Conv2DTranspose(16, 3, activation='relu')(x)
decoder_output = keras.layers.Conv2DTranspose(1, 3, activation='sigmoid')(x)
decoder = keras.Model(decoder_input, decoder_output, name='decoder')
autoencoder_input = keras.Input(shape=(28, 28, 1), name='img')
encoded = encoder(autoencoder_input)
decoded = decoder(encoded)
autoencoder = keras.Model(autoencoder_input, decoded, name='autoencoder')
模型可视化 #
python
keras.utils.plot_model(model, 'model.png', show_shapes=True)
下一步 #
现在你已经掌握了 Functional API,接下来学习 核心层,深入了解各种网络层!
最后更新:2026-04-04