TensorFlow 模型转换实战 #

环境准备 #

安装依赖 #

bash
pip install tensorflow tf2onnx onnx onnxruntime

验证安装 #

python
import tensorflow as tf
import tf2onnx
import onnx
import onnxruntime as ort

print(f"TensorFlow 版本: {tf.__version__}")
print(f"tf2onnx 版本: {tf2onnx.__version__}")
print(f"ONNX 版本: {onnx.__version__}")
print(f"ONNX Runtime 版本: {ort.__version__}")

Keras 模型转换 #

Sequential 模型转换 #

python
import tensorflow as tf
import tf2onnx
import onnx

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    opset=17
)

onnx.save(onnx_model, "keras_model.onnx")
print("Keras 模型转换成功")

Functional API 模型转换 #

python
import tensorflow as tf
import tf2onnx
import onnx

def create_functional_model():
    inputs = tf.keras.Input(shape=(224, 224, 3), name='input')
    
    x = tf.keras.layers.Conv2D(64, 3, padding='same')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.MaxPooling2D()(x)
    
    x = tf.keras.layers.Conv2D(128, 3, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    
    outputs = tf.keras.layers.Dense(10, activation='softmax', name='output')(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

model = create_functional_model()

input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    opset=17,
    output_path="functional_model.onnx"
)

print("Functional API 模型转换成功")

预训练模型转换 #

python
import tensorflow as tf
import tf2onnx
import onnx

model = tf.keras.applications.ResNet50V2(weights='imagenet')

input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    opset=17,
    output_path="resnet50_tf.onnx"
)

print("ResNet50 模型转换成功")

SavedModel 转换 #

保存 SavedModel #

python
import tensorflow as tf

model = tf.keras.applications.MobileNetV2(weights='imagenet')

tf.saved_model.save(model, "mobilenet_saved_model")

print("SavedModel 已保存")

从 SavedModel 转换 #

python
import tf2onnx
import onnx

onnx_model = tf2onnx.convert.from_saved_model(
    "mobilenet_saved_model",
    opset=17,
    output_path="mobilenet.onnx"
)

print("SavedModel 转换成功")

命令行转换 #

bash
python -m tf2onnx.convert \
    --saved-model mobilenet_saved_model \
    --output mobilenet.onnx \
    --opset 17

Concrete Function 转换 #

使用 @tf.function #

python
import tensorflow as tf
import tf2onnx
import onnx

class MyModel(tf.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(64)
        self.dense2 = tf.keras.layers.Dense(10)
    
    @tf.function(input_signature=[tf.TensorSpec([None, 32], tf.float32)])
    def __call__(self, x):
        x = tf.nn.relu(self.dense1(x))
        return self.dense2(x)

model = MyModel()

concrete_func = model.__call__.get_concrete_function()

onnx_model, _ = tf2onnx.convert.from_function(
    concrete_func,
    input_signature=[tf.TensorSpec([None, 32], tf.float32, name='input')],
    opset=17
)

onnx.save(onnx_model, "tf_function_model.onnx")
print("tf.function 模型转换成功")

多输入多输出 #

python
import tensorflow as tf
import tf2onnx
import onnx

class MultiIOModule(tf.Module):
    @tf.function(input_signature=[
        tf.TensorSpec([None, 32], tf.float32, name='input_a'),
        tf.TensorSpec([None, 16], tf.float32, name='input_b')
    ])
    def __call__(self, a, b):
        x = tf.concat([a, b], axis=1)
        output1 = tf.keras.layers.Dense(10, name='output1')(x)
        output2 = tf.keras.layers.Dense(5, name='output2')(x)
        return output1, output2

module = MultiIOModule()

onnx_model, _ = tf2onnx.convert.from_function(
    module.__call__.get_concrete_function(),
    input_signature=[
        tf.TensorSpec([None, 32], tf.float32, name='input_a'),
        tf.TensorSpec([None, 16], tf.float32, name='input_b')
    ],
    opset=17
)

onnx.save(onnx_model, "multi_io_model.onnx")
print("多输入多输出模型转换成功")

BERT 模型转换 #

HuggingFace Transformers #

python
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer
import tf2onnx
import onnx

model_name = "bert-base-uncased"
model = TFBertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

input_signature = [
    tf.TensorSpec([None, None], tf.int32, name='input_ids'),
    tf.TensorSpec([None, None], tf.int32, name='attention_mask'),
    tf.TensorSpec([None, None], tf.int32, name='token_type_ids')
]

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    opset=17,
    output_path="bert_tf.onnx"
)

print("BERT TensorFlow 模型转换成功")

转换选项 #

基本选项 #

python
import tf2onnx

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    opset=17,
    target=None,
    custom_op_handlers=None,
    custom_rewriter=None,
    extra_opset=None,
    shape_override=None
)

目标平台 #

python
import tf2onnx

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    target="tensorrt",
    opset=17
)

自定义算子处理 #

python
import tensorflow as tf
import tf2onnx

def custom_op_handler(ctx, node, name, args):
    pass

custom_op_handlers = {
    "CustomOp": custom_op_handler
}

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    custom_op_handlers=custom_op_handlers,
    opset=17
)

验证与测试 #

输出一致性验证 #

python
import tensorflow as tf
import onnxruntime as ort
import numpy as np

def validate_tf_onnx(tf_model_path, onnx_path, input_shape, num_tests=5):
    print("=" * 60)
    print("TensorFlow vs ONNX 输出验证")
    print("=" * 60)
    
    tf_model = tf.saved_model.load(tf_model_path)
    infer = tf_model.signatures['serving_default']
    
    onnx_session = ort.InferenceSession(onnx_path)
    
    input_name = list(infer.structured_input_signature[1].keys())[0]
    onnx_input_name = onnx_session.get_inputs()[0].name
    
    max_diff = 0
    
    for i in range(num_tests):
        test_input = np.random.randn(*input_shape).astype(np.float32)
        
        tf_output = infer(tf.constant(test_input))
        tf_output = list(tf_output.values())[0].numpy()
        
        onnx_output = onnx_session.run(
            None,
            {onnx_input_name: test_input}
        )[0]
        
        diff = np.abs(tf_output - onnx_output).max()
        max_diff = max(max_diff, diff)
        
        print(f"测试 {i+1}: 最大差异 = {diff:.6f}")
    
    print(f"\n最大差异: {max_diff:.6f}")
    
    if max_diff < 1e-5:
        print("✅ 验证通过")
    else:
        print("⚠️ 存在精度差异")

validate_tf_onnx("mobilenet_saved_model", "mobilenet.onnx", (1, 224, 224, 3))

性能对比 #

python
import tensorflow as tf
import onnxruntime as ort
import numpy as np
import time

def benchmark_tf_vs_onnx(tf_model_path, onnx_path, input_shape, num_runs=100):
    tf_model = tf.saved_model.load(tf_model_path)
    infer = tf_model.signatures['serving_default']
    
    onnx_session = ort.InferenceSession(onnx_path)
    onnx_input_name = onnx_session.get_inputs()[0].name
    
    test_input = np.random.randn(*input_shape).astype(np.float32)
    
    for _ in range(10):
        infer(tf.constant(test_input))
        onnx_session.run(None, {onnx_input_name: test_input})
    
    start = time.time()
    for _ in range(num_runs):
        infer(tf.constant(test_input))
    tf_time = (time.time() - start) / num_runs * 1000
    
    start = time.time()
    for _ in range(num_runs):
        onnx_session.run(None, {onnx_input_name: test_input})
    onnx_time = (time.time() - start) / num_runs * 1000
    
    print("=" * 60)
    print("性能对比")
    print("=" * 60)
    print(f"TensorFlow 推理时间: {tf_time:.2f} ms")
    print(f"ONNX Runtime 推理时间: {onnx_time:.2f} ms")
    print(f"加速比: {tf_time / onnx_time:.2f}x")

benchmark_tf_vs_onnx("mobilenet_saved_model", "mobilenet.onnx", (1, 224, 224, 3))

常见问题解决 #

问题 1:不支持的操作 #

python
import tensorflow as tf
import tf2onnx

class ModelWithUnsupportedOp(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec([None, 10], tf.float32)])
    def __call__(self, x):
        return tf.raw_ops.UniqueWithCountsV2(
            x,
            axis=[0],
            out_idx=tf.int32
        )

def custom_unique_handler(ctx, node, name, args):
    pass

try:
    onnx_model, _ = tf2onnx.convert.from_function(
        ModelWithUnsupportedOp().__call__.get_concrete_function(),
        input_signature=[tf.TensorSpec([None, 10], tf.float32)],
        custom_op_handlers={"UniqueWithCountsV2": custom_unique_handler}
    )
except Exception as e:
    print(f"转换失败: {e}")

问题 2:动态形状 #

python
import tensorflow as tf
import tf2onnx

class DynamicShapeModel(tf.Module):
    @tf.function(input_signature=[
        tf.TensorSpec([None, None, 3], tf.float32, name='input')
    ])
    def __call__(self, x):
        return tf.reshape(x, [-1, 3])

model = DynamicShapeModel()

onnx_model, _ = tf2onnx.convert.from_function(
    model.__call__.get_concrete_function(),
    input_signature=[tf.TensorSpec([None, None, 3], tf.float32, name='input')],
    opset=17
)

onnx.save(onnx_model, "dynamic_shape.onnx")

问题 3:NHWC vs NCHW #

python
import tensorflow as tf
import tf2onnx

model = tf.keras.applications.ResNet50(weights='imagenet')

input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    opset=17,
    inputs_as_nchw=['input']
)

onnx.save(onnx_model, "resnet_nchw.onnx")

完整转换脚本 #

python
import os
import tensorflow as tf
import tf2onnx
import onnx
import onnxruntime as ort
import numpy as np

def convert_tf_to_onnx(
    model_or_path,
    output_path,
    input_signature=None,
    opset=17,
    validate=True
):
    print("=" * 60)
    print("TensorFlow 到 ONNX 转换")
    print("=" * 60)
    
    if isinstance(model_or_path, str):
        print(f"\n从 SavedModel 加载: {model_or_path}")
        onnx_model, _ = tf2onnx.convert.from_saved_model(
            model_or_path,
            opset=opset,
            output_path=output_path
        )
    else:
        print("\n从 Keras 模型转换")
        if input_signature is None:
            input_shape = model_or_path.input_shape
            if isinstance(input_shape, list):
                input_shape = input_shape[0]
            input_signature = [tf.TensorSpec(
                [None] + list(input_shape[1:]),
                tf.float32,
                name='input'
            )]
        
        onnx_model, _ = tf2onnx.convert.from_keras(
            model_or_path,
            input_signature=input_signature,
            opset=opset,
            output_path=output_path
        )
    
    print(f"\nONNX 模型已保存: {output_path}")
    
    if validate:
        print("\n验证 ONNX 模型...")
        onnx_model = onnx.load(output_path)
        onnx.checker.check_model(onnx_model)
        print("✅ 模型验证通过")
        
        session = ort.InferenceSession(output_path)
        print("\n模型信息:")
        for inp in session.get_inputs():
            print(f"  输入: {inp.name}, 形状: {inp.shape}")
        for out in session.get_outputs():
            print(f"  输出: {out.name}, 形状: {out.shape}")
    
    file_size = os.path.getsize(output_path) / (1024 * 1024)
    print(f"\n模型大小: {file_size:.2f} MB")
    print("=" * 60)

model = tf.keras.applications.MobileNetV2(weights='imagenet')
convert_tf_to_onnx(model, "mobilenetv2.onnx")

下一步 #

现在你已经掌握了 TensorFlow 模型转换,接下来学习 模型部署,学习如何将 ONNX 模型部署到生产环境!

最后更新:2026-04-04