自定义模型 #

为什么需要自定义模型? #

text
┌─────────────────────────────────────────────────────────────┐
│                    模型构建方式对比                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Sequential:                                                │
│  ✅ 简单易用                                                │
│  ❌ 只能线性堆叠                                            │
│                                                             │
│  Functional API:                                            │
│  ✅ 支持复杂拓扑                                            │
│  ✅ 多输入多输出                                            │
│  ❌ 静态图结构                                              │
│                                                             │
│  Model Subclassing:                                         │
│  ✅ 最大灵活性                                              │
│  ✅ 动态计算图                                              │
│  ✅ 自定义前向传播                                          │
│  ❌ 代码更复杂                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本结构 #

python
import keras

class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def build(self, input_shape):
        pass
    
    def call(self, inputs):
        return inputs

简单示例 #

python
import keras

class SimpleModel(keras.Model):
    def __init__(self, hidden_units, num_classes, **kwargs):
        super().__init__(**kwargs)
        self.hidden_units = hidden_units
        self.num_classes = num_classes
    
    def build(self, input_shape):
        self.dense1 = keras.layers.Dense(self.hidden_units, activation='relu')
        self.dense2 = keras.layers.Dense(self.hidden_units // 2, activation='relu')
        self.output_layer = keras.layers.Dense(self.num_classes, activation='softmax')
    
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.output_layer(x)

model = SimpleModel(hidden_units=128, num_classes=10)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

变分自编码器(VAE) #

python
import keras
import keras.ops as ops

class VAE(keras.Model):
    def __init__(self, latent_dim, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        
        self.encoder = keras.Sequential([
            keras.layers.InputLayer(input_shape=(28, 28, 1)),
            keras.layers.Conv2D(32, 3, activation='relu', strides=2, padding='same'),
            keras.layers.Conv2D(64, 3, activation='relu', strides=2, padding='same'),
            keras.layers.Flatten(),
            keras.layers.Dense(latent_dim + latent_dim)
        ])
        
        self.decoder = keras.Sequential([
            keras.layers.InputLayer(input_shape=(latent_dim,)),
            keras.layers.Dense(7 * 7 * 64, activation='relu'),
            keras.layers.Reshape((7, 7, 64)),
            keras.layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same'),
            keras.layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same'),
            keras.layers.Conv2DTranspose(1, 3, padding='same')
        ])
    
    def encode(self, x):
        mean, log_var = ops.split(self.encoder(x), 2, axis=-1)
        return mean, log_var
    
    def reparameterize(self, mean, log_var):
        eps = keras.random.normal(shape=ops.shape(mean))
        return eps * ops.exp(log_var * 0.5) + mean
    
    def decode(self, z):
        return self.decoder(z)
    
    def call(self, inputs):
        mean, log_var = self.encode(inputs)
        z = self.reparameterize(mean, log_var)
        reconstructed = self.decode(z)
        
        kl_loss = -0.5 * ops.sum(1 + log_var - ops.square(mean) - ops.exp(log_var), axis=-1)
        self.add_loss(ops.mean(kl_loss))
        
        return reconstructed

vae = VAE(latent_dim=2)
vae.compile(optimizer='adam', loss=keras.losses.MeanSquaredError())

残差网络 #

python
import keras

class ResidualBlock(keras.layers.Layer):
    def __init__(self, filters, strides=1, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.strides = strides
    
    def build(self, input_shape):
        self.conv1 = keras.layers.Conv2D(
            self.filters, 3, strides=self.strides, padding='same'
        )
        self.bn1 = keras.layers.BatchNormalization()
        self.conv2 = keras.layers.Conv2D(
            self.filters, 3, padding='same'
        )
        self.bn2 = keras.layers.BatchNormalization()
        
        if self.strides != 1 or input_shape[-1] != self.filters:
            self.shortcut = keras.Sequential([
                keras.layers.Conv2D(self.filters, 1, strides=self.strides),
                keras.layers.BatchNormalization()
            ])
        else:
            self.shortcut = lambda x: x
    
    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = keras.activations.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        
        shortcut = self.shortcut(inputs)
        x = keras.layers.Add()([x, shortcut])
        x = keras.activations.relu(x)
        
        return x

class ResNet(keras.Model):
    def __init__(self, num_classes=10, **kwargs):
        super().__init__(**kwargs)
        self.num_classes = num_classes
    
    def build(self, input_shape):
        self.conv1 = keras.layers.Conv2D(64, 7, strides=2, padding='same')
        self.bn1 = keras.layers.BatchNormalization()
        self.maxpool = keras.layers.MaxPooling2D(3, strides=2, padding='same')
        
        self.layer1 = self._make_layer(64, 2, strides=1)
        self.layer2 = self._make_layer(128, 2, strides=2)
        self.layer3 = self._make_layer(256, 2, strides=2)
        self.layer4 = self._make_layer(512, 2, strides=2)
        
        self.avgpool = keras.layers.GlobalAveragePooling2D()
        self.fc = keras.layers.Dense(self.num_classes, activation='softmax')
    
    def _make_layer(self, filters, blocks, strides):
        layers = [ResidualBlock(filters, strides)]
        for _ in range(1, blocks):
            layers.append(ResidualBlock(filters))
        return keras.Sequential(layers)
    
    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = keras.activations.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        
        x = self.avgpool(x)
        x = self.fc(x)
        
        return x

model = ResNet(num_classes=10)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

带条件的模型 #

python
import keras

class ConditionalModel(keras.Model):
    def __init__(self, hidden_units, num_classes, **kwargs):
        super().__init__(**kwargs)
        self.hidden_units = hidden_units
        self.num_classes = num_classes
    
    def build(self, input_shape):
        self.dense1 = keras.layers.Dense(self.hidden_units)
        self.dense2 = keras.layers.Dense(self.hidden_units)
        self.bn1 = keras.layers.BatchNormalization()
        self.bn2 = keras.layers.BatchNormalization()
        self.output_layer = keras.layers.Dense(self.num_classes, activation='softmax')
        self.dropout = keras.layers.Dropout(0.3)
    
    def call(self, inputs, training=False, use_dropout=True):
        x = self.dense1(inputs)
        x = self.bn1(x, training=training)
        x = keras.activations.relu(x)
        
        if use_dropout:
            x = self.dropout(x, training=training)
        
        x = self.dense2(x)
        x = self.bn2(x, training=training)
        x = keras.activations.relu(x)
        
        return self.output_layer(x)

model = ConditionalModel(hidden_units=128, num_classes=10)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

多输入模型 #

python
import keras

class MultiInputModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def build(self, input_shapes):
        self.image_encoder = keras.Sequential([
            keras.layers.Conv2D(32, 3, activation='relu'),
            keras.layers.MaxPooling2D(),
            keras.layers.Conv2D(64, 3, activation='relu'),
            keras.layers.GlobalAveragePooling2D()
        ])
        
        self.text_encoder = keras.Sequential([
            keras.layers.Embedding(10000, 64),
            keras.layers.LSTM(64)
        ])
        
        self.fusion = keras.layers.Dense(128, activation='relu')
        self.classifier = keras.layers.Dense(10, activation='softmax')
    
    def call(self, inputs):
        image, text = inputs
        image_features = self.image_encoder(image)
        text_features = self.text_encoder(text)
        
        combined = keras.ops.concatenate([image_features, text_features], axis=-1)
        x = self.fusion(combined)
        return self.classifier(x)

model = MultiInputModel()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

自定义训练循环 #

python
import keras

class CustomModel(keras.Model):
    def __init__(self, hidden_units, num_classes, **kwargs):
        super().__init__(**kwargs)
        self.dense1 = keras.layers.Dense(hidden_units, activation='relu')
        self.dense2 = keras.layers.Dense(num_classes, activation='softmax')
    
    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)
    
    def train_step(self, data):
        x, y = data
        
        with keras.backend.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        self.compiled_metrics.update_state(y, y_pred)
        
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

model = CustomModel(hidden_units=128, num_classes=10)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

下一步 #

现在你已经掌握了自定义模型,接下来学习 数据预处理,了解如何处理训练数据!

最后更新:2026-04-04