高级配置 #

认证与授权 #

启用认证 #

bash
mlflow server \
    --backend-store-uri postgresql://user:pass@host:5432/mlflow \
    --default-artifact-root s3://my-bucket/mlruns \
    --app-name basic-auth \
    --host 0.0.0.0 \
    --port 5000

配置认证 #

python
import mlflow

mlflow.set_tracking_uri("http://localhost:5000")

os.environ["MLFLOW_TRACKING_USERNAME"] = "admin"
os.environ["MLFLOW_TRACKING_PASSWORD"] = "password"

用户管理 #

python
from mlflow.server.auth import create_user, update_user_password

create_user("username", "password")

update_user_password("username", "new_password")

权限管理 #

text
┌─────────────────────────────────────────────────────────────┐
│                    MLflow 权限级别                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  READ                                                       │
│  ─────────────────────────────────────────────────────────  │
│  查看实验和运行                                             │
│  查看模型和版本                                             │
│                                                             │
│  EDIT                                                       │
│  ─────────────────────────────────────────────────────────  │
│  创建和修改实验                                             │
│  记录参数、指标和工件                                       │
│  注册模型版本                                               │
│                                                             │
│  MANAGE                                                     │
│  ─────────────────────────────────────────────────────────  │
│  删除实验和运行                                             │
│  管理模型阶段转换                                           │
│  管理用户权限                                               │
│                                                             │
│  ADMIN                                                      │
│  ─────────────────────────────────────────────────────────  │
│  完全管理权限                                               │
│  系统配置                                                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

分布式跟踪 #

配置远程 Tracking Server #

python
import mlflow

mlflow.set_tracking_uri("http://tracking-server:5000")

mlflow.set_tracking_uri("databricks://my-profile")

mlflow.set_tracking_uri("postgresql://user:pass@host:5432/mlflow")

多工作进程配置 #

bash
mlflow server \
    --backend-store-uri postgresql://user:pass@host:5432/mlflow \
    --default-artifact-root s3://my-bucket/mlruns \
    --host 0.0.0.0 \
    --port 5000 \
    --workers 4 \
    --gunicorn-opts "--timeout 120 --keep-alive 5"

负载均衡配置 #

nginx
upstream mlflow_backend {
    server mlflow-server-1:5000;
    server mlflow-server-2:5000;
    server mlflow-server-3:5000;
}

server {
    listen 80;
    server_name mlflow.example.com;

    location / {
        proxy_pass http://mlflow_backend;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;
    }
}

存储优化 #

数据库优化 #

sql
CREATE INDEX idx_runs_experiment_id ON runs(experiment_id);
CREATE INDEX idx_runs_status ON runs(status);
CREATE INDEX idx_runs_start_time ON runs(start_time);
CREATE INDEX idx_params_key ON params(key);
CREATE INDEX idx_metrics_key ON metrics(key);

工件存储优化 #

text
┌─────────────────────────────────────────────────────────────┐
│                    工件存储优化策略                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 使用对象存储                                             │
│     ─────────────────────────────────────────────────────   │
│     - AWS S3                                                │
│     - Azure Blob Storage                                    │
│     - Google Cloud Storage                                  │
│     - MinIO (自托管)                                        │
│                                                             │
│  2. 生命周期管理                                             │
│     ─────────────────────────────────────────────────────   │
│     - 设置自动过期策略                                      │
│     - 归档旧版本工件                                        │
│     - 压缩存储                                              │
│                                                             │
│  3. 缓存策略                                                 │
│     ─────────────────────────────────────────────────────   │
│     - 本地缓存频繁访问的工件                                │
│     - CDN 加速                                              │
│     - 预热缓存                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

配置 S3 生命周期 #

json
{
    "Rules": [
        {
            "ID": "ArchiveOldArtifacts",
            "Status": "Enabled",
            "Filter": {
                "Prefix": "mlruns/"
            },
            "Transitions": [
                {
                    "Days": 90,
                    "StorageClass": "STANDARD_IA"
                },
                {
                    "Days": 180,
                    "StorageClass": "GLACIER"
                }
            ]
        }
    ]
}

性能优化 #

连接池配置 #

python
import mlflow
from sqlalchemy import create_engine
from sqlalchemy.pool import QueuePool

engine = create_engine(
    "postgresql://user:pass@host:5432/mlflow",
    poolclass=QueuePool,
    pool_size=10,
    max_overflow=20,
    pool_pre_ping=True
)

批量写入优化 #

python
import mlflow

with mlflow.start_run():
    metrics = []
    for i in range(1000):
        metrics.append(("metric_name", i * 0.1, i))
    
    for name, value, step in metrics:
        mlflow.log_metric(name, value, step=step)

异步日志 #

python
import mlflow
from concurrent.futures import ThreadPoolExecutor
import threading

class AsyncMLflowLogger:
    def __init__(self, max_workers=4):
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.run_id = None
    
    def start_run(self):
        run = mlflow.start_run()
        self.run_id = run.info.run_id
        return run
    
    def log_metric_async(self, key, value, step=None):
        def _log():
            mlflow.log_metric(key, value, step=step, run_id=self.run_id)
        self.executor.submit(_log)
    
    def log_params_async(self, params):
        def _log():
            mlflow.log_params(params, run_id=self.run_id)
        self.executor.submit(_log)
    
    def finish(self):
        mlflow.end_run()
        self.executor.shutdown()

logger = AsyncMLflowLogger()
logger.start_run()
logger.log_params_async({"lr": 0.01, "epochs": 100})
logger.finish()

插件开发 #

创建自定义 Artifact Store #

python
from mlflow.store.artifact.artifact_repo import ArtifactRepository

class CustomArtifactRepository(ArtifactRepository):
    
    def __init__(self, artifact_uri):
        super().__init__(artifact_uri)
    
    def log_artifact(self, local_file, artifact_path=None):
        pass
    
    def log_artifacts(self, local_dir, artifact_path=None):
        pass
    
    def list_artifacts(self, path=None):
        pass
    
    def download_artifacts(self, artifact_path, dst_path=None):
        pass
    
    def delete_artifacts(self, artifact_path=None):
        pass

注册插件 #

python
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository

def get_custom_artifact_repository(artifact_uri):
    return CustomArtifactRepository(artifact_uri)

from mlflow.store.artifact.artifact_repository_registry import artifact_repository_registry
artifact_repository_registry.register("custom://", get_custom_artifact_repository)

创建自定义 Model Flavor #

python
import mlflow.pyfunc
import pandas as pd

class CustomModelWrapper(mlflow.pyfunc.PythonModel):
    
    def __init__(self, model):
        self.model = model
    
    def load_context(self, context):
        pass
    
    def predict(self, context, model_input):
        if isinstance(model_input, pd.DataFrame):
            return self.model.predict(model_input.values)
        return self.model.predict(model_input)

def log_custom_model(model, artifact_path, **kwargs):
    mlflow.pyfunc.log_model(
        artifact_path=artifact_path,
        python_model=CustomModelWrapper(model),
        **kwargs
    )

监控与告警 #

Prometheus 集成 #

python
from prometheus_client import Counter, Histogram, start_http_server
import mlflow

PREDICTION_COUNT = Counter('mlflow_predictions_total', 'Total predictions')
PREDICTION_LATENCY = Histogram('mlflow_prediction_latency_seconds', 'Prediction latency')

start_http_server(9090)

@PREDICTION_LATENCY.time()
def predict_with_metrics(model, data):
    PREDICTION_COUNT.inc()
    return model.predict(data)

健康检查端点 #

python
from fastapi import FastAPI, Response
import mlflow

app = FastAPI()

@app.get("/health")
def health_check():
    try:
        mlflow.search_experiments(max_results=1)
        return {"status": "healthy"}
    except Exception as e:
        return Response(
            content={"status": "unhealthy", "error": str(e)},
            status_code=503
        )

@app.get("/ready")
def readiness_check():
    try:
        mlflow.search_experiments(max_results=1)
        return {"status": "ready"}
    except Exception:
        return Response(
            content={"status": "not ready"},
            status_code=503
        )

告警配置示例 #

yaml
groups:
- name: mlflow_alerts
  rules:
  - alert: MLflowServerDown
    expr: up{job="mlflow"} == 0
    for: 1m
    labels:
      severity: critical
    annotations:
      summary: "MLflow server is down"
      
  - alert: HighLatency
    expr: histogram_quantile(0.95, rate(mlflow_prediction_latency_seconds_bucket[5m])) > 1
    for: 5m
    labels:
      severity: warning
    annotations:
      summary: "High prediction latency detected"

安全配置 #

HTTPS 配置 #

nginx
server {
    listen 443 ssl;
    server_name mlflow.example.com;

    ssl_certificate /path/to/cert.pem;
    ssl_certificate_key /path/to/key.pem;
    ssl_protocols TLSv1.2 TLSv1.3;

    location / {
        proxy_pass http://mlflow_backend;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;
    }
}

网络安全 #

yaml
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
  name: mlflow-network-policy
spec:
  podSelector:
    matchLabels:
      app: mlflow
  policyTypes:
  - Ingress
  - Egress
  ingress:
  - from:
    - namespaceSelector:
        matchLabels:
          name: production
    ports:
    - protocol: TCP
      port: 5000
  egress:
  - to:
    - namespaceSelector:
        matchLabels:
          name: database
    ports:
    - protocol: TCP
      port: 5432

备份与恢复 #

数据库备份 #

bash
pg_dump -h localhost -U mlflow -d mlflow > mlflow_backup_$(date +%Y%m%d).sql

工件备份 #

bash
aws s3 sync s3://my-bucket/mlruns s3://backup-bucket/mlruns-backup-$(date +%Y%m%d)

恢复流程 #

bash
psql -h localhost -U mlflow -d mlflow < mlflow_backup_20240101.sql

aws s3 sync s3://backup-bucket/mlruns-backup-20240101 s3://my-bucket/mlruns

配置最佳实践 #

环境变量配置 #

bash
export MLFLOW_TRACKING_URI="http://localhost:5000"
export MLFLOW_REGISTRY_URI="http://localhost:5000"
export MLFLOW_S3_ENDPOINT_URL="https://s3.amazonaws.com"
export MLFLOW_EXPERIMENT_NAME="my-experiment"
export MLFLOW_RUN_NAME="my-run"

配置文件 #

yaml
mlflow:
  tracking_uri: http://localhost:5000
  registry_uri: http://localhost:5000
  experiment_name: my-experiment
  
  backend_store:
    type: postgresql
    host: localhost
    port: 5432
    database: mlflow
    user: mlflow
    password: ${MLFLOW_DB_PASSWORD}
  
  artifact_store:
    type: s3
    bucket: my-bucket
    prefix: mlruns
  
  server:
    host: 0.0.0.0
    port: 5000
    workers: 4

日志配置 #

python
import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('mlflow.log'),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger('mlflow')

故障排查 #

常见问题 #

text
┌─────────────────────────────────────────────────────────────┐
│                    常见问题排查                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 连接超时                                                 │
│     ─────────────────────────────────────────────────────   │
│     - 检查网络连接                                          │
│     - 检查防火墙设置                                        │
│     - 增加 timeout 设置                                     │
│                                                             │
│  2. 认证失败                                                 │
│     ─────────────────────────────────────────────────────   │
│     - 检查用户名密码                                        │
│     - 检查 token 是否过期                                   │
│     - 检查权限配置                                          │
│                                                             │
│  3. 存储问题                                                 │
│     ─────────────────────────────────────────────────────   │
│     - 检查存储空间                                          │
│     - 检查访问权限                                          │
│     - 检查连接配置                                          │
│                                                             │
│  4. 性能问题                                                 │
│     ─────────────────────────────────────────────────────   │
│     - 检查数据库索引                                        │
│     - 检查网络延迟                                          │
│     - 检查资源使用                                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

调试模式 #

python
import mlflow
import logging

logging.basicConfig(level=logging.DEBUG)

mlflow.set_tracking_uri("http://localhost:5000")

with mlflow.start_run():
    mlflow.log_param("debug", True)

下一步 #

现在你已经掌握了 MLflow 的高级配置,接下来学习 PyTorch 集成,了解如何在实际项目中使用 MLflow!

最后更新:2026-04-04