生产部署 #
部署概述 #
text
┌─────────────────────────────────────────────────────────────┐
│ 生产部署架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 负载均衡层 │
│ └── Nginx / ALB │
│ │
│ API 网关层 │
│ ├── 认证授权 │
│ ├── 限流控制 │
│ └── 请求路由 │
│ │
│ 服务层 │
│ ├── 模型服务(vLLM/TGI) │
│ ├── 业务服务 │
│ └── 缓存服务 │
│ │
│ 基础设施 │
│ ├── 监控告警 │
│ ├── 日志收集 │
│ └── 配置管理 │
│ │
└─────────────────────────────────────────────────────────────┘
部署方案 #
vLLM 部署 #
python
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
class ProductionModel:
def __init__(self, model_path, tensor_parallel_size=1):
self.llm = LLM(
model=model_path,
tensor_parallel_size=tensor_parallel_size,
trust_remote_code=True,
dtype="float16",
gpu_memory_utilization=0.9
)
def generate(
self,
prompts: list,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9
) -> list:
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
outputs = self.llm.generate(prompts, sampling_params)
return [output.outputs[0].text for output in outputs]
model = ProductionModel("models/merged", tensor_parallel_size=2)
TGI 部署 #
bash
启动 TGI 服务:
docker run --gpus all --shm-size 1g -p 8080:80 \
ghcr.io/huggingface/text-generation-inference:latest \
--model-id models/merged \
--num-shard 2 \
--max-input-length 2048 \
--max-total-tokens 4096 \
--max-batch-size 32
python
import requests
class TGIClient:
def __init__(self, base_url="http://localhost:8080"):
self.base_url = base_url
def generate(self, prompt, max_new_tokens=512, temperature=0.7):
response = requests.post(
f"{self.base_url}/generate",
json={
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": True
}
}
)
return response.json()["generated_text"]
client = TGIClient()
response = client.generate("你好,请介绍一下自己。")
服务架构 #
FastAPI 服务 #
python
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Optional
import asyncio
from datetime import datetime
import logging
app = FastAPI(
title="LLM API",
description="大语言模型推理服务",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
stream: bool = False
class GenerateResponse(BaseModel):
text: str
tokens_generated: int
latency_ms: float
model = None
@app.on_event("startup")
async def startup():
global model
model = ProductionModel("models/merged")
logging.info("Model loaded successfully")
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = datetime.now()
results = model.generate(
[request.prompt],
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
)
latency = (datetime.now() - start_time).total_seconds() * 1000
return GenerateResponse(
text=results[0],
tokens_generated=len(results[0].split()),
latency_ms=latency
)
@app.post("/generate/stream")
async def generate_stream(request: GenerateRequest):
async def stream_generator():
for token in ["你", "好", ",", "世", "界"]:
yield f"data: {token}\n\n"
await asyncio.sleep(0.1)
return StreamingResponse(
stream_generator(),
media_type="text/event-stream"
)
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": model is not None}
认证授权 #
python
from fastapi import Security, HTTPException
from fastapi.security import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(api_key: str = Security(api_key_header)):
valid_keys = {"your-api-key-1", "your-api-key-2"}
if api_key not in valid_keys:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid API Key"
)
return api_key
@app.post("/generate")
async def generate(
request: GenerateRequest,
api_key: str = Depends(get_api_key)
):
pass
限流控制 #
python
from fastapi import FastAPI, Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
import redis.asyncio as redis
@app.on_event("startup")
async def startup():
redis_client = redis.from_url("redis://localhost:6379")
await FastAPILimiter.init(redis_client)
@app.post(
"/generate",
dependencies=[Depends(RateLimiter(times=100, seconds=60))]
)
async def generate(request: GenerateRequest):
pass
监控告警 #
Prometheus 监控 #
python
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from fastapi import Response
REQUEST_COUNT = Counter(
'llm_requests_total',
'Total LLM requests',
['method', 'endpoint', 'status']
)
REQUEST_LATENCY = Histogram(
'llm_request_latency_seconds',
'Request latency',
['method', 'endpoint']
)
ACTIVE_REQUESTS = Gauge(
'llm_active_requests',
'Active requests'
)
@app.middleware("http")
async def add_metrics(request: Request, call_next):
ACTIVE_REQUESTS.inc()
start_time = datetime.now()
response = await call_next(request)
latency = (datetime.now() - start_time).total_seconds()
REQUEST_COUNT.labels(
method=request.method,
endpoint=request.url.path,
status=response.status_code
).inc()
REQUEST_LATENCY.labels(
method=request.method,
endpoint=request.url.path
).observe(latency)
ACTIVE_REQUESTS.dec()
return response
@app.get("/metrics")
async def metrics():
return Response(
content=generate_latest(),
media_type="text/plain"
)
日志收集 #
python
import logging
from logging.handlers import RotatingFileHandler
import json
from datetime import datetime
class JSONFormatter(logging.Formatter):
def format(self, record):
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno
}
if hasattr(record, 'request_id'):
log_entry['request_id'] = record.request_id
return json.dumps(log_entry)
logging.basicConfig(
level=logging.INFO,
handlers=[
RotatingFileHandler(
'logs/app.log',
maxBytes=100*1024*1024,
backupCount=10
)
]
)
logger = logging.getLogger(__name__)
logger.handlers[0].setFormatter(JSONFormatter())
告警配置 #
yaml
prometheus/alertmanager/alertmanager.yml:
global:
resolve_timeout: 5m
route:
group_by: ['alertname']
group_wait: 10s
group_interval: 10s
repeat_interval: 1h
receiver: 'default'
receivers:
- name: 'default'
email_configs:
- to: 'team@example.com'
from: 'alertmanager@example.com'
smarthost: 'smtp.example.com:587'
prometheus/alerts.yml:
groups:
- name: llm_alerts
rules:
- alert: HighLatency
expr: histogram_quantile(0.95, rate(llm_request_latency_seconds_bucket[5m])) > 5
for: 5m
labels:
severity: warning
annotations:
summary: "High request latency"
- alert: HighErrorRate
expr: rate(llm_requests_total{status=~"5.."}[5m]) / rate(llm_requests_total[5m]) > 0.1
for: 5m
labels:
severity: critical
annotations:
summary: "High error rate"
扩展部署 #
Kubernetes 部署 #
yaml
k8s/deployment.yaml:
apiVersion: apps/v1
kind: Deployment
metadata:
name: llm-api
spec:
replicas: 3
selector:
matchLabels:
app: llm-api
template:
metadata:
labels:
app: llm-api
spec:
containers:
- name: llm-api
image: llm-api:latest
ports:
- containerPort: 8000
resources:
limits:
nvidia.com/gpu: 1
memory: "16Gi"
requests:
nvidia.com/gpu: 1
memory: "8Gi"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
env:
- name: MODEL_PATH
value: "/models/merged"
volumeMounts:
- name: model-storage
mountPath: /models
volumes:
- name: model-storage
persistentVolumeClaim:
claimName: model-pvc
yaml
k8s/service.yaml:
apiVersion: v1
kind: Service
metadata:
name: llm-api-service
spec:
selector:
app: llm-api
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
yaml
k8s/hpa.yaml:
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: llm-api-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: llm-api
minReplicas: 2
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
Docker Compose #
yaml
docker-compose.yml:
version: '3.8'
services:
llm-api:
build: .
ports:
- "8000:8000"
environment:
- MODEL_PATH=/models/merged
- CUDA_VISIBLE_DEVICES=0
volumes:
- ./models:/models
- ./logs:/app/logs
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
depends_on:
- redis
- prometheus
redis:
image: redis:alpine
ports:
- "6379:6379"
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus:/etc/prometheus
grafana:
image: grafana/grafana
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
- grafana-storage:/var/lib/grafana
volumes:
grafana-storage:
性能优化 #
批处理优化 #
python
from typing import List
import asyncio
from collections import deque
class BatchProcessor:
def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_time = max_wait_time
self.queue = deque()
self.lock = asyncio.Lock()
async def generate(self, prompt: str) -> str:
future = asyncio.Future()
async with self.lock:
self.queue.append((prompt, future))
if len(self.queue) >= self.max_batch_size:
await self._process_batch()
else:
asyncio.create_task(self._wait_and_process())
return await future
async def _wait_and_process(self):
await asyncio.sleep(self.max_wait_time)
async with self.lock:
if self.queue:
await self._process_batch()
async def _process_batch(self):
batch = []
futures = []
while self.queue and len(batch) < self.max_batch_size:
prompt, future = self.queue.popleft()
batch.append(prompt)
futures.append(future)
if batch:
results = await asyncio.to_thread(
self.model.generate,
batch
)
for future, result in zip(futures, results):
future.set_result(result)
batch_processor = BatchProcessor(model)
@app.post("/generate")
async def generate(request: GenerateRequest):
result = await batch_processor.generate(request.prompt)
return {"text": result}
缓存优化 #
python
from functools import lru_cache
import hashlib
class SemanticCache:
def __init__(self, similarity_threshold=0.95):
self.cache = {}
self.similarity_threshold = similarity_threshold
def get(self, prompt: str):
prompt_hash = hashlib.md5(prompt.encode()).hexdigest()
if prompt_hash in self.cache:
return self.cache[prompt_hash]
return None
def set(self, prompt: str, response: str):
prompt_hash = hashlib.md5(prompt.encode()).hexdigest()
self.cache[prompt_hash] = response
cache = SemanticCache()
@app.post("/generate")
async def generate(request: GenerateRequest):
cached = cache.get(request.prompt)
if cached:
return {"text": cached, "cached": True}
result = model.generate([request.prompt])[0]
cache.set(request.prompt, result)
return {"text": result, "cached": False}
运维实践 #
健康检查 #
python
@app.get("/health/detailed")
async def detailed_health():
checks = {
"model": model is not None,
"gpu": torch.cuda.is_available(),
"memory": {
"allocated": torch.cuda.memory_allocated() / 1024**3,
"reserved": torch.cuda.memory_reserved() / 1024**3
}
}
all_healthy = all(checks.values()) if isinstance(list(checks.values())[0], bool) else checks["model"] and checks["gpu"]
return {
"status": "healthy" if all_healthy else "unhealthy",
"checks": checks
}
优雅关闭 #
python
import signal
import sys
shutdown_event = asyncio.Event()
def signal_handler(signum, frame):
logging.info("Received shutdown signal")
shutdown_event.set()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
@app.on_event("shutdown")
async def shutdown():
logging.info("Shutting down gracefully...")
await shutdown_event.wait()
最佳实践 #
text
部署最佳实践:
1. 模型优化
├── 合并 LoRA 权重
├── 量化部署
└── 批处理推理
2. 服务设计
├── 健康检查
├── 优雅关闭
└── 错误处理
3. 监控告警
├── 性能监控
├── 错误追踪
└── 资源监控
4. 安全措施
├── 认证授权
├── 限流控制
└── 输入验证
5. 高可用
├── 多副本部署
├── 负载均衡
└── 故障恢复
总结 #
恭喜你完成了 Fine-tuning 完全指南的学习!你已经掌握了从基础概念到生产部署的完整知识体系。继续实践,不断优化,成为微调专家!
最后更新:2026-04-05