模型注册中心 #
Model Registry 概述 #
MLflow Model Registry 是一个集中式模型存储库,提供模型版本管理、阶段转换和元数据管理功能。
text
┌─────────────────────────────────────────────────────────────┐
│ Model Registry 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Registered Model │ │
│ │ ├── 名称、描述 │ │
│ │ ├── 标签 │ │
│ │ └── 模型版本列表 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Model Version │ │
│ │ ├── 版本号 │ │
│ │ ├── 来源 Run ID │ │
│ │ ├── 阶段 (Stage) │ │
│ │ ├── 描述 │ │
│ │ └── 标签 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Model Stage │ │
│ │ ├── None/Archived (归档) │ │
│ │ ├── Staging (预发布) │ │
│ │ └── Production (生产环境) │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
核心概念 #
Registered Model(注册模型) #
text
┌─────────────────────────────────────────────────────────────┐
│ Registered Model │
├─────────────────────────────────────────────────────────────┤
│ │
│ 属性: │
│ ├── name: 模型名称(唯一标识) │
│ ├── description: 模型描述 │
│ ├── tags: 标签 │
│ ├── creation_timestamp: 创建时间 │
│ └── last_updated_timestamp: 最后更新时间 │
│ │
│ 包含: │
│ └── 多个 Model Versions(模型版本) │
│ │
└─────────────────────────────────────────────────────────────┘
Model Version(模型版本) #
text
┌─────────────────────────────────────────────────────────────┐
│ Model Version │
├─────────────────────────────────────────────────────────────┤
│ │
│ 属性: │
│ ├── version: 版本号(自动递增) │
│ ├── name: 所属模型名称 │
│ ├── source: 模型文件路径 │
│ ├── run_id: 来源 Run ID │
│ ├── status: 状态 │
│ ├── current_stage: 当前阶段 │
│ ├── description: 版本描述 │
│ ├── tags: 标签 │
│ └── creation_timestamp: 创建时间 │
│ │
└─────────────────────────────────────────────────────────────┘
Model Stage(模型阶段) #
text
┌─────────────────────────────────────────────────────────────┐
│ Model Stage │
├─────────────────────────────────────────────────────────────┤
│ │
│ None / Archived │
│ ───────────────────────────────────────────────────────── │
│ 归档状态,不用于生产 │
│ 可以随时转换到其他阶段 │
│ │
│ Staging │
│ ───────────────────────────────────────────────────────── │
│ 预发布阶段 │
│ 用于测试和验证 │
│ 准备进入生产环境 │
│ │
│ Production │
│ ───────────────────────────────────────────────────────── │
│ 生产环境 │
│ 正式对外服务 │
│ 每个模型只能有一个 Production 版本 │
│ │
└─────────────────────────────────────────────────────────────┘
注册模型 #
通过 log_model 注册 #
python
import mlflow.sklearn
with mlflow.start_run():
mlflow.sklearn.log_model(
model,
"model",
registered_model_name="my_model"
)
使用 API 注册 #
python
import mlflow
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.create_registered_model(
name="my_model",
tags={"project": "customer-churn", "team": "data-science"},
description="Customer churn prediction model"
)
result = client.create_model_version(
name="my_model",
source="runs:/<run_id>/model",
description="Version 1 of the model"
)
print(f"Version: {result.version}")
通过 UI 注册 #
text
1. 在 Run 详情页面找到 Artifacts
2. 点击模型目录
3. 点击 "Register Model" 按钮
4. 选择或创建模型名称
5. 点击 Register
模型版本管理 #
列出所有模型 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
for rm in client.search_registered_models():
print(f"Name: {rm.name}")
print(f"Description: {rm.description}")
print(f"Tags: {rm.tags}")
搜索模型 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
models = client.search_registered_models(
filter_string="tags.project = 'customer-churn'"
)
for model in models:
print(f"Model: {model.name}")
获取模型版本 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
versions = client.search_model_versions("name='my_model'")
for version in versions:
print(f"Version: {version.version}")
print(f"Stage: {version.current_stage}")
print(f"Status: {version.status}")
获取特定版本 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
version = client.get_model_version("my_model", "1")
print(f"Version: {version.version}")
print(f"Stage: {version.current_stage}")
print(f"Source: {version.source}")
阶段转换 #
转换模型阶段 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.transition_model_version_stage(
name="my_model",
version="1",
stage="Staging"
)
client.transition_model_version_stage(
name="my_model",
version="2",
stage="Production",
archive_existing_versions=True
)
阶段转换流程 #
text
┌─────────────────────────────────────────────────────────────┐
│ 阶段转换流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 开发阶段 │
│ ───────────────────────────────────────────────────── │
│ 1. 训练新模型 │
│ 2. 注册模型版本 │
│ 3. 初始阶段为 None │
│ │
│ 测试阶段 │
│ ───────────────────────────────────────────────────── │
│ 4. 转换到 Staging │
│ 5. 进行 A/B 测试 │
│ 6. 验证模型性能 │
│ │
│ 生产阶段 │
│ ───────────────────────────────────────────────────── │
│ 7. 转换到 Production │
│ 8. 归档旧版本 │
│ 9. 监控模型表现 │
│ │
│ 归档阶段 │
│ ───────────────────────────────────────────────────── │
│ 10. 不再使用的版本归档 │
│ 11. 保留历史记录 │
│ │
└─────────────────────────────────────────────────────────────┘
获取特定阶段的模型 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
versions = client.get_latest_versions("my_model", stages=["Production"])
for version in versions:
print(f"Production Version: {version.version}")
模型别名 #
设置模型别名 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.set_registered_model_alias(
name="my_model",
alias="champion",
version="5"
)
client.set_registered_model_alias(
name="my_model",
alias="challenger",
version="6"
)
删除模型别名 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.delete_registered_model_alias(
name="my_model",
alias="champion"
)
使用别名加载模型 #
python
import mlflow.pyfunc
model = mlflow.pyfunc.load_model("models:/my_model@champion")
predictions = model.predict(data)
别名最佳实践 #
text
┌─────────────────────────────────────────────────────────────┐
│ 别名最佳实践 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 推荐别名命名: │
│ ├── champion: 当前生产环境模型 │
│ ├── challenger: 挑战者模型(待测试) │
│ ├── canary: 金丝雀发布模型 │
│ └── backup: 备份模型 │
│ │
│ 别名 vs 阶段: │
│ ├── 阶段:固定的生命周期状态 │
│ ├── 别名:灵活的引用方式 │
│ └── 可以同时使用两者 │
│ │
└─────────────────────────────────────────────────────────────┘
模型标签 #
设置模型标签 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.set_registered_model_tag(
name="my_model",
key="project",
value="customer-churn"
)
client.set_model_version_tag(
name="my_model",
version="1",
key="validated",
value="true"
)
删除模型标签 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.delete_registered_model_tag(
name="my_model",
key="project"
)
client.delete_model_version_tag(
name="my_model",
version="1",
key="validated"
)
模型描述 #
设置模型描述 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.update_registered_model(
name="my_model",
description="Customer churn prediction model using Random Forest"
)
client.update_model_version(
name="my_model",
version="1",
description="Initial version with basic features"
)
加载注册模型 #
通过名称和版本加载 #
python
import mlflow.sklearn
model = mlflow.sklearn.load_model("models:/my_model/1")
predictions = model.predict(X_test)
通过阶段加载 #
python
import mlflow.sklearn
model = mlflow.sklearn.load_model("models:/my_model/Production")
model = mlflow.sklearn.load_model("models:/my_model/Staging")
通过别名加载 #
python
import mlflow.sklearn
model = mlflow.sklearn.load_model("models:/my_model@champion")
获取最新版本 #
python
import mlflow.sklearn
model = mlflow.sklearn.load_model("models:/my_model/latest")
删除模型 #
删除模型版本 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.delete_model_version(
name="my_model",
version="1"
)
删除注册模型 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
for version in client.search_model_versions("name='my_model'"):
client.delete_model_version("my_model", version.version)
client.delete_registered_model("my_model")
重命名模型 #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.rename_registered_model(
name="old_model_name",
new_name="new_model_name"
)
Webhooks #
创建 Webhook #
python
from mlflow.tracking import MlflowClient
client = MlflowClient()
client.create_webhook(
name="model-promotion-webhook",
events=["MODEL_VERSION_TRANSITIONED_STAGE"],
description="Triggered when model stage changes",
status="ACTIVE",
job_spec={
"job_type": "HTTP",
"url": "https://example.com/webhook",
"secret": "my-secret"
}
)
Webhook 事件类型 #
text
┌─────────────────────────────────────────────────────────────┐
│ Webhook 事件类型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ MODEL_VERSION_CREATED │
│ ───────────────────────────────────────────────────────── │
│ 新模型版本创建时触发 │
│ │
│ MODEL_VERSION_TRANSITIONED_STAGE │
│ ───────────────────────────────────────────────────────── │
│ 模型阶段转换时触发 │
│ │
│ REGISTERED_MODEL_CREATED │
│ ───────────────────────────────────────────────────────── │
│ 新注册模型创建时触发 │
│ │
│ COMMENT_CREATED │
│ ───────────────────────────────────────────────────────── │
│ 新评论创建时触发 │
│ │
└─────────────────────────────────────────────────────────────┘
完整工作流示例 #
python
import mlflow
import mlflow.sklearn
from mlflow.tracking import MlflowClient
from sklearn.ensemble import RandomForestClassifier
client = MlflowClient()
model_name = "customer-churn-model"
try:
client.create_registered_model(
name=model_name,
tags={"project": "churn", "team": "ml"},
description="Customer churn prediction model"
)
except Exception:
print(f"Model {model_name} already exists")
with mlflow.start_run() as run:
model = RandomForestClassifier(n_estimators=100, max_depth=10)
model.fit(X_train, y_train)
mlflow.log_params({"n_estimators": 100, "max_depth": 10})
mlflow.log_metric("accuracy", model.score(X_test, y_test))
model_uri = mlflow.sklearn.log_model(
model,
"model",
registered_model_name=model_name
)
model_version = client.create_model_version(
name=model_name,
source=f"runs:/{run.info.run_id}/model",
description="Version with improved features"
)
client.transition_model_version_stage(
name=model_name,
version=model_version.version,
stage="Staging"
)
print("Model validation passed, promoting to Production")
client.transition_model_version_stage(
name=model_name,
version=model_version.version,
stage="Production",
archive_existing_versions=True
)
client.set_registered_model_alias(
name=model_name,
alias="champion",
version=model_version.version
)
production_model = mlflow.sklearn.load_model(f"models:/{model_name}/Production")
最佳实践 #
1. 命名规范 #
python
model_name = f"{project}-{task}-{version}"
model_name = "customer-churn-prediction-v1"
2. 版本描述 #
python
client.update_model_version(
name="my_model",
version="1",
description="""
Version 1.0.0
Features:
- Added feature engineering
- Improved accuracy by 5%
Training:
- Dataset: customer_data_v2.csv
- Run ID: abc123
"""
)
3. 阶段转换审批 #
python
def promote_model(model_name, version, stage):
client = MlflowClient()
current_version = client.get_model_version(model_name, version)
if current_version.status != "READY":
raise ValueError("Model version is not ready")
client.transition_model_version_stage(
name=model_name,
version=version,
stage=stage,
archive_existing_versions=(stage == "Production")
)
client.set_model_version_tag(
name=model_name,
version=version,
key="promoted_by",
value="data-team"
)
下一步 #
现在你已经掌握了 Model Registry 的核心功能,接下来学习 模型部署,了解如何将模型部署到生产环境!
最后更新:2026-04-04