TensorFlow 模型转换实战 #
环境准备 #
安装依赖 #
bash
pip install tensorflow tf2onnx onnx onnxruntime
验证安装 #
python
import tensorflow as tf
import tf2onnx
import onnx
import onnxruntime as ort
print(f"TensorFlow 版本: {tf.__version__}")
print(f"tf2onnx 版本: {tf2onnx.__version__}")
print(f"ONNX 版本: {onnx.__version__}")
print(f"ONNX Runtime 版本: {ort.__version__}")
Keras 模型转换 #
Sequential 模型转换 #
python
import tensorflow as tf
import tf2onnx
import onnx
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17
)
onnx.save(onnx_model, "keras_model.onnx")
print("Keras 模型转换成功")
Functional API 模型转换 #
python
import tensorflow as tf
import tf2onnx
import onnx
def create_functional_model():
inputs = tf.keras.Input(shape=(224, 224, 3), name='input')
x = tf.keras.layers.Conv2D(64, 3, padding='same')(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(128, 3, padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(10, activation='softmax', name='output')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
model = create_functional_model()
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
output_path="functional_model.onnx"
)
print("Functional API 模型转换成功")
预训练模型转换 #
python
import tensorflow as tf
import tf2onnx
import onnx
model = tf.keras.applications.ResNet50V2(weights='imagenet')
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
output_path="resnet50_tf.onnx"
)
print("ResNet50 模型转换成功")
SavedModel 转换 #
保存 SavedModel #
python
import tensorflow as tf
model = tf.keras.applications.MobileNetV2(weights='imagenet')
tf.saved_model.save(model, "mobilenet_saved_model")
print("SavedModel 已保存")
从 SavedModel 转换 #
python
import tf2onnx
import onnx
onnx_model = tf2onnx.convert.from_saved_model(
"mobilenet_saved_model",
opset=17,
output_path="mobilenet.onnx"
)
print("SavedModel 转换成功")
命令行转换 #
bash
python -m tf2onnx.convert \
--saved-model mobilenet_saved_model \
--output mobilenet.onnx \
--opset 17
Concrete Function 转换 #
使用 @tf.function #
python
import tensorflow as tf
import tf2onnx
import onnx
class MyModel(tf.Module):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(64)
self.dense2 = tf.keras.layers.Dense(10)
@tf.function(input_signature=[tf.TensorSpec([None, 32], tf.float32)])
def __call__(self, x):
x = tf.nn.relu(self.dense1(x))
return self.dense2(x)
model = MyModel()
concrete_func = model.__call__.get_concrete_function()
onnx_model, _ = tf2onnx.convert.from_function(
concrete_func,
input_signature=[tf.TensorSpec([None, 32], tf.float32, name='input')],
opset=17
)
onnx.save(onnx_model, "tf_function_model.onnx")
print("tf.function 模型转换成功")
多输入多输出 #
python
import tensorflow as tf
import tf2onnx
import onnx
class MultiIOModule(tf.Module):
@tf.function(input_signature=[
tf.TensorSpec([None, 32], tf.float32, name='input_a'),
tf.TensorSpec([None, 16], tf.float32, name='input_b')
])
def __call__(self, a, b):
x = tf.concat([a, b], axis=1)
output1 = tf.keras.layers.Dense(10, name='output1')(x)
output2 = tf.keras.layers.Dense(5, name='output2')(x)
return output1, output2
module = MultiIOModule()
onnx_model, _ = tf2onnx.convert.from_function(
module.__call__.get_concrete_function(),
input_signature=[
tf.TensorSpec([None, 32], tf.float32, name='input_a'),
tf.TensorSpec([None, 16], tf.float32, name='input_b')
],
opset=17
)
onnx.save(onnx_model, "multi_io_model.onnx")
print("多输入多输出模型转换成功")
BERT 模型转换 #
HuggingFace Transformers #
python
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer
import tf2onnx
import onnx
model_name = "bert-base-uncased"
model = TFBertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
input_signature = [
tf.TensorSpec([None, None], tf.int32, name='input_ids'),
tf.TensorSpec([None, None], tf.int32, name='attention_mask'),
tf.TensorSpec([None, None], tf.int32, name='token_type_ids')
]
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
output_path="bert_tf.onnx"
)
print("BERT TensorFlow 模型转换成功")
转换选项 #
基本选项 #
python
import tf2onnx
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
target=None,
custom_op_handlers=None,
custom_rewriter=None,
extra_opset=None,
shape_override=None
)
目标平台 #
python
import tf2onnx
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
target="tensorrt",
opset=17
)
自定义算子处理 #
python
import tensorflow as tf
import tf2onnx
def custom_op_handler(ctx, node, name, args):
pass
custom_op_handlers = {
"CustomOp": custom_op_handler
}
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
custom_op_handlers=custom_op_handlers,
opset=17
)
验证与测试 #
输出一致性验证 #
python
import tensorflow as tf
import onnxruntime as ort
import numpy as np
def validate_tf_onnx(tf_model_path, onnx_path, input_shape, num_tests=5):
print("=" * 60)
print("TensorFlow vs ONNX 输出验证")
print("=" * 60)
tf_model = tf.saved_model.load(tf_model_path)
infer = tf_model.signatures['serving_default']
onnx_session = ort.InferenceSession(onnx_path)
input_name = list(infer.structured_input_signature[1].keys())[0]
onnx_input_name = onnx_session.get_inputs()[0].name
max_diff = 0
for i in range(num_tests):
test_input = np.random.randn(*input_shape).astype(np.float32)
tf_output = infer(tf.constant(test_input))
tf_output = list(tf_output.values())[0].numpy()
onnx_output = onnx_session.run(
None,
{onnx_input_name: test_input}
)[0]
diff = np.abs(tf_output - onnx_output).max()
max_diff = max(max_diff, diff)
print(f"测试 {i+1}: 最大差异 = {diff:.6f}")
print(f"\n最大差异: {max_diff:.6f}")
if max_diff < 1e-5:
print("✅ 验证通过")
else:
print("⚠️ 存在精度差异")
validate_tf_onnx("mobilenet_saved_model", "mobilenet.onnx", (1, 224, 224, 3))
性能对比 #
python
import tensorflow as tf
import onnxruntime as ort
import numpy as np
import time
def benchmark_tf_vs_onnx(tf_model_path, onnx_path, input_shape, num_runs=100):
tf_model = tf.saved_model.load(tf_model_path)
infer = tf_model.signatures['serving_default']
onnx_session = ort.InferenceSession(onnx_path)
onnx_input_name = onnx_session.get_inputs()[0].name
test_input = np.random.randn(*input_shape).astype(np.float32)
for _ in range(10):
infer(tf.constant(test_input))
onnx_session.run(None, {onnx_input_name: test_input})
start = time.time()
for _ in range(num_runs):
infer(tf.constant(test_input))
tf_time = (time.time() - start) / num_runs * 1000
start = time.time()
for _ in range(num_runs):
onnx_session.run(None, {onnx_input_name: test_input})
onnx_time = (time.time() - start) / num_runs * 1000
print("=" * 60)
print("性能对比")
print("=" * 60)
print(f"TensorFlow 推理时间: {tf_time:.2f} ms")
print(f"ONNX Runtime 推理时间: {onnx_time:.2f} ms")
print(f"加速比: {tf_time / onnx_time:.2f}x")
benchmark_tf_vs_onnx("mobilenet_saved_model", "mobilenet.onnx", (1, 224, 224, 3))
常见问题解决 #
问题 1:不支持的操作 #
python
import tensorflow as tf
import tf2onnx
class ModelWithUnsupportedOp(tf.Module):
@tf.function(input_signature=[tf.TensorSpec([None, 10], tf.float32)])
def __call__(self, x):
return tf.raw_ops.UniqueWithCountsV2(
x,
axis=[0],
out_idx=tf.int32
)
def custom_unique_handler(ctx, node, name, args):
pass
try:
onnx_model, _ = tf2onnx.convert.from_function(
ModelWithUnsupportedOp().__call__.get_concrete_function(),
input_signature=[tf.TensorSpec([None, 10], tf.float32)],
custom_op_handlers={"UniqueWithCountsV2": custom_unique_handler}
)
except Exception as e:
print(f"转换失败: {e}")
问题 2:动态形状 #
python
import tensorflow as tf
import tf2onnx
class DynamicShapeModel(tf.Module):
@tf.function(input_signature=[
tf.TensorSpec([None, None, 3], tf.float32, name='input')
])
def __call__(self, x):
return tf.reshape(x, [-1, 3])
model = DynamicShapeModel()
onnx_model, _ = tf2onnx.convert.from_function(
model.__call__.get_concrete_function(),
input_signature=[tf.TensorSpec([None, None, 3], tf.float32, name='input')],
opset=17
)
onnx.save(onnx_model, "dynamic_shape.onnx")
问题 3:NHWC vs NCHW #
python
import tensorflow as tf
import tf2onnx
model = tf.keras.applications.ResNet50(weights='imagenet')
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
inputs_as_nchw=['input']
)
onnx.save(onnx_model, "resnet_nchw.onnx")
完整转换脚本 #
python
import os
import tensorflow as tf
import tf2onnx
import onnx
import onnxruntime as ort
import numpy as np
def convert_tf_to_onnx(
model_or_path,
output_path,
input_signature=None,
opset=17,
validate=True
):
print("=" * 60)
print("TensorFlow 到 ONNX 转换")
print("=" * 60)
if isinstance(model_or_path, str):
print(f"\n从 SavedModel 加载: {model_or_path}")
onnx_model, _ = tf2onnx.convert.from_saved_model(
model_or_path,
opset=opset,
output_path=output_path
)
else:
print("\n从 Keras 模型转换")
if input_signature is None:
input_shape = model_or_path.input_shape
if isinstance(input_shape, list):
input_shape = input_shape[0]
input_signature = [tf.TensorSpec(
[None] + list(input_shape[1:]),
tf.float32,
name='input'
)]
onnx_model, _ = tf2onnx.convert.from_keras(
model_or_path,
input_signature=input_signature,
opset=opset,
output_path=output_path
)
print(f"\nONNX 模型已保存: {output_path}")
if validate:
print("\n验证 ONNX 模型...")
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("✅ 模型验证通过")
session = ort.InferenceSession(output_path)
print("\n模型信息:")
for inp in session.get_inputs():
print(f" 输入: {inp.name}, 形状: {inp.shape}")
for out in session.get_outputs():
print(f" 输出: {out.name}, 形状: {out.shape}")
file_size = os.path.getsize(output_path) / (1024 * 1024)
print(f"\n模型大小: {file_size:.2f} MB")
print("=" * 60)
model = tf.keras.applications.MobileNetV2(weights='imagenet')
convert_tf_to_onnx(model, "mobilenetv2.onnx")
下一步 #
现在你已经掌握了 TensorFlow 模型转换,接下来学习 模型部署,学习如何将 ONNX 模型部署到生产环境!
最后更新:2026-04-04