模型持久化 #

概述 #

模型持久化是将训练好的模型保存到磁盘,以便后续加载和使用的过程。

为什么需要持久化? #

原因 描述
避免重复训练 训练可能耗时很长
模型部署 在生产环境使用模型
版本管理 跟踪模型迭代
团队共享 与他人共享模型

持久化方法 #

方法 特点 推荐场景
joblib 高效,适合大数据 sklearn 模型
pickle 通用,Python 标准 小型模型
ONNX 跨平台 跨框架部署
PMML 标准格式 企业级部署

joblib #

基本使用 #

python
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from joblib import dump, load

iris = load_iris()
X, y = iris.data, iris.target

model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X, y)

dump(model, 'model.joblib')

loaded_model = load('model.joblib')
print(f"预测结果: {loaded_model.predict([[5.1, 3.5, 1.4, 0.2]])}")

压缩保存 #

python
dump(model, 'model_compressed.joblib', compress=3)

压缩级别 #

级别 压缩率 速度
0 无压缩 最快
1-3 低压缩
4-6 中等压缩 中等
7-9 高压缩

保存 Pipeline #

python
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', LogisticRegression())
])

pipe.fit(X, y)

dump(pipe, 'pipeline.joblib')

loaded_pipe = load('pipeline.joblib')

pickle #

基本使用 #

python
import pickle

with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

with open('model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)

协议版本 #

python
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)

协议对比 #

协议 特点
0 ASCII,兼容性好
2 Python 2 兼容
3 Python 3 默认
4 支持大对象
5 支持带外数据

joblib vs pickle #

特性 joblib pickle
大数组处理 高效 一般
压缩支持 内置 需额外处理
通用性 sklearn 优化 Python 标准
文件大小 较小 较大
python
import time
import os

start = time.time()
dump(model, 'model_joblib.joblib')
joblib_time = time.time() - start
joblib_size = os.path.getsize('model_joblib.joblib')

start = time.time()
with open('model_pickle.pkl', 'wb') as f:
    pickle.dump(model, f)
pickle_time = time.time() - start
pickle_size = os.path.getsize('model_pickle.pkl')

print(f"joblib: {joblib_time:.4f}s, {joblib_size} bytes")
print(f"pickle: {pickle_time:.4f}s, {pickle_size} bytes")

模型版本管理 #

带版本保存 #

python
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'model_{timestamp}.joblib'
dump(model, filename)

版本信息 #

python
import sklearn
import json

metadata = {
    'model_type': type(model).__name__,
    'sklearn_version': sklearn.__version__,
    'timestamp': datetime.now().isoformat(),
    'parameters': model.get_params(),
    'features': iris.feature_names
}

dump({'model': model, 'metadata': metadata}, 'model_with_meta.joblib')

加载并验证 #

python
data = load('model_with_meta.joblib')
model = data['model']
metadata = data['metadata']

if metadata['sklearn_version'] != sklearn.__version__:
    print(f"警告: 版本不匹配! 保存版本: {metadata['sklearn_version']}, 当前版本: {sklearn.__version__}")

安全考虑 #

pickle 安全风险 #

python
import pickle

class Malicious:
    def __reduce__(self):
        import os
        return (os.system, ('echo "恶意代码执行"',))

with open('malicious.pkl', 'wb') as f:
    pickle.dump(Malicious(), f)

with open('malicious.pkl', 'rb') as f:
    pickle.load(f)

安全加载 #

python
import joblib

try:
    model = joblib.load('model.joblib')
except Exception as e:
    print(f"加载失败: {e}")

最佳实践 #

  1. 只加载可信来源的模型
  2. 使用文件校验和验证
  3. 限制文件权限
python
import hashlib

def save_with_checksum(model, filename):
    dump(model, filename)
    with open(filename, 'rb') as f:
        checksum = hashlib.sha256(f.read()).hexdigest()
    with open(filename + '.sha256', 'w') as f:
        f.write(checksum)
    return checksum

def load_with_checksum(filename):
    with open(filename + '.sha256', 'r') as f:
        expected_checksum = f.read().strip()
    with open(filename, 'rb') as f:
        actual_checksum = hashlib.sha256(f.read()).hexdigest()
    if actual_checksum != expected_checksum:
        raise ValueError("校验和不匹配!")
    return load(filename)

跨平台部署 #

ONNX 格式 #

python
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

initial_type = [('float_input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)

with open('model.onnx', 'wb') as f:
    f.write(onnx_model.SerializeToString())

加载 ONNX #

python
import onnxruntime as rt

sess = rt.InferenceSession('model.onnx')
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

pred = sess.run([output_name], {input_name: [[5.1, 3.5, 1.4, 0.2]]})

模型压缩 #

量化 #

python
from sklearn.tree import DecisionTreeClassifier

dt = DecisionTreeClassifier(max_depth=5)
dt.fit(X, y)

dump(dt, 'tree_model.joblib')

剪枝 #

python
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(
    n_estimators=50,
    max_depth=10,
    random_state=42
)
rf.fit(X, y)

完整示例 #

训练保存流程 #

python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from joblib import dump
import json
from datetime import datetime
import sklearn

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.2, random_state=42
)

pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
])

pipe.fit(X_train, y_train)

y_pred = pipe.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

metadata = {
    'model_type': 'RandomForestClassifier',
    'sklearn_version': sklearn.__version__,
    'timestamp': datetime.now().isoformat(),
    'accuracy': accuracy,
    'features': iris.feature_names.tolist(),
    'target_names': iris.target_names.tolist()
}

dump({
    'model': pipe,
    'metadata': metadata
}, 'iris_classifier.joblib')

with open('iris_classifier_meta.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"模型保存完成,准确率: {accuracy:.4f}")

加载预测流程 #

python
from joblib import load
import json

data = load('iris_classifier.joblib')
model = data['model']
metadata = data['metadata']

with open('iris_classifier_meta.json', 'r') as f:
    loaded_meta = json.load(f)

print(f"模型类型: {metadata['model_type']}")
print(f"训练准确率: {metadata['accuracy']:.4f}")

new_sample = [[5.1, 3.5, 1.4, 0.2]]
prediction = model.predict(new_sample)
probability = model.predict_proba(new_sample)

print(f"预测类别: {metadata['target_names'][prediction[0]]}")
print(f"预测概率: {probability}")

最佳实践 #

1. 保存完整 Pipeline #

python
dump(pipe, 'pipeline.joblib')

2. 记录元数据 #

python
metadata = {
    'version': '1.0',
    'date': datetime.now().isoformat(),
    'sklearn_version': sklearn.__version__
}

3. 版本兼容检查 #

python
import sklearn

if metadata['sklearn_version'] != sklearn.__version__:
    print("警告: sklearn 版本不匹配")

4. 定期备份 #

python
import shutil
from datetime import datetime

backup_name = f"model_backup_{datetime.now().strftime('%Y%m%d')}.joblib"
shutil.copy('model.joblib', backup_name)

下一步 #

掌握模型持久化后,继续学习 实战案例 了解完整项目实践!

最后更新:2026-04-04