模型量化 #
量化概述 #
量化是将模型从高精度(如 FP32)转换为低精度(如 INT8)的技术,可以显著减少模型大小和提升推理速度。
text
┌─────────────────────────────────────────────────────────────┐
│ 量化类型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 动态量化 (Dynamic Quantization): │
│ ├── 权重预先量化 │
│ ├── 激活值运行时量化 │
│ ├── 无需校准数据 │
│ └── 适合 RNN/Transformer │
│ │
│ 静态量化 (Static Quantization): │
│ ├── 权重和激活值都预先量化 │
│ ├── 需要校准数据 │
│ ├── 推理速度最快 │
│ └── 适合 CNN │
│ │
│ 量化感知训练 (Quantization Aware Training): │
│ ├── 训练时模拟量化 │
│ ├── 精度损失最小 │
│ ├── 需要重新训练 │
│ └── 适合精度敏感场景 │
│ │
└─────────────────────────────────────────────────────────────┘
量化精度对比 #
| 精度 | 位宽 | 数值范围 | 适用场景 |
|---|---|---|---|
| FP32 | 32 | ±3.4e38 | 训练、高精度推理 |
| FP16 | 16 | ±65504 | GPU 推理 |
| BF16 | 16 | ±3.4e38 | 训练加速 |
| INT8 | 8 | -128 ~ 127 | 量化推理 |
| INT4 | 4 | -8 ~ 7 | 极致压缩 |
ONNX Runtime 量化 #
安装 #
bash
pip install onnx onnxruntime onnxruntime-quantization
动态量化 #
python
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
model_path = "model.onnx"
quantized_model_path = "model_dynamic_quantized.onnx"
quantize_dynamic(
model_path,
quantized_model_path,
weight_type=QuantType.QUInt8,
op_types_to_quantize=['MatMul', 'Gemm', 'Conv'],
per_channel=False,
reduce_range=False
)
print(f"量化模型已保存: {quantized_model_path}")
动态量化参数 #
python
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="model.onnx",
model_output="model_quantized.onnx",
weight_type=QuantType.QInt8,
op_types_to_quantize=['MatMul', 'Gemm'],
per_channel=True,
reduce_range=True,
extra_options={
'WeightSymmetric': True,
'ActivationSymmetric': False
}
)
静态量化 #
python
import os
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
class ImageCalibrationDataReader(CalibrationDataReader):
def __init__(self, calibration_dir, batch_size=1, input_name="input"):
self.calibration_dir = calibration_dir
self.batch_size = batch_size
self.input_name = input_name
self.data_list = self._load_data()
self.index = 0
def _load_data(self):
data = []
for filename in os.listdir(self.calibration_dir)[:100]:
if filename.endswith('.npy'):
data.append(np.load(os.path.join(self.calibration_dir, filename)))
return data
def get_next(self):
if self.index >= len(self.data_list):
return None
batch_data = self.data_list[self.index]
self.index += 1
return {self.input_name: batch_data}
def rewind(self):
self.index = 0
dr = ImageCalibrationDataReader("calibration_data", input_name="input")
quantize_static(
model_input="model.onnx",
model_output="model_static_quantized.onnx",
calibration_data_reader=dr,
quant_format=QuantFormat.QDQ,
per_channel=False,
weight_type=QuantType.QInt8,
activation_type=QuantType.QUInt8
)
使用校准数据 #
python
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
class RandomCalibrationDataReader(CalibrationDataReader):
def __init__(self, input_name, input_shape, num_samples=100):
self.input_name = input_name
self.input_shape = input_shape
self.num_samples = num_samples
self.index = 0
def get_next(self):
if self.index >= self.num_samples:
return None
self.index += 1
return {self.input_name: np.random.randn(*self.input_shape).astype(np.float32)}
def rewind(self):
self.index = 0
import onnxruntime as ort
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
input_shape = [1, 3, 224, 224]
calibration_reader = RandomCalibrationDataReader(
input_name=input_name,
input_shape=input_shape,
num_samples=100
)
quantize_static(
model_input="model.onnx",
model_output="model_quantized.onnx",
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QDQ,
per_channel=True,
weight_type=QuantType.QInt8
)
量化格式 #
QDQ 格式 #
python
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType
quantize_static(
model_input="model.onnx",
model_output="model_qdq.onnx",
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QDQ,
weight_type=QuantType.QInt8,
activation_type=QuantType.QUInt8
)
QOperator 格式 #
python
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType
quantize_static(
model_input="model.onnx",
model_output="model_qop.onnx",
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QOperator,
weight_type=QuantType.QInt8
)
格式对比 #
text
┌─────────────────────────────────────────────────────────────┐
│ 量化格式对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ QDQ (Quantize-Dequantize): │
│ ├── 插入 QuantizeLinear 和 DequantizeLinear 节点 │
│ ├── 兼容性好 │
│ ├── 可调试 │
│ └── 推荐使用 │
│ │
│ QOperator: │
│ ├── 使用量化算子(如 MatMulInteger) │
│ ├── 更紧凑 │
│ ├── 兼容性较差 │
│ └── 特定硬件优化 │
│ │
└─────────────────────────────────────────────────────────────┘
量化配置 #
指定量化算子 #
python
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="model.onnx",
model_output="model_quantized.onnx",
op_types_to_quantize=['MatMul', 'Gemm', 'Conv'],
weight_type=QuantType.QInt8
)
指定量化节点 #
python
from onnxruntime.quantization import quantize_static, QuantType
nodes_to_quantize = ["conv1", "conv2", "fc1", "fc2"]
quantize_static(
model_input="model.onnx",
model_output="model_quantized.onnx",
calibration_data_reader=calibration_reader,
nodes_to_quantize=nodes_to_quantize,
weight_type=QuantType.QInt8
)
排除节点 #
python
from onnxruntime.quantization import quantize_static, QuantType
nodes_to_exclude = ["conv_first", "conv_last"]
quantize_static(
model_input="model.onnx",
model_output="model_quantized.onnx",
calibration_data_reader=calibration_reader,
nodes_to_exclude=nodes_to_exclude,
weight_type=QuantType.QInt8
)
额外选项 #
python
from onnxruntime.quantization import quantize_static, QuantType
extra_options = {
'WeightSymmetric': True,
'ActivationSymmetric': False,
'MovingAverage': True,
'MovingAveragePerChannel': True,
'CalibTensorRangeSymmetric': False,
'CalibMovingAverage': True,
'CalibMovingAverageConstant': 0.01
}
quantize_static(
model_input="model.onnx",
model_output="model_quantized.onnx",
calibration_data_reader=calibration_reader,
extra_options=extra_options
)
量化评估 #
精度评估 #
python
import onnxruntime as ort
import numpy as np
def evaluate_accuracy(model_path, test_data, test_labels):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
correct = 0
total = len(test_data)
for i in range(0, len(test_data), 32):
batch_data = test_data[i:i+32]
batch_labels = test_labels[i:i+32]
outputs = session.run(None, {input_name: batch_data})
predictions = np.argmax(outputs[0], axis=1)
correct += np.sum(predictions == batch_labels)
accuracy = correct / total
return accuracy
fp32_accuracy = evaluate_accuracy("model.onnx", test_data, test_labels)
int8_accuracy = evaluate_accuracy("model_quantized.onnx", test_data, test_labels)
print(f"FP32 精度: {fp32_accuracy:.4f}")
print(f"INT8 精度: {int8_accuracy:.4f}")
print(f"精度损失: {(fp32_accuracy - int8_accuracy) * 100:.2f}%")
性能评估 #
python
import onnxruntime as ort
import numpy as np
import time
def benchmark_model(model_path, num_runs=100, warmup=10):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
input_shape = [1 if isinstance(d, str) else d for d in input_shape]
input_data = np.random.randn(*input_shape).astype(np.float32)
for _ in range(warmup):
session.run(None, {input_name: input_data})
start = time.time()
for _ in range(num_runs):
session.run(None, {input_name: input_data})
end = time.time()
avg_time = (end - start) / num_runs * 1000
return avg_time
fp32_time = benchmark_model("model.onnx")
int8_time = benchmark_model("model_quantized.onnx")
print(f"FP32 推理时间: {fp32_time:.2f} ms")
print(f"INT8 推理时间: {int8_time:.2f} ms")
print(f"加速比: {fp32_time / int8_time:.2f}x")
模型大小对比 #
python
import os
def get_model_size(model_path):
return os.path.getsize(model_path) / (1024 * 1024)
fp32_size = get_model_size("model.onnx")
int8_size = get_model_size("model_quantized.onnx")
print(f"FP32 模型大小: {fp32_size:.2f} MB")
print(f"INT8 模型大小: {int8_size:.2f} MB")
print(f"压缩比: {fp32_size / int8_size:.2f}x")
量化感知训练 #
PyTorch QAT #
python
import torch
import torch.nn as nn
import torch.quantization as quant
class QATModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.fc = nn.Linear(64 * 224 * 224, 10)
self.quant = quant.QuantStub()
self.dequant = quant.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.dequant(x)
return x
model = QATModel()
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
quant.prepare_qat(model, inplace=True)
for epoch in range(num_epochs):
train_one_epoch(model, train_loader, optimizer, criterion)
model_int8 = quant.convert(model)
torch.onnx.export(model_int8, torch.randn(1, 3, 224, 224), "qat_model.onnx")
最佳实践 #
量化选择指南 #
text
┌─────────────────────────────────────────────────────────────┐
│ 量化方法选择 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 选择动态量化: │
│ ├── RNN / LSTM / Transformer 模型 │
│ ├── 没有校准数据 │
│ ├── 快速部署 │
│ └── 可接受一定精度损失 │
│ │
│ 选择静态量化: │
│ ├── CNN 模型 │
│ ├── 有校准数据 │
│ ├── 追求最佳推理性能 │
│ └── 可接受校准时间 │
│ │
│ 选择量化感知训练: │
│ ├── 精度敏感场景 │
│ ├── 可以重新训练 │
│ ├── 追求最小精度损失 │
│ └── 有充足时间和资源 │
│ │
└─────────────────────────────────────────────────────────────┘
完整量化流程 #
python
import onnx
import onnxruntime as ort
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, QuantFormat, CalibrationDataReader
def quantize_model_pipeline(
input_model,
output_model,
calibration_data=None,
quantization_type="static"
):
print(f"加载模型: {input_model}")
model = onnx.load(input_model)
onnx.checker.check_model(model)
if quantization_type == "static":
print("执行静态量化...")
session = ort.InferenceSession(input_model)
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
input_shape = [1 if isinstance(d, str) else d for d in input_shape]
class SimpleCalibrationDataReader(CalibrationDataReader):
def __init__(self, data, name):
self.data = data if data is not None else [
np.random.randn(*input_shape).astype(np.float32)
for _ in range(100)
]
self.name = name
self.index = 0
def get_next(self):
if self.index >= len(self.data):
return None
data = self.data[self.index]
self.index += 1
return {self.name: data}
def rewind(self):
self.index = 0
calibration_reader = SimpleCalibrationDataReader(calibration_data, input_name)
quantize_static(
model_input=input_model,
model_output=output_model,
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QDQ,
per_channel=True,
weight_type=QuantType.QInt8,
activation_type=QuantType.QUInt8
)
print(f"量化完成: {output_model}")
original_size = onnx.load(input_model).ByteSize() / 1024 / 1024
quantized_size = onnx.load(output_model).ByteSize() / 1024 / 1024
print(f"原始大小: {original_size:.2f} MB")
print(f"量化大小: {quantized_size:.2f} MB")
print(f"压缩比: {original_size / quantized_size:.2f}x")
quantize_model_pipeline("model.onnx", "model_quantized.onnx")
下一步 #
现在你已经了解了模型量化,接下来学习 PyTorch 模型转换,通过实战案例巩固所学知识!
最后更新:2026-04-04