模型部署实战 #

部署概述 #

模型训练完成后,需要将其部署到生产环境供用户使用。TensorFlow 提供了多种部署方案。

部署方案 #

text
┌─────────────────────────────────────────────────────────────┐
│                    部署方案                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  TensorFlow Serving                                         │
│  ├── 服务器部署                                             │
│  ├── 高性能推理                                             │
│  └── 支持模型版本管理                                       │
│                                                             │
│  TensorFlow Lite                                            │
│  ├── 移动端部署                                             │
│  ├── 模型量化                                               │
│  └── Android/iOS/嵌入式                                     │
│                                                             │
│  TensorFlow.js                                              │
│  ├── 浏览器部署                                             │
│  ├── Node.js 部署                                           │
│  └── 无需后端服务                                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

模型保存与导出 #

SavedModel 格式 #

python
import tensorflow as tf
import numpy as np

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

model.save('saved_model/my_model')

model.save('my_model.keras')

model.save_weights('weights.h5')

导出为 SavedModel #

python
import tensorflow as tf

class MyModel(tf.Module):
    def __init__(self):
        self.weights = tf.Variable(tf.random.normal([784, 10]))
        self.bias = tf.Variable(tf.zeros([10]))
    
    @tf.function(input_signature=[tf.TensorSpec(shape=[None, 784], dtype=tf.float32)])
    def __call__(self, x):
        return tf.matmul(x, self.weights) + self.bias

model = MyModel()
tf.saved_model.save(model, 'custom_model')

TensorFlow Serving #

安装与启动 #

bash
# 使用 Docker 安装
docker pull tensorflow/serving

# 启动服务
docker run -p 8501:8501 \
  --mount type=bind,source=/path/to/model,target=/models/my_model \
  -e MODEL_NAME=my_model \
  tensorflow/serving

REST API 推理 #

python
import requests
import json
import numpy as np

data = np.random.random((1, 784)).tolist()

payload = {
    "instances": data
}

response = requests.post(
    'http://localhost:8501/v1/models/my_model:predict',
    json=payload
)

predictions = response.json()['predictions']
print(f"预测结果: {predictions}")

gRPC 推理 #

python
import grpc
import numpy as np
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc

channel = grpc.insecure_channel('localhost:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

request = predict_pb2.PredictRequest()
request.model_spec.name = 'my_model'
request.model_spec.signature_name = 'serving_default'
request.inputs['input_tensor'].CopyFrom(
    tf.make_tensor_proto(np.random.random((1, 784)).astype(np.float32))
)

response = stub.Predict(request, 10.0)
predictions = tf.make_ndarray(response.outputs['output_0'])
print(f"预测结果: {predictions}")

TensorFlow Lite #

模型转换 #

python
import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.save('model.keras')

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

模型量化 #

python
import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]

def representative_dataset():
    for _ in range(100):
        yield [np.random.random((1, 784)).astype(np.float32)]

converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

tflite_quant_model = converter.convert()

with open('model_quant.tflite', 'wb') as f:
    f.write(tflite_quant_model)

Python 推理 #

python
import numpy as np
import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_data = np.random.random((1, 784)).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
print(f"预测结果: {output_data}")

Android 集成 #

kotlin
// 在 build.gradle 中添加依赖
// implementation 'org.tensorflow:tensorflow-lite:2.12.0'

// Kotlin 代码
val interpreter = Interpreter(loadModelFile())

val input = ByteBuffer.allocateDirect(4 * 784)
input.order(ByteOrder.nativeOrder())
// 填充输入数据

val output = Array(1) { FloatArray(10) }
interpreter.run(input, output)

fun loadModelFile(): MappedByteBuffer {
    val fileDescriptor = assets.openFd("model.tflite")
    val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
    val fileChannel = inputStream.channel
    val startOffset = fileDescriptor.startOffset
    val declaredLength = fileDescriptor.declaredLength
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}

TensorFlow.js #

模型转换 #

bash
# 安装转换工具
pip install tensorflowjs

# 转换 Keras 模型
tensorflowjs_converter --input_format keras \
  model.keras \
  tfjs_model

# 转换 SavedModel
tensorflowjs_converter --input_format tf_saved_model \
  saved_model \
  tfjs_model

浏览器推理 #

html
<!DOCTYPE html>
<html>
<head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
</head>
<body>
    <script>
        async function predict() {
            const model = await tf.loadLayersModel('tfjs_model/model.json');
            
            const input = tf.randomNormal([1, 784]);
            
            const prediction = model.predict(input);
            
            prediction.print();
        }
        
        predict();
    </script>
</body>
</html>

Node.js 推理 #

javascript
const tf = require('@tensorflow/tfjs-node');

async function predict() {
    const model = await tf.loadLayersModel('file://./tfjs_model/model.json');
    
    const input = tf.randomNormal([1, 784]);
    
    const prediction = model.predict(input);
    
    prediction.print();
}

predict();

模型优化 #

模型剪枝 #

python
import tensorflow as tf
import tensorflow_model_optimization as tfmot

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
        0.5,
        begin_step=0,
        frequency=100
    )
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

model_for_pruning.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep()
]

model_for_pruning.fit(
    x_train, y_train,
    epochs=10,
    callbacks=callbacks
)

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
model_for_export.save('pruned_model.keras')

知识蒸馏 #

python
import tensorflow as tf

class Distiller(tf.keras.Model):
    def __init__(self, teacher, student):
        super().__init__()
        self.teacher = teacher
        self.student = student
    
    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature
    
    def train_step(self, data):
        x, y = data
        
        with tf.GradientTape() as tape:
            teacher_predictions = self.teacher(x, training=False)
            student_predictions = self.student(x, training=True)
            
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature),
                tf.nn.softmax(student_predictions / self.temperature)
            )
            
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        
        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
        
        self.compiled_metrics.update_state(y, student_predictions)
        
        return {m.name: m.result() for m in self.metrics}

teacher = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10)
])

student = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10)
])

distiller = Distiller(teacher, student)
distiller.compile(
    optimizer='adam',
    metrics=['accuracy'],
    student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=3
)

distiller.fit(x_train, y_train, epochs=10)

总结 #

TensorFlow 提供了完整的模型部署解决方案:

  • TensorFlow Serving: 适合服务器端高性能推理
  • TensorFlow Lite: 适合移动端和嵌入式设备
  • TensorFlow.js: 适合浏览器和 Node.js 环境

选择合适的部署方案,可以让模型在各种场景下高效运行。

恭喜完成! #

你已经完成了 TensorFlow 完全指南的学习!从基础的张量操作到高级的分布式训练,从模型构建到生产部署,你现在具备了使用 TensorFlow 进行深度学习开发的完整知识体系。

最后更新:2026-04-04