模型部署实战 #
部署概述 #
模型训练完成后,需要将其部署到生产环境供用户使用。TensorFlow 提供了多种部署方案。
部署方案 #
text
┌─────────────────────────────────────────────────────────────┐
│ 部署方案 │
├─────────────────────────────────────────────────────────────┤
│ │
│ TensorFlow Serving │
│ ├── 服务器部署 │
│ ├── 高性能推理 │
│ └── 支持模型版本管理 │
│ │
│ TensorFlow Lite │
│ ├── 移动端部署 │
│ ├── 模型量化 │
│ └── Android/iOS/嵌入式 │
│ │
│ TensorFlow.js │
│ ├── 浏览器部署 │
│ ├── Node.js 部署 │
│ └── 无需后端服务 │
│ │
└─────────────────────────────────────────────────────────────┘
模型保存与导出 #
SavedModel 格式 #
python
import tensorflow as tf
import numpy as np
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.save('saved_model/my_model')
model.save('my_model.keras')
model.save_weights('weights.h5')
导出为 SavedModel #
python
import tensorflow as tf
class MyModel(tf.Module):
def __init__(self):
self.weights = tf.Variable(tf.random.normal([784, 10]))
self.bias = tf.Variable(tf.zeros([10]))
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 784], dtype=tf.float32)])
def __call__(self, x):
return tf.matmul(x, self.weights) + self.bias
model = MyModel()
tf.saved_model.save(model, 'custom_model')
TensorFlow Serving #
安装与启动 #
bash
# 使用 Docker 安装
docker pull tensorflow/serving
# 启动服务
docker run -p 8501:8501 \
--mount type=bind,source=/path/to/model,target=/models/my_model \
-e MODEL_NAME=my_model \
tensorflow/serving
REST API 推理 #
python
import requests
import json
import numpy as np
data = np.random.random((1, 784)).tolist()
payload = {
"instances": data
}
response = requests.post(
'http://localhost:8501/v1/models/my_model:predict',
json=payload
)
predictions = response.json()['predictions']
print(f"预测结果: {predictions}")
gRPC 推理 #
python
import grpc
import numpy as np
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
channel = grpc.insecure_channel('localhost:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'my_model'
request.model_spec.signature_name = 'serving_default'
request.inputs['input_tensor'].CopyFrom(
tf.make_tensor_proto(np.random.random((1, 784)).astype(np.float32))
)
response = stub.Predict(request, 10.0)
predictions = tf.make_ndarray(response.outputs['output_0'])
print(f"预测结果: {predictions}")
TensorFlow Lite #
模型转换 #
python
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
model.save('model.keras')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
模型量化 #
python
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset():
for _ in range(100):
yield [np.random.random((1, 784)).astype(np.float32)]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_quant_model = converter.convert()
with open('model_quant.tflite', 'wb') as f:
f.write(tflite_quant_model)
Python 推理 #
python
import numpy as np
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_data = np.random.random((1, 784)).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(f"预测结果: {output_data}")
Android 集成 #
kotlin
// 在 build.gradle 中添加依赖
// implementation 'org.tensorflow:tensorflow-lite:2.12.0'
// Kotlin 代码
val interpreter = Interpreter(loadModelFile())
val input = ByteBuffer.allocateDirect(4 * 784)
input.order(ByteOrder.nativeOrder())
// 填充输入数据
val output = Array(1) { FloatArray(10) }
interpreter.run(input, output)
fun loadModelFile(): MappedByteBuffer {
val fileDescriptor = assets.openFd("model.tflite")
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
TensorFlow.js #
模型转换 #
bash
# 安装转换工具
pip install tensorflowjs
# 转换 Keras 模型
tensorflowjs_converter --input_format keras \
model.keras \
tfjs_model
# 转换 SavedModel
tensorflowjs_converter --input_format tf_saved_model \
saved_model \
tfjs_model
浏览器推理 #
html
<!DOCTYPE html>
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
</head>
<body>
<script>
async function predict() {
const model = await tf.loadLayersModel('tfjs_model/model.json');
const input = tf.randomNormal([1, 784]);
const prediction = model.predict(input);
prediction.print();
}
predict();
</script>
</body>
</html>
Node.js 推理 #
javascript
const tf = require('@tensorflow/tfjs-node');
async function predict() {
const model = await tf.loadLayersModel('file://./tfjs_model/model.json');
const input = tf.randomNormal([1, 784]);
const prediction = model.predict(input);
prediction.print();
}
predict();
模型优化 #
模型剪枝 #
python
import tensorflow as tf
import tensorflow_model_optimization as tfmot
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
0.5,
begin_step=0,
frequency=100
)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
model_for_pruning.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep()
]
model_for_pruning.fit(
x_train, y_train,
epochs=10,
callbacks=callbacks
)
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
model_for_export.save('pruned_model.keras')
知识蒸馏 #
python
import tensorflow as tf
class Distiller(tf.keras.Model):
def __init__(self, teacher, student):
super().__init__()
self.teacher = teacher
self.student = student
def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
teacher_predictions = self.teacher(x, training=False)
student_predictions = self.student(x, training=True)
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature),
tf.nn.softmax(student_predictions / self.temperature)
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
self.compiled_metrics.update_state(y, student_predictions)
return {m.name: m.result() for m in self.metrics}
teacher = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])
student = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])
distiller = Distiller(teacher, student)
distiller.compile(
optimizer='adam',
metrics=['accuracy'],
student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=tf.keras.losses.KLDivergence(),
alpha=0.1,
temperature=3
)
distiller.fit(x_train, y_train, epochs=10)
总结 #
TensorFlow 提供了完整的模型部署解决方案:
- TensorFlow Serving: 适合服务器端高性能推理
- TensorFlow Lite: 适合移动端和嵌入式设备
- TensorFlow.js: 适合浏览器和 Node.js 环境
选择合适的部署方案,可以让模型在各种场景下高效运行。
恭喜完成! #
你已经完成了 TensorFlow 完全指南的学习!从基础的张量操作到高级的分布式训练,从模型构建到生产部署,你现在具备了使用 TensorFlow 进行深度学习开发的完整知识体系。
最后更新:2026-04-04