模型量化 #

量化概述 #

量化是将模型从高精度(如 FP32)转换为低精度(如 INT8)的技术,可以显著减少模型大小和提升推理速度。

text
┌─────────────────────────────────────────────────────────────┐
│                    量化类型                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  动态量化 (Dynamic Quantization):                          │
│  ├── 权重预先量化                                           │
│  ├── 激活值运行时量化                                       │
│  ├── 无需校准数据                                           │
│  └── 适合 RNN/Transformer                                  │
│                                                             │
│  静态量化 (Static Quantization):                           │
│  ├── 权重和激活值都预先量化                                 │
│  ├── 需要校准数据                                           │
│  ├── 推理速度最快                                           │
│  └── 适合 CNN                                              │
│                                                             │
│  量化感知训练 (Quantization Aware Training):               │
│  ├── 训练时模拟量化                                         │
│  ├── 精度损失最小                                           │
│  ├── 需要重新训练                                           │
│  └── 适合精度敏感场景                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

量化精度对比 #

精度 位宽 数值范围 适用场景
FP32 32 ±3.4e38 训练、高精度推理
FP16 16 ±65504 GPU 推理
BF16 16 ±3.4e38 训练加速
INT8 8 -128 ~ 127 量化推理
INT4 4 -8 ~ 7 极致压缩

ONNX Runtime 量化 #

安装 #

bash
pip install onnx onnxruntime onnxruntime-quantization

动态量化 #

python
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

model_path = "model.onnx"
quantized_model_path = "model_dynamic_quantized.onnx"

quantize_dynamic(
    model_path,
    quantized_model_path,
    weight_type=QuantType.QUInt8,
    op_types_to_quantize=['MatMul', 'Gemm', 'Conv'],
    per_channel=False,
    reduce_range=False
)

print(f"量化模型已保存: {quantized_model_path}")

动态量化参数 #

python
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    
    weight_type=QuantType.QInt8,
    
    op_types_to_quantize=['MatMul', 'Gemm'],
    
    per_channel=True,
    
    reduce_range=True,
    
    extra_options={
        'WeightSymmetric': True,
        'ActivationSymmetric': False
    }
)

静态量化 #

python
import os
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader

class ImageCalibrationDataReader(CalibrationDataReader):
    def __init__(self, calibration_dir, batch_size=1, input_name="input"):
        self.calibration_dir = calibration_dir
        self.batch_size = batch_size
        self.input_name = input_name
        self.data_list = self._load_data()
        self.index = 0
    
    def _load_data(self):
        data = []
        for filename in os.listdir(self.calibration_dir)[:100]:
            if filename.endswith('.npy'):
                data.append(np.load(os.path.join(self.calibration_dir, filename)))
        return data
    
    def get_next(self):
        if self.index >= len(self.data_list):
            return None
        
        batch_data = self.data_list[self.index]
        self.index += 1
        
        return {self.input_name: batch_data}
    
    def rewind(self):
        self.index = 0

dr = ImageCalibrationDataReader("calibration_data", input_name="input")

quantize_static(
    model_input="model.onnx",
    model_output="model_static_quantized.onnx",
    calibration_data_reader=dr,
    quant_format=QuantFormat.QDQ,
    per_channel=False,
    weight_type=QuantType.QInt8,
    activation_type=QuantType.QUInt8
)

使用校准数据 #

python
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader

class RandomCalibrationDataReader(CalibrationDataReader):
    def __init__(self, input_name, input_shape, num_samples=100):
        self.input_name = input_name
        self.input_shape = input_shape
        self.num_samples = num_samples
        self.index = 0
    
    def get_next(self):
        if self.index >= self.num_samples:
            return None
        
        self.index += 1
        return {self.input_name: np.random.randn(*self.input_shape).astype(np.float32)}
    
    def rewind(self):
        self.index = 0

import onnxruntime as ort

session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
input_shape = [1, 3, 224, 224]

calibration_reader = RandomCalibrationDataReader(
    input_name=input_name,
    input_shape=input_shape,
    num_samples=100
)

quantize_static(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    calibration_data_reader=calibration_reader,
    quant_format=QuantFormat.QDQ,
    per_channel=True,
    weight_type=QuantType.QInt8
)

量化格式 #

QDQ 格式 #

python
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType

quantize_static(
    model_input="model.onnx",
    model_output="model_qdq.onnx",
    calibration_data_reader=calibration_reader,
    quant_format=QuantFormat.QDQ,
    weight_type=QuantType.QInt8,
    activation_type=QuantType.QUInt8
)

QOperator 格式 #

python
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType

quantize_static(
    model_input="model.onnx",
    model_output="model_qop.onnx",
    calibration_data_reader=calibration_reader,
    quant_format=QuantFormat.QOperator,
    weight_type=QuantType.QInt8
)

格式对比 #

text
┌─────────────────────────────────────────────────────────────┐
│                    量化格式对比                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  QDQ (Quantize-Dequantize):                                │
│  ├── 插入 QuantizeLinear 和 DequantizeLinear 节点          │
│  ├── 兼容性好                                               │
│  ├── 可调试                                                 │
│  └── 推荐使用                                               │
│                                                             │
│  QOperator:                                                │
│  ├── 使用量化算子(如 MatMulInteger)                       │
│  ├── 更紧凑                                                 │
│  ├── 兼容性较差                                             │
│  └── 特定硬件优化                                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

量化配置 #

指定量化算子 #

python
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    op_types_to_quantize=['MatMul', 'Gemm', 'Conv'],
    weight_type=QuantType.QInt8
)

指定量化节点 #

python
from onnxruntime.quantization import quantize_static, QuantType

nodes_to_quantize = ["conv1", "conv2", "fc1", "fc2"]

quantize_static(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    calibration_data_reader=calibration_reader,
    nodes_to_quantize=nodes_to_quantize,
    weight_type=QuantType.QInt8
)

排除节点 #

python
from onnxruntime.quantization import quantize_static, QuantType

nodes_to_exclude = ["conv_first", "conv_last"]

quantize_static(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    calibration_data_reader=calibration_reader,
    nodes_to_exclude=nodes_to_exclude,
    weight_type=QuantType.QInt8
)

额外选项 #

python
from onnxruntime.quantization import quantize_static, QuantType

extra_options = {
    'WeightSymmetric': True,
    'ActivationSymmetric': False,
    'MovingAverage': True,
    'MovingAveragePerChannel': True,
    'CalibTensorRangeSymmetric': False,
    'CalibMovingAverage': True,
    'CalibMovingAverageConstant': 0.01
}

quantize_static(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    calibration_data_reader=calibration_reader,
    extra_options=extra_options
)

量化评估 #

精度评估 #

python
import onnxruntime as ort
import numpy as np

def evaluate_accuracy(model_path, test_data, test_labels):
    session = ort.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name
    
    correct = 0
    total = len(test_data)
    
    for i in range(0, len(test_data), 32):
        batch_data = test_data[i:i+32]
        batch_labels = test_labels[i:i+32]
        
        outputs = session.run(None, {input_name: batch_data})
        predictions = np.argmax(outputs[0], axis=1)
        
        correct += np.sum(predictions == batch_labels)
    
    accuracy = correct / total
    return accuracy

fp32_accuracy = evaluate_accuracy("model.onnx", test_data, test_labels)
int8_accuracy = evaluate_accuracy("model_quantized.onnx", test_data, test_labels)

print(f"FP32 精度: {fp32_accuracy:.4f}")
print(f"INT8 精度: {int8_accuracy:.4f}")
print(f"精度损失: {(fp32_accuracy - int8_accuracy) * 100:.2f}%")

性能评估 #

python
import onnxruntime as ort
import numpy as np
import time

def benchmark_model(model_path, num_runs=100, warmup=10):
    session = ort.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name
    input_shape = session.get_inputs()[0].shape
    
    input_shape = [1 if isinstance(d, str) else d for d in input_shape]
    input_data = np.random.randn(*input_shape).astype(np.float32)
    
    for _ in range(warmup):
        session.run(None, {input_name: input_data})
    
    start = time.time()
    for _ in range(num_runs):
        session.run(None, {input_name: input_data})
    end = time.time()
    
    avg_time = (end - start) / num_runs * 1000
    return avg_time

fp32_time = benchmark_model("model.onnx")
int8_time = benchmark_model("model_quantized.onnx")

print(f"FP32 推理时间: {fp32_time:.2f} ms")
print(f"INT8 推理时间: {int8_time:.2f} ms")
print(f"加速比: {fp32_time / int8_time:.2f}x")

模型大小对比 #

python
import os

def get_model_size(model_path):
    return os.path.getsize(model_path) / (1024 * 1024)

fp32_size = get_model_size("model.onnx")
int8_size = get_model_size("model_quantized.onnx")

print(f"FP32 模型大小: {fp32_size:.2f} MB")
print(f"INT8 模型大小: {int8_size:.2f} MB")
print(f"压缩比: {fp32_size / int8_size:.2f}x")

量化感知训练 #

PyTorch QAT #

python
import torch
import torch.nn as nn
import torch.quantization as quant

class QATModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc = nn.Linear(64 * 224 * 224, 10)
        
        self.quant = quant.QuantStub()
        self.dequant = quant.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.dequant(x)
        return x

model = QATModel()

model.qconfig = quant.get_default_qat_qconfig('fbgemm')

quant.prepare_qat(model, inplace=True)

for epoch in range(num_epochs):
    train_one_epoch(model, train_loader, optimizer, criterion)

model_int8 = quant.convert(model)

torch.onnx.export(model_int8, torch.randn(1, 3, 224, 224), "qat_model.onnx")

最佳实践 #

量化选择指南 #

text
┌─────────────────────────────────────────────────────────────┐
│                    量化方法选择                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  选择动态量化:                                             │
│  ├── RNN / LSTM / Transformer 模型                         │
│  ├── 没有校准数据                                           │
│  ├── 快速部署                                               │
│  └── 可接受一定精度损失                                     │
│                                                             │
│  选择静态量化:                                             │
│  ├── CNN 模型                                               │
│  ├── 有校准数据                                             │
│  ├── 追求最佳推理性能                                       │
│  └── 可接受校准时间                                         │
│                                                             │
│  选择量化感知训练:                                         │
│  ├── 精度敏感场景                                           │
│  ├── 可以重新训练                                           │
│  ├── 追求最小精度损失                                       │
│  └── 有充足时间和资源                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

完整量化流程 #

python
import onnx
import onnxruntime as ort
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, QuantFormat, CalibrationDataReader

def quantize_model_pipeline(
    input_model,
    output_model,
    calibration_data=None,
    quantization_type="static"
):
    print(f"加载模型: {input_model}")
    model = onnx.load(input_model)
    onnx.checker.check_model(model)
    
    if quantization_type == "static":
        print("执行静态量化...")
        
        session = ort.InferenceSession(input_model)
        input_name = session.get_inputs()[0].name
        input_shape = session.get_inputs()[0].shape
        input_shape = [1 if isinstance(d, str) else d for d in input_shape]
        
        class SimpleCalibrationDataReader(CalibrationDataReader):
            def __init__(self, data, name):
                self.data = data if data is not None else [
                    np.random.randn(*input_shape).astype(np.float32) 
                    for _ in range(100)
                ]
                self.name = name
                self.index = 0
            
            def get_next(self):
                if self.index >= len(self.data):
                    return None
                data = self.data[self.index]
                self.index += 1
                return {self.name: data}
            
            def rewind(self):
                self.index = 0
        
        calibration_reader = SimpleCalibrationDataReader(calibration_data, input_name)
        
        quantize_static(
            model_input=input_model,
            model_output=output_model,
            calibration_data_reader=calibration_reader,
            quant_format=QuantFormat.QDQ,
            per_channel=True,
            weight_type=QuantType.QInt8,
            activation_type=QuantType.QUInt8
        )
    
    print(f"量化完成: {output_model}")
    
    original_size = onnx.load(input_model).ByteSize() / 1024 / 1024
    quantized_size = onnx.load(output_model).ByteSize() / 1024 / 1024
    print(f"原始大小: {original_size:.2f} MB")
    print(f"量化大小: {quantized_size:.2f} MB")
    print(f"压缩比: {original_size / quantized_size:.2f}x")

quantize_model_pipeline("model.onnx", "model_quantized.onnx")

下一步 #

现在你已经了解了模型量化,接下来学习 PyTorch 模型转换,通过实战案例巩固所学知识!

最后更新:2026-04-04