模型部署 #

部署概述 #

MLflow 提供了多种模型部署方式,从本地测试到生产环境部署,满足不同场景的需求。

text
┌─────────────────────────────────────────────────────────────┐
│                    MLflow 部署架构                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                    MLflow Model                      │   │
│  │  ├── 统一模型格式                                    │   │
│  │  ├── 多 Flavor 支持                                  │   │
│  │  └── 环境依赖打包                                    │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                  │
│                          ▼                                  │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                    部署方式                          │   │
│  │  ├── 本地推理                                        │   │
│  │  ├── REST API 服务                                   │   │
│  │  ├── Docker 容器                                     │   │
│  │  ├── Apache Spark                                    │   │
│  │  ├── AWS SageMaker                                   │   │
│  │  ├── Azure ML                                        │   │
│  │  └── 其他云平台                                      │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

本地推理 #

直接加载模型 #

python
import mlflow.sklearn

model = mlflow.pyfunc.load_model("models:/my_model/Production")

predictions = model.predict(data)

批量预测 #

python
import mlflow.pyfunc
import pandas as pd

model = mlflow.pyfunc.load_model("models:/my_model/Production")

data = pd.read_csv("data/batch_input.csv")

predictions = model.predict(data)

pd.DataFrame(predictions, columns=["prediction"]).to_csv("predictions.csv", index=False)

使用 Spark 批量推理 #

python
import mlflow
import pyspark.sql.functions as F

model_uri = "models:/my_model/Production"

predictions = mlflow.pyfunc.spark_udf(
    spark,
    model_uri,
    result_type="double"
)

df = spark.read.parquet("data/input.parquet")

df_with_predictions = df.withColumn("prediction", predictions(*df.columns))

df_with_predictions.write.parquet("data/predictions.parquet")

REST API 服务 #

启动本地服务 #

bash
mlflow models serve -m "models:/my_model/Production" -p 5001

服务参数 #

text
┌─────────────────────────────────────────────────────────────┐
│                  mlflow models serve 参数                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  -m, --model-uri          模型 URI                          │
│  ─────────────────────────────────────────────────────────  │
│  -m "models:/my_model/Production"                          │
│  -m "runs:/<run_id>/model"                                 │
│                                                             │
│  -p, --port               服务端口                          │
│  ─────────────────────────────────────────────────────────  │
│  -p 5001                                                   │
│                                                             │
│  -h, --host               监听地址                          │
│  ─────────────────────────────────────────────────────────  │
│  -h 0.0.0.0                                                │
│                                                             │
│  --no-conda               不使用 Conda 环境                 │
│  ─────────────────────────────────────────────────────────  │
│                                                             │
│  --enable-mlserver        使用 MLServer                     │
│  ─────────────────────────────────────────────────────────  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

发送预测请求 #

python
import requests
import pandas as pd
import json

data = pd.DataFrame({
    "feature1": [1.0, 2.0, 3.0],
    "feature2": [4.0, 5.0, 6.0]
})

response = requests.post(
    "http://localhost:5001/invocations",
    json={"dataframe_split": data.to_dict(orient="split")},
    headers={"Content-Type": "application/json"}
)

predictions = response.json()
print(predictions)

请求格式 #

text
┌─────────────────────────────────────────────────────────────┐
│                     API 请求格式                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. DataFrame Split 格式                                    │
│     ─────────────────────────────────────────────────────   │
│     {                                                      │
│       "dataframe_split": {                                 │
│         "columns": ["feature1", "feature2"],               │
│         "data": [[1.0, 4.0], [2.0, 5.0]]                   │
│       }                                                    │
│     }                                                      │
│                                                             │
│  2. DataFrame Records 格式                                  │
│     ─────────────────────────────────────────────────────   │
│     {                                                      │
│       "dataframe_records": [                               │
│         {"feature1": 1.0, "feature2": 4.0},                │
│         {"feature1": 2.0, "feature2": 5.0}                 │
│       ]                                                    │
│     }                                                      │
│                                                             │
│  3. Tensor 格式                                             │
│     ─────────────────────────────────────────────────────   │
│     {                                                      │
│       "instances": [                                        │
│         [1.0, 4.0],                                        │
│         [2.0, 5.0]                                         │
│       ]                                                    │
│     }                                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

API 端点 #

text
┌─────────────────────────────────────────────────────────────┐
│                      API 端点                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  POST /invocations                                          │
│  ─────────────────────────────────────────────────────────  │
│  执行模型预测                                               │
│  Content-Type: application/json                            │
│                                                             │
│  GET /health                                                │
│  ─────────────────────────────────────────────────────────  │
│  健康检查                                                   │
│                                                             │
│  GET /version                                               │
│  ─────────────────────────────────────────────────────────  │
│  获取模型版本信息                                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Docker 部署 #

构建 Docker 镜像 #

bash
mlflow models build-docker \
    -m "models:/my_model/Production" \
    -n my-model-image \
    --enable-mlserver

运行 Docker 容器 #

bash
docker run -p 5001:8080 my-model-image

Docker Compose 部署 #

yaml
version: '3.8'

services:
  mlflow-model:
    image: my-model-image:latest
    ports:
      - "5001:8080"
    environment:
      - MLFLOW_TRACKING_URI=http://mlflow-server:5000
    deploy:
      resources:
        limits:
          memory: 2G
          cpus: '1'
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
      interval: 30s
      timeout: 10s
      retries: 3

  mlflow-server:
    image: ghcr.io/mlflow/mlflow:v2.10.0
    ports:
      - "5000:5000"
    environment:
      - MLFLOW_BACKEND_STORE_URI=postgresql://user:pass@postgres:5432/mlflow
      - MLFLOW_DEFAULT_ARTIFACT_ROOT=s3://my-bucket/mlruns
    depends_on:
      - postgres

  postgres:
    image: postgres:14
    environment:
      POSTGRES_USER: user
      POSTGRES_PASSWORD: pass
      POSTGRES_DB: mlflow
    volumes:
      - postgres_data:/var/lib/postgresql/data

volumes:
  postgres_data:

Dockerfile 自定义 #

dockerfile
FROM python:3.10-slim

WORKDIR /app

RUN pip install mlflow scikit-learn pandas

COPY model/ /app/model/

EXPOSE 8080

CMD ["mlflow", "models", "serve", "-m", "/app/model", "-h", "0.0.0.0", "-p", "8080"]

Kubernetes 部署 #

Deployment 配置 #

yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: mlflow-model-server
spec:
  replicas: 3
  selector:
    matchLabels:
      app: mlflow-model
  template:
    metadata:
      labels:
        app: mlflow-model
    spec:
      containers:
      - name: model-server
        image: my-model-image:latest
        ports:
        - containerPort: 8080
        resources:
          requests:
            memory: "1Gi"
            cpu: "500m"
          limits:
            memory: "2Gi"
            cpu: "1000m"
        livenessProbe:
          httpGet:
            path: /health
            port: 8080
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 8080
          initialDelaySeconds: 5
          periodSeconds: 5
        env:
        - name: MLFLOW_TRACKING_URI
          value: "http://mlflow-server:5000"
---
apiVersion: v1
kind: Service
metadata:
  name: mlflow-model-service
spec:
  selector:
    app: mlflow-model
  ports:
  - port: 80
    targetPort: 8080
  type: LoadBalancer
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: mlflow-model-hpa
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: mlflow-model-server
  minReplicas: 2
  maxReplicas: 10
  metrics:
  - type: Resource
    resource:
      name: cpu
      target:
        type: Utilization
        averageUtilization: 70

AWS SageMaker 部署 #

部署到 SageMaker #

python
import mlflow.sagemaker

mlflow.sagemaker.deploy(
    app_name="my-model-endpoint",
    model_uri="models:/my_model/Production",
    execution_role_arn="arn:aws:iam::account-id:role/SageMakerRole",
    region_name="us-west-2",
    mode="create",
    archive=False,
    instance_type="ml.m5.xlarge",
    initial_instance_count=1
)

更新 SageMaker 端点 #

python
import mlflow.sagemaker

mlflow.sagemaker.deploy(
    app_name="my-model-endpoint",
    model_uri="models:/my_model/Production",
    execution_role_arn="arn:aws:iam::account-id:role/SageMakerRole",
    region_name="us-west-2",
    mode="replace",
    archive=False
)

删除 SageMaker 端点 #

python
import mlflow.sagemaker

mlflow.sagemaker.delete(
    app_name="my-model-endpoint",
    region_name="us-west-2"
)

SageMaker 预测 #

python
import boto3
import json

runtime = boto3.client("runtime.sagemaker", region_name="us-west-2")

data = {"instances": [[1.0, 2.0, 3.0]]}

response = runtime.invoke_endpoint(
    EndpointName="my-model-endpoint",
    ContentType="application/json",
    Body=json.dumps(data)
)

predictions = json.loads(response["Body"].read())

Azure ML 部署 #

部署到 Azure ML #

python
import mlflow.azureml
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential

ml_client = MLClient(
    credential=DefaultAzureCredential(),
    subscription_id="subscription-id",
    resource_group_name="resource-group",
    workspace_name="workspace"
)

model_uri = "models:/my_model/Production"

azure_model = mlflow.azureml.build_image(
    model_uri=model_uri,
    client=ml_client,
    model_name="my-model",
    description="Customer churn prediction model",
    tags={"project": "churn"}
)

Apache Spark 部署 #

Spark UDF 部署 #

python
import mlflow
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

model_uri = "models:/my_model/Production"

predict_udf = mlflow.pyfunc.spark_udf(
    spark,
    model_uri,
    result_type="double"
)

df = spark.read.parquet("s3://bucket/data/input.parquet")

df_predictions = df.withColumn(
    "prediction",
    predict_udf("feature1", "feature2", "feature3")
)

df_predictions.write.parquet("s3://bucket/data/predictions.parquet")

Spark 批处理作业 #

python
from pyspark.sql import SparkSession
import mlflow.pyfunc

spark = SparkSession.builder \
    .appName("MLflow Batch Prediction") \
    .getOrCreate()

model = mlflow.pyfunc.load_model("models:/my_model/Production")

def predict_batch(iterator):
    for batch in iterator:
        predictions = model.predict(batch)
        yield predictions

df = spark.read.parquet("s3://bucket/data/input.parquet")

predictions_df = df.mapInPandas(
    predict_batch,
    schema="prediction double"
)

predictions_df.write.parquet("s3://bucket/data/predictions.parquet")

部署最佳实践 #

1. 模型版本管理 #

python
import mlflow

model = mlflow.pyfunc.load_model("models:/my_model@champion")

model = mlflow.pyfunc.load_model("models:/my_model/Production")

2. 健康检查 #

python
from fastapi import FastAPI
import mlflow.pyfunc

app = FastAPI()
model = mlflow.pyfunc.load_model("models:/my_model/Production")

@app.get("/health")
def health_check():
    return {"status": "healthy", "model_loaded": model is not None}

@app.post("/predict")
def predict(data: dict):
    predictions = model.predict(data)
    return {"predictions": predictions}

3. 监控和日志 #

python
import logging
import time
from functools import wraps

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def log_prediction(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        duration = time.time() - start_time
        logger.info(f"Prediction completed in {duration:.3f}s")
        return result
    return wrapper

@log_prediction
def predict(data):
    return model.predict(data)

4. A/B 测试 #

python
import mlflow.pyfunc
import random

model_a = mlflow.pyfunc.load_model("models:/my_model@champion")
model_b = mlflow.pyfunc.load_model("models:/my_model@challenger")

def predict_with_ab_test(data, traffic_split=0.1):
    if random.random() < traffic_split:
        return model_b.predict(data), "model_b"
    else:
        return model_a.predict(data), "model_a"

5. 金丝雀发布 #

python
import mlflow.pyfunc

stable_model = mlflow.pyfunc.load_model("models:/my_model@champion")
canary_model = mlflow.pyfunc.load_model("models:/my_model@canary")

def predict_with_canary(data, canary_percentage=5):
    canary_count = int(len(data) * canary_percentage / 100)
    
    canary_data = data[:canary_count]
    stable_data = data[canary_count:]
    
    canary_predictions = canary_model.predict(canary_data)
    stable_predictions = stable_model.predict(stable_data)
    
    return canary_predictions + stable_predictions

部署检查清单 #

text
┌─────────────────────────────────────────────────────────────┐
│                    部署检查清单                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  模型准备:                                                  │
│  ├── 模型已注册到 Model Registry                            │
│  ├── 模型签名已定义                                         │
│  ├── 环境依赖已打包                                         │
│  └── 输入示例已添加                                         │
│                                                             │
│  部署配置:                                                  │
│  ├── 计算资源已分配                                         │
│  ├── 网络配置已完成                                         │
│  ├── 安全组/防火墙已配置                                    │
│  └── 负载均衡已设置                                         │
│                                                             │
│  监控告警:                                                  │
│  ├── 健康检查已配置                                         │
│  ├── 性能指标监控                                           │
│  ├── 错误日志收集                                           │
│  └── 告警规则已设置                                         │
│                                                             │
│  回滚准备:                                                  │
│  ├── 旧版本模型已归档                                       │
│  ├── 回滚流程已测试                                         │
│  └── 备份策略已实施                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

下一步 #

现在你已经掌握了模型部署的核心知识,接下来学习 高级配置,了解 MLflow 的高级特性和配置!

最后更新:2026-04-04