生产部署 #

部署概述 #

text
┌─────────────────────────────────────────────────────────────┐
│                   生产部署架构                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  负载均衡层                                                 │
│  └── Nginx / ALB                                           │
│                                                             │
│  API 网关层                                                 │
│  ├── 认证授权                                               │
│  ├── 限流控制                                               │
│  └── 请求路由                                               │
│                                                             │
│  服务层                                                     │
│  ├── 模型服务(vLLM/TGI)                                   │
│  ├── 业务服务                                               │
│  └── 缓存服务                                               │
│                                                             │
│  基础设施                                                   │
│  ├── 监控告警                                               │
│  ├── 日志收集                                               │
│  └── 配置管理                                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

部署方案 #

vLLM 部署 #

python
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine

class ProductionModel:
    def __init__(self, model_path, tensor_parallel_size=1):
        self.llm = LLM(
            model=model_path,
            tensor_parallel_size=tensor_parallel_size,
            trust_remote_code=True,
            dtype="float16",
            gpu_memory_utilization=0.9
        )
    
    def generate(
        self,
        prompts: list,
        max_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9
    ) -> list:
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens
        )
        
        outputs = self.llm.generate(prompts, sampling_params)
        
        return [output.outputs[0].text for output in outputs]

model = ProductionModel("models/merged", tensor_parallel_size=2)

TGI 部署 #

bash
启动 TGI 服务:

docker run --gpus all --shm-size 1g -p 8080:80 \
  ghcr.io/huggingface/text-generation-inference:latest \
  --model-id models/merged \
  --num-shard 2 \
  --max-input-length 2048 \
  --max-total-tokens 4096 \
  --max-batch-size 32
python
import requests

class TGIClient:
    def __init__(self, base_url="http://localhost:8080"):
        self.base_url = base_url
    
    def generate(self, prompt, max_new_tokens=512, temperature=0.7):
        response = requests.post(
            f"{self.base_url}/generate",
            json={
                "inputs": prompt,
                "parameters": {
                    "max_new_tokens": max_new_tokens,
                    "temperature": temperature,
                    "do_sample": True
                }
            }
        )
        
        return response.json()["generated_text"]

client = TGIClient()
response = client.generate("你好,请介绍一下自己。")

服务架构 #

FastAPI 服务 #

python
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Optional
import asyncio
from datetime import datetime
import logging

app = FastAPI(
    title="LLM API",
    description="大语言模型推理服务",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.9
    stream: bool = False

class GenerateResponse(BaseModel):
    text: str
    tokens_generated: int
    latency_ms: float

model = None

@app.on_event("startup")
async def startup():
    global model
    model = ProductionModel("models/merged")
    logging.info("Model loaded successfully")

@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    start_time = datetime.now()
    
    results = model.generate(
        [request.prompt],
        max_tokens=request.max_tokens,
        temperature=request.temperature,
        top_p=request.top_p
    )
    
    latency = (datetime.now() - start_time).total_seconds() * 1000
    
    return GenerateResponse(
        text=results[0],
        tokens_generated=len(results[0].split()),
        latency_ms=latency
    )

@app.post("/generate/stream")
async def generate_stream(request: GenerateRequest):
    async def stream_generator():
        for token in ["你", "好", ",", "世", "界"]:
            yield f"data: {token}\n\n"
            await asyncio.sleep(0.1)
    
    return StreamingResponse(
        stream_generator(),
        media_type="text/event-stream"
    )

@app.get("/health")
async def health():
    return {"status": "healthy", "model_loaded": model is not None}

认证授权 #

python
from fastapi import Security, HTTPException
from fastapi.security import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN

api_key_header = APIKeyHeader(name="X-API-Key")

async def get_api_key(api_key: str = Security(api_key_header)):
    valid_keys = {"your-api-key-1", "your-api-key-2"}
    
    if api_key not in valid_keys:
        raise HTTPException(
            status_code=HTTP_403_FORBIDDEN,
            detail="Invalid API Key"
        )
    
    return api_key

@app.post("/generate")
async def generate(
    request: GenerateRequest,
    api_key: str = Depends(get_api_key)
):
    pass

限流控制 #

python
from fastapi import FastAPI, Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
import redis.asyncio as redis

@app.on_event("startup")
async def startup():
    redis_client = redis.from_url("redis://localhost:6379")
    await FastAPILimiter.init(redis_client)

@app.post(
    "/generate",
    dependencies=[Depends(RateLimiter(times=100, seconds=60))]
)
async def generate(request: GenerateRequest):
    pass

监控告警 #

Prometheus 监控 #

python
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from fastapi import Response

REQUEST_COUNT = Counter(
    'llm_requests_total',
    'Total LLM requests',
    ['method', 'endpoint', 'status']
)

REQUEST_LATENCY = Histogram(
    'llm_request_latency_seconds',
    'Request latency',
    ['method', 'endpoint']
)

ACTIVE_REQUESTS = Gauge(
    'llm_active_requests',
    'Active requests'
)

@app.middleware("http")
async def add_metrics(request: Request, call_next):
    ACTIVE_REQUESTS.inc()
    
    start_time = datetime.now()
    
    response = await call_next(request)
    
    latency = (datetime.now() - start_time).total_seconds()
    
    REQUEST_COUNT.labels(
        method=request.method,
        endpoint=request.url.path,
        status=response.status_code
    ).inc()
    
    REQUEST_LATENCY.labels(
        method=request.method,
        endpoint=request.url.path
    ).observe(latency)
    
    ACTIVE_REQUESTS.dec()
    
    return response

@app.get("/metrics")
async def metrics():
    return Response(
        content=generate_latest(),
        media_type="text/plain"
    )

日志收集 #

python
import logging
from logging.handlers import RotatingFileHandler
import json
from datetime import datetime

class JSONFormatter(logging.Formatter):
    def format(self, record):
        log_entry = {
            "timestamp": datetime.utcnow().isoformat(),
            "level": record.levelname,
            "message": record.getMessage(),
            "module": record.module,
            "function": record.funcName,
            "line": record.lineno
        }
        
        if hasattr(record, 'request_id'):
            log_entry['request_id'] = record.request_id
        
        return json.dumps(log_entry)

logging.basicConfig(
    level=logging.INFO,
    handlers=[
        RotatingFileHandler(
            'logs/app.log',
            maxBytes=100*1024*1024,
            backupCount=10
        )
    ]
)

logger = logging.getLogger(__name__)
logger.handlers[0].setFormatter(JSONFormatter())

告警配置 #

yaml
prometheus/alertmanager/alertmanager.yml:

global:
  resolve_timeout: 5m

route:
  group_by: ['alertname']
  group_wait: 10s
  group_interval: 10s
  repeat_interval: 1h
  receiver: 'default'

receivers:
  - name: 'default'
    email_configs:
      - to: 'team@example.com'
        from: 'alertmanager@example.com'
        smarthost: 'smtp.example.com:587'

prometheus/alerts.yml:

groups:
  - name: llm_alerts
    rules:
      - alert: HighLatency
        expr: histogram_quantile(0.95, rate(llm_request_latency_seconds_bucket[5m])) > 5
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: "High request latency"
          
      - alert: HighErrorRate
        expr: rate(llm_requests_total{status=~"5.."}[5m]) / rate(llm_requests_total[5m]) > 0.1
        for: 5m
        labels:
          severity: critical
        annotations:
          summary: "High error rate"

扩展部署 #

Kubernetes 部署 #

yaml
k8s/deployment.yaml:

apiVersion: apps/v1
kind: Deployment
metadata:
  name: llm-api
spec:
  replicas: 3
  selector:
    matchLabels:
      app: llm-api
  template:
    metadata:
      labels:
        app: llm-api
    spec:
      containers:
      - name: llm-api
        image: llm-api:latest
        ports:
        - containerPort: 8000
        resources:
          limits:
            nvidia.com/gpu: 1
            memory: "16Gi"
          requests:
            nvidia.com/gpu: 1
            memory: "8Gi"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 5
          periodSeconds: 5
        env:
        - name: MODEL_PATH
          value: "/models/merged"
        volumeMounts:
        - name: model-storage
          mountPath: /models
      volumes:
      - name: model-storage
        persistentVolumeClaim:
          claimName: model-pvc
yaml
k8s/service.yaml:

apiVersion: v1
kind: Service
metadata:
  name: llm-api-service
spec:
  selector:
    app: llm-api
  ports:
  - port: 80
    targetPort: 8000
  type: LoadBalancer
yaml
k8s/hpa.yaml:

apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: llm-api-hpa
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: llm-api
  minReplicas: 2
  maxReplicas: 10
  metrics:
  - type: Resource
    resource:
      name: cpu
      target:
        type: Utilization
        averageUtilization: 70
  - type: Resource
    resource:
      name: memory
      target:
        type: Utilization
        averageUtilization: 80

Docker Compose #

yaml
docker-compose.yml:

version: '3.8'

services:
  llm-api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - MODEL_PATH=/models/merged
      - CUDA_VISIBLE_DEVICES=0
    volumes:
      - ./models:/models
      - ./logs:/app/logs
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]
    depends_on:
      - redis
      - prometheus
  
  redis:
    image: redis:alpine
    ports:
      - "6379:6379"
  
  prometheus:
    image: prom/prometheus
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus:/etc/prometheus
  
  grafana:
    image: grafana/grafana
    ports:
      - "3000:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=admin
    volumes:
      - grafana-storage:/var/lib/grafana

volumes:
  grafana-storage:

性能优化 #

批处理优化 #

python
from typing import List
import asyncio
from collections import deque

class BatchProcessor:
    def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time
        self.queue = deque()
        self.lock = asyncio.Lock()
    
    async def generate(self, prompt: str) -> str:
        future = asyncio.Future()
        
        async with self.lock:
            self.queue.append((prompt, future))
            
            if len(self.queue) >= self.max_batch_size:
                await self._process_batch()
            else:
                asyncio.create_task(self._wait_and_process())
        
        return await future
    
    async def _wait_and_process(self):
        await asyncio.sleep(self.max_wait_time)
        async with self.lock:
            if self.queue:
                await self._process_batch()
    
    async def _process_batch(self):
        batch = []
        futures = []
        
        while self.queue and len(batch) < self.max_batch_size:
            prompt, future = self.queue.popleft()
            batch.append(prompt)
            futures.append(future)
        
        if batch:
            results = await asyncio.to_thread(
                self.model.generate,
                batch
            )
            
            for future, result in zip(futures, results):
                future.set_result(result)

batch_processor = BatchProcessor(model)

@app.post("/generate")
async def generate(request: GenerateRequest):
    result = await batch_processor.generate(request.prompt)
    return {"text": result}

缓存优化 #

python
from functools import lru_cache
import hashlib

class SemanticCache:
    def __init__(self, similarity_threshold=0.95):
        self.cache = {}
        self.similarity_threshold = similarity_threshold
    
    def get(self, prompt: str):
        prompt_hash = hashlib.md5(prompt.encode()).hexdigest()
        
        if prompt_hash in self.cache:
            return self.cache[prompt_hash]
        
        return None
    
    def set(self, prompt: str, response: str):
        prompt_hash = hashlib.md5(prompt.encode()).hexdigest()
        self.cache[prompt_hash] = response

cache = SemanticCache()

@app.post("/generate")
async def generate(request: GenerateRequest):
    cached = cache.get(request.prompt)
    if cached:
        return {"text": cached, "cached": True}
    
    result = model.generate([request.prompt])[0]
    cache.set(request.prompt, result)
    
    return {"text": result, "cached": False}

运维实践 #

健康检查 #

python
@app.get("/health/detailed")
async def detailed_health():
    checks = {
        "model": model is not None,
        "gpu": torch.cuda.is_available(),
        "memory": {
            "allocated": torch.cuda.memory_allocated() / 1024**3,
            "reserved": torch.cuda.memory_reserved() / 1024**3
        }
    }
    
    all_healthy = all(checks.values()) if isinstance(list(checks.values())[0], bool) else checks["model"] and checks["gpu"]
    
    return {
        "status": "healthy" if all_healthy else "unhealthy",
        "checks": checks
    }

优雅关闭 #

python
import signal
import sys

shutdown_event = asyncio.Event()

def signal_handler(signum, frame):
    logging.info("Received shutdown signal")
    shutdown_event.set()

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

@app.on_event("shutdown")
async def shutdown():
    logging.info("Shutting down gracefully...")
    await shutdown_event.wait()

最佳实践 #

text
部署最佳实践:

1. 模型优化
   ├── 合并 LoRA 权重
   ├── 量化部署
   └── 批处理推理

2. 服务设计
   ├── 健康检查
   ├── 优雅关闭
   └── 错误处理

3. 监控告警
   ├── 性能监控
   ├── 错误追踪
   └── 资源监控

4. 安全措施
   ├── 认证授权
   ├── 限流控制
   └── 输入验证

5. 高可用
   ├── 多副本部署
   ├── 负载均衡
   └── 故障恢复

总结 #

恭喜你完成了 Fine-tuning 完全指南的学习!你已经掌握了从基础概念到生产部署的完整知识体系。继续实践,不断优化,成为微调专家!

最后更新:2026-04-05