自定义模型 #
为什么需要自定义模型? #
text
┌─────────────────────────────────────────────────────────────┐
│ 模型构建方式对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Sequential: │
│ ✅ 简单易用 │
│ ❌ 只能线性堆叠 │
│ │
│ Functional API: │
│ ✅ 支持复杂拓扑 │
│ ✅ 多输入多输出 │
│ ❌ 静态图结构 │
│ │
│ Model Subclassing: │
│ ✅ 最大灵活性 │
│ ✅ 动态计算图 │
│ ✅ 自定义前向传播 │
│ ❌ 代码更复杂 │
│ │
└─────────────────────────────────────────────────────────────┘
基本结构 #
python
import keras
class MyModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
pass
def call(self, inputs):
return inputs
简单示例 #
python
import keras
class SimpleModel(keras.Model):
def __init__(self, hidden_units, num_classes, **kwargs):
super().__init__(**kwargs)
self.hidden_units = hidden_units
self.num_classes = num_classes
def build(self, input_shape):
self.dense1 = keras.layers.Dense(self.hidden_units, activation='relu')
self.dense2 = keras.layers.Dense(self.hidden_units // 2, activation='relu')
self.output_layer = keras.layers.Dense(self.num_classes, activation='softmax')
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
return self.output_layer(x)
model = SimpleModel(hidden_units=128, num_classes=10)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
变分自编码器(VAE) #
python
import keras
import keras.ops as ops
class VAE(keras.Model):
def __init__(self, latent_dim, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.encoder = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28, 1)),
keras.layers.Conv2D(32, 3, activation='relu', strides=2, padding='same'),
keras.layers.Conv2D(64, 3, activation='relu', strides=2, padding='same'),
keras.layers.Flatten(),
keras.layers.Dense(latent_dim + latent_dim)
])
self.decoder = keras.Sequential([
keras.layers.InputLayer(input_shape=(latent_dim,)),
keras.layers.Dense(7 * 7 * 64, activation='relu'),
keras.layers.Reshape((7, 7, 64)),
keras.layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same'),
keras.layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same'),
keras.layers.Conv2DTranspose(1, 3, padding='same')
])
def encode(self, x):
mean, log_var = ops.split(self.encoder(x), 2, axis=-1)
return mean, log_var
def reparameterize(self, mean, log_var):
eps = keras.random.normal(shape=ops.shape(mean))
return eps * ops.exp(log_var * 0.5) + mean
def decode(self, z):
return self.decoder(z)
def call(self, inputs):
mean, log_var = self.encode(inputs)
z = self.reparameterize(mean, log_var)
reconstructed = self.decode(z)
kl_loss = -0.5 * ops.sum(1 + log_var - ops.square(mean) - ops.exp(log_var), axis=-1)
self.add_loss(ops.mean(kl_loss))
return reconstructed
vae = VAE(latent_dim=2)
vae.compile(optimizer='adam', loss=keras.losses.MeanSquaredError())
残差网络 #
python
import keras
class ResidualBlock(keras.layers.Layer):
def __init__(self, filters, strides=1, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.strides = strides
def build(self, input_shape):
self.conv1 = keras.layers.Conv2D(
self.filters, 3, strides=self.strides, padding='same'
)
self.bn1 = keras.layers.BatchNormalization()
self.conv2 = keras.layers.Conv2D(
self.filters, 3, padding='same'
)
self.bn2 = keras.layers.BatchNormalization()
if self.strides != 1 or input_shape[-1] != self.filters:
self.shortcut = keras.Sequential([
keras.layers.Conv2D(self.filters, 1, strides=self.strides),
keras.layers.BatchNormalization()
])
else:
self.shortcut = lambda x: x
def call(self, inputs, training=False):
x = self.conv1(inputs)
x = self.bn1(x, training=training)
x = keras.activations.relu(x)
x = self.conv2(x)
x = self.bn2(x, training=training)
shortcut = self.shortcut(inputs)
x = keras.layers.Add()([x, shortcut])
x = keras.activations.relu(x)
return x
class ResNet(keras.Model):
def __init__(self, num_classes=10, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
def build(self, input_shape):
self.conv1 = keras.layers.Conv2D(64, 7, strides=2, padding='same')
self.bn1 = keras.layers.BatchNormalization()
self.maxpool = keras.layers.MaxPooling2D(3, strides=2, padding='same')
self.layer1 = self._make_layer(64, 2, strides=1)
self.layer2 = self._make_layer(128, 2, strides=2)
self.layer3 = self._make_layer(256, 2, strides=2)
self.layer4 = self._make_layer(512, 2, strides=2)
self.avgpool = keras.layers.GlobalAveragePooling2D()
self.fc = keras.layers.Dense(self.num_classes, activation='softmax')
def _make_layer(self, filters, blocks, strides):
layers = [ResidualBlock(filters, strides)]
for _ in range(1, blocks):
layers.append(ResidualBlock(filters))
return keras.Sequential(layers)
def call(self, inputs, training=False):
x = self.conv1(inputs)
x = self.bn1(x, training=training)
x = keras.activations.relu(x)
x = self.maxpool(x)
x = self.layer1(x, training=training)
x = self.layer2(x, training=training)
x = self.layer3(x, training=training)
x = self.layer4(x, training=training)
x = self.avgpool(x)
x = self.fc(x)
return x
model = ResNet(num_classes=10)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
带条件的模型 #
python
import keras
class ConditionalModel(keras.Model):
def __init__(self, hidden_units, num_classes, **kwargs):
super().__init__(**kwargs)
self.hidden_units = hidden_units
self.num_classes = num_classes
def build(self, input_shape):
self.dense1 = keras.layers.Dense(self.hidden_units)
self.dense2 = keras.layers.Dense(self.hidden_units)
self.bn1 = keras.layers.BatchNormalization()
self.bn2 = keras.layers.BatchNormalization()
self.output_layer = keras.layers.Dense(self.num_classes, activation='softmax')
self.dropout = keras.layers.Dropout(0.3)
def call(self, inputs, training=False, use_dropout=True):
x = self.dense1(inputs)
x = self.bn1(x, training=training)
x = keras.activations.relu(x)
if use_dropout:
x = self.dropout(x, training=training)
x = self.dense2(x)
x = self.bn2(x, training=training)
x = keras.activations.relu(x)
return self.output_layer(x)
model = ConditionalModel(hidden_units=128, num_classes=10)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
多输入模型 #
python
import keras
class MultiInputModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shapes):
self.image_encoder = keras.Sequential([
keras.layers.Conv2D(32, 3, activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(64, 3, activation='relu'),
keras.layers.GlobalAveragePooling2D()
])
self.text_encoder = keras.Sequential([
keras.layers.Embedding(10000, 64),
keras.layers.LSTM(64)
])
self.fusion = keras.layers.Dense(128, activation='relu')
self.classifier = keras.layers.Dense(10, activation='softmax')
def call(self, inputs):
image, text = inputs
image_features = self.image_encoder(image)
text_features = self.text_encoder(text)
combined = keras.ops.concatenate([image_features, text_features], axis=-1)
x = self.fusion(combined)
return self.classifier(x)
model = MultiInputModel()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
自定义训练循环 #
python
import keras
class CustomModel(keras.Model):
def __init__(self, hidden_units, num_classes, **kwargs):
super().__init__(**kwargs)
self.dense1 = keras.layers.Dense(hidden_units, activation='relu')
self.dense2 = keras.layers.Dense(num_classes, activation='softmax')
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
def train_step(self, data):
x, y = data
with keras.backend.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred)
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
x, y = data
y_pred = self(x, training=False)
self.compiled_loss(y, y_pred)
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
model = CustomModel(hidden_units=128, num_classes=10)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
下一步 #
现在你已经掌握了自定义模型,接下来学习 数据预处理,了解如何处理训练数据!
最后更新:2026-04-04