模型持久化 #
概述 #
模型持久化是将训练好的模型保存到磁盘,以便后续加载和使用的过程。
为什么需要持久化? #
| 原因 | 描述 |
|---|---|
| 避免重复训练 | 训练可能耗时很长 |
| 模型部署 | 在生产环境使用模型 |
| 版本管理 | 跟踪模型迭代 |
| 团队共享 | 与他人共享模型 |
持久化方法 #
| 方法 | 特点 | 推荐场景 |
|---|---|---|
| joblib | 高效,适合大数据 | sklearn 模型 |
| pickle | 通用,Python 标准 | 小型模型 |
| ONNX | 跨平台 | 跨框架部署 |
| PMML | 标准格式 | 企业级部署 |
joblib #
基本使用 #
python
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from joblib import dump, load
iris = load_iris()
X, y = iris.data, iris.target
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X, y)
dump(model, 'model.joblib')
loaded_model = load('model.joblib')
print(f"预测结果: {loaded_model.predict([[5.1, 3.5, 1.4, 0.2]])}")
压缩保存 #
python
dump(model, 'model_compressed.joblib', compress=3)
压缩级别 #
| 级别 | 压缩率 | 速度 |
|---|---|---|
| 0 | 无压缩 | 最快 |
| 1-3 | 低压缩 | 快 |
| 4-6 | 中等压缩 | 中等 |
| 7-9 | 高压缩 | 慢 |
保存 Pipeline #
python
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
pipe = Pipeline([
('scaler', StandardScaler()),
('classifier', LogisticRegression())
])
pipe.fit(X, y)
dump(pipe, 'pipeline.joblib')
loaded_pipe = load('pipeline.joblib')
pickle #
基本使用 #
python
import pickle
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)
with open('model.pkl', 'rb') as f:
loaded_model = pickle.load(f)
协议版本 #
python
with open('model.pkl', 'wb') as f:
pickle.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL)
协议对比 #
| 协议 | 特点 |
|---|---|
| 0 | ASCII,兼容性好 |
| 2 | Python 2 兼容 |
| 3 | Python 3 默认 |
| 4 | 支持大对象 |
| 5 | 支持带外数据 |
joblib vs pickle #
| 特性 | joblib | pickle |
|---|---|---|
| 大数组处理 | 高效 | 一般 |
| 压缩支持 | 内置 | 需额外处理 |
| 通用性 | sklearn 优化 | Python 标准 |
| 文件大小 | 较小 | 较大 |
python
import time
import os
start = time.time()
dump(model, 'model_joblib.joblib')
joblib_time = time.time() - start
joblib_size = os.path.getsize('model_joblib.joblib')
start = time.time()
with open('model_pickle.pkl', 'wb') as f:
pickle.dump(model, f)
pickle_time = time.time() - start
pickle_size = os.path.getsize('model_pickle.pkl')
print(f"joblib: {joblib_time:.4f}s, {joblib_size} bytes")
print(f"pickle: {pickle_time:.4f}s, {pickle_size} bytes")
模型版本管理 #
带版本保存 #
python
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'model_{timestamp}.joblib'
dump(model, filename)
版本信息 #
python
import sklearn
import json
metadata = {
'model_type': type(model).__name__,
'sklearn_version': sklearn.__version__,
'timestamp': datetime.now().isoformat(),
'parameters': model.get_params(),
'features': iris.feature_names
}
dump({'model': model, 'metadata': metadata}, 'model_with_meta.joblib')
加载并验证 #
python
data = load('model_with_meta.joblib')
model = data['model']
metadata = data['metadata']
if metadata['sklearn_version'] != sklearn.__version__:
print(f"警告: 版本不匹配! 保存版本: {metadata['sklearn_version']}, 当前版本: {sklearn.__version__}")
安全考虑 #
pickle 安全风险 #
python
import pickle
class Malicious:
def __reduce__(self):
import os
return (os.system, ('echo "恶意代码执行"',))
with open('malicious.pkl', 'wb') as f:
pickle.dump(Malicious(), f)
with open('malicious.pkl', 'rb') as f:
pickle.load(f)
安全加载 #
python
import joblib
try:
model = joblib.load('model.joblib')
except Exception as e:
print(f"加载失败: {e}")
最佳实践 #
- 只加载可信来源的模型
- 使用文件校验和验证
- 限制文件权限
python
import hashlib
def save_with_checksum(model, filename):
dump(model, filename)
with open(filename, 'rb') as f:
checksum = hashlib.sha256(f.read()).hexdigest()
with open(filename + '.sha256', 'w') as f:
f.write(checksum)
return checksum
def load_with_checksum(filename):
with open(filename + '.sha256', 'r') as f:
expected_checksum = f.read().strip()
with open(filename, 'rb') as f:
actual_checksum = hashlib.sha256(f.read()).hexdigest()
if actual_checksum != expected_checksum:
raise ValueError("校验和不匹配!")
return load(filename)
跨平台部署 #
ONNX 格式 #
python
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)
with open('model.onnx', 'wb') as f:
f.write(onnx_model.SerializeToString())
加载 ONNX #
python
import onnxruntime as rt
sess = rt.InferenceSession('model.onnx')
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
pred = sess.run([output_name], {input_name: [[5.1, 3.5, 1.4, 0.2]]})
模型压缩 #
量化 #
python
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(max_depth=5)
dt.fit(X, y)
dump(dt, 'tree_model.joblib')
剪枝 #
python
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(
n_estimators=50,
max_depth=10,
random_state=42
)
rf.fit(X, y)
完整示例 #
训练保存流程 #
python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from joblib import dump
import json
from datetime import datetime
import sklearn
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)
pipe = Pipeline([
('scaler', StandardScaler()),
('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
])
pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
metadata = {
'model_type': 'RandomForestClassifier',
'sklearn_version': sklearn.__version__,
'timestamp': datetime.now().isoformat(),
'accuracy': accuracy,
'features': iris.feature_names.tolist(),
'target_names': iris.target_names.tolist()
}
dump({
'model': pipe,
'metadata': metadata
}, 'iris_classifier.joblib')
with open('iris_classifier_meta.json', 'w') as f:
json.dump(metadata, f, indent=2)
print(f"模型保存完成,准确率: {accuracy:.4f}")
加载预测流程 #
python
from joblib import load
import json
data = load('iris_classifier.joblib')
model = data['model']
metadata = data['metadata']
with open('iris_classifier_meta.json', 'r') as f:
loaded_meta = json.load(f)
print(f"模型类型: {metadata['model_type']}")
print(f"训练准确率: {metadata['accuracy']:.4f}")
new_sample = [[5.1, 3.5, 1.4, 0.2]]
prediction = model.predict(new_sample)
probability = model.predict_proba(new_sample)
print(f"预测类别: {metadata['target_names'][prediction[0]]}")
print(f"预测概率: {probability}")
最佳实践 #
1. 保存完整 Pipeline #
python
dump(pipe, 'pipeline.joblib')
2. 记录元数据 #
python
metadata = {
'version': '1.0',
'date': datetime.now().isoformat(),
'sklearn_version': sklearn.__version__
}
3. 版本兼容检查 #
python
import sklearn
if metadata['sklearn_version'] != sklearn.__version__:
print("警告: sklearn 版本不匹配")
4. 定期备份 #
python
import shutil
from datetime import datetime
backup_name = f"model_backup_{datetime.now().strftime('%Y%m%d')}.joblib"
shutil.copy('model.joblib', backup_name)
下一步 #
掌握模型持久化后,继续学习 实战案例 了解完整项目实践!
最后更新:2026-04-04