DSPy 最佳实践 #
项目结构 #
推荐目录结构 #
text
project/
├── dspy_modules/
│ ├── __init__.py
│ ├── signatures.py
│ ├── modules.py
│ └── retrievers.py
├── data/
│ ├── train.json
│ ├── val.json
│ └── test.json
├── configs/
│ ├── development.yaml
│ ├── production.yaml
│ └── testing.yaml
├── optimized/
│ └── saved_modules/
├── scripts/
│ ├── optimize.py
│ └── evaluate.py
├── tests/
│ ├── test_signatures.py
│ ├── test_modules.py
│ └── test_integration.py
├── main.py
└── requirements.txt
模块组织 #
python
class signatures.py
from dataclasses import dataclass
import dspy
class QuestionAnswer(dspy.Signature):
"""根据上下文回答问题"""
context = dspy.InputField(desc="相关上下文")
question = dspy.InputField(desc="用户问题")
answer = dspy.OutputField(desc="基于上下文的回答")
class Summarize(dspy.Signature):
"""生成文档摘要"""
document = dspy.InputField(desc="原始文档")
summary = dspy.OutputField(desc="简洁摘要")
class Classify(dspy.Signature):
"""文本分类"""
text = dspy.InputField(desc="待分类文本")
category = dspy.OutputField(desc="分类结果")
python
class modules.py
import dspy
from .signatures import QuestionAnswer, Summarize, Classify
class RAG(dspy.Module):
def __init__(self, k=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=k)
self.generate = dspy.ChainOfThought(QuestionAnswer)
def forward(self, question):
context = self.retrieve(question).passages
return self.generate(context=context, question=question)
class TextProcessor(dspy.Module):
def __init__(self):
super().__init__()
self.summarize = dspy.Predict(Summarize)
self.classify = dspy.Predict(Classify)
def forward(self, text):
summary = self.summarize(document=text).summary
category = self.classify(text=text).category
return dspy.Prediction(summary=summary, category=category)
配置管理 #
配置文件 #
yaml
class configs/production.yaml
lm:
model: "openai/gpt-4o-mini"
api_key: "${OPENAI_API_KEY}"
temperature: 0.0
max_tokens: 1000
rm:
type: "chroma"
persist_directory: "./chroma_db"
collection_name: "documents"
optimization:
optimizer: "BootstrapFewShot"
max_bootstrapped_demos: 4
max_labeled_demos: 16
retrieval:
k: 3
rerank: true
配置加载 #
python
import yaml
import os
import dspy
def load_config(config_path):
with open(config_path) as f:
config = yaml.safe_load(f)
config['lm']['api_key'] = os.environ.get('OPENAI_API_KEY')
return config
def setup_dspy(config):
lm = dspy.LM(
config['lm']['model'],
api_key=config['lm']['api_key'],
temperature=config['lm']['temperature']
)
dspy.configure(lm=lm)
if config.get('rm'):
rm = setup_retriever(config['rm'])
dspy.configure(rm=rm)
config = load_config('configs/production.yaml')
setup_dspy(config)
数据管理 #
数据加载 #
python
import json
from dspy import Example
def load_dataset(path):
with open(path) as f:
data = json.load(f)
return [
Example(**item).with_inputs(*item.keys() - {'answer'})
for item in data
]
trainset = load_dataset('data/train.json')
valset = load_dataset('data/val.json')
testset = load_dataset('data/test.json')
数据验证 #
python
def validate_example(example, signature):
required_inputs = signature.input_fields.keys()
for field in required_inputs:
if not hasattr(example, field):
raise ValueError(f"Missing required field: {field}")
return True
def validate_dataset(dataset, signature):
for i, example in enumerate(dataset):
try:
validate_example(example, signature)
except ValueError as e:
print(f"Invalid example at index {i}: {e}")
return False
return True
性能优化 #
缓存策略 #
python
import dspy
from functools import lru_cache
import hashlib
class CachedModule:
def __init__(self, module, cache_size=1000):
self.module = module
self.cache_size = cache_size
self._setup_cache()
def _setup_cache(self):
@lru_cache(maxsize=self.cache_size)
def cached_forward(cache_key):
return self.module(**cache_key)
self._cached_forward = cached_forward
def __call__(self, **kwargs):
cache_key = self._make_cache_key(kwargs)
return self._cached_forward(cache_key)
def _make_cache_key(self, kwargs):
key_str = str(sorted(kwargs.items()))
return hashlib.md5(key_str.encode()).hexdigest()
批处理优化 #
python
import dspy
from concurrent.futures import ThreadPoolExecutor
class BatchProcessor:
def __init__(self, module, batch_size=10, max_workers=4):
self.module = module
self.batch_size = batch_size
self.max_workers = max_workers
def process(self, inputs_list):
results = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(self.module, **inputs)
for inputs in inputs_list
]
results = [f.result() for f in futures]
return results
连接池管理 #
python
import dspy
from contextlib import contextmanager
class ConnectionPool:
def __init__(self, max_connections=10):
self.max_connections = max_connections
self.connections = []
@contextmanager
def get_connection(self):
conn = self._acquire()
try:
yield conn
finally:
self._release(conn)
def _acquire(self):
pass
def _release(self, conn):
pass
错误处理 #
统一错误处理 #
python
import dspy
from typing import Optional
from dataclasses import dataclass
@dataclass
class ModuleResult:
success: bool
result: Optional[dspy.Prediction] = None
error: Optional[str] = None
class SafeModule(dspy.Module):
def __init__(self, module, max_retries=3):
super().__init__()
self.module = module
self.max_retries = max_retries
def forward(self, **kwargs):
for attempt in range(self.max_retries):
try:
result = self.module(**kwargs)
return ModuleResult(success=True, result=result)
except Exception as e:
error_msg = str(e)
if attempt == self.max_retries - 1:
return ModuleResult(success=False, error=error_msg)
return ModuleResult(success=False, error="Max retries exceeded")
降级策略 #
python
import dspy
class FallbackChain:
def __init__(self, modules):
self.modules = modules
def __call__(self, **kwargs):
errors = []
for module in self.modules:
try:
return module(**kwargs)
except Exception as e:
errors.append(str(e))
continue
raise RuntimeError(f"All modules failed: {errors}")
安全考虑 #
API 密钥管理 #
python
import os
from dotenv import load_dotenv
load_dotenv()
API_KEY = os.environ.get('OPENAI_API_KEY')
if not API_KEY:
raise ValueError("OPENAI_API_KEY not found in environment")
lm = dspy.LM('openai/gpt-4o-mini', api_key=API_KEY)
输入验证 #
python
import dspy
import re
class SecureModule:
def __init__(self, module):
self.module = module
self.max_input_length = 10000
self.blocked_patterns = [
r'system\s*:',
r'ignore\s+previous',
r'disregard\s+instructions'
]
def __call__(self, **kwargs):
for key, value in kwargs.items():
if isinstance(value, str):
self._validate_input(value)
return self.module(**kwargs)
def _validate_input(self, text):
if len(text) > self.max_input_length:
raise ValueError(f"Input too long: {len(text)} > {self.max_input_length}")
for pattern in self.blocked_patterns:
if re.search(pattern, text, re.IGNORECASE):
raise ValueError(f"Blocked pattern detected: {pattern}")
输出过滤 #
python
import dspy
import re
class OutputFilter:
def __init__(self, module, sensitive_patterns=None):
self.module = module
self.sensitive_patterns = sensitive_patterns or [
r'\b\d{16}\b',
r'\b\d{3}-\d{2}-\d{4}\b',
r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
]
def __call__(self, **kwargs):
result = self.module(**kwargs)
return self._filter_output(result)
def _filter_output(self, result):
for field in result._fields:
value = getattr(result, field)
if isinstance(value, str):
filtered = self._redact_sensitive(value)
setattr(result, field, filtered)
return result
def _redact_sensitive(self, text):
for pattern in self.sensitive_patterns:
text = re.sub(pattern, '[REDACTED]', text)
return text
监控与日志 #
结构化日志 #
python
import dspy
import logging
import json
from datetime import datetime
class StructuredLogger:
def __init__(self, name):
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(message)s'))
self.logger.addHandler(handler)
def log_call(self, module_name, inputs, outputs, duration):
log_entry = {
'timestamp': datetime.now().isoformat(),
'module': module_name,
'inputs': inputs,
'outputs': outputs,
'duration_ms': duration * 1000
}
self.logger.info(json.dumps(log_entry))
class LoggedModule(dspy.Module):
def __init__(self, module, logger):
super().__init__()
self.module = module
self.logger = logger
def forward(self, **kwargs):
import time
start = time.time()
result = self.module(**kwargs)
duration = time.time() - start
self.logger.log_call(
self.module.__class__.__name__,
kwargs,
result._asdict(),
duration
)
return result
性能监控 #
python
import dspy
import time
from collections import defaultdict
class PerformanceMonitor:
def __init__(self):
self.metrics = defaultdict(list)
def record(self, module_name, duration, success):
self.metrics[module_name].append({
'duration': duration,
'success': success,
'timestamp': time.time()
})
def get_stats(self, module_name):
records = self.metrics[module_name]
if not records:
return None
durations = [r['duration'] for r in records]
successes = [r['success'] for r in records]
return {
'count': len(records),
'avg_duration': sum(durations) / len(durations),
'success_rate': sum(successes) / len(successes),
'p50_duration': sorted(durations)[len(durations) // 2],
'p99_duration': sorted(durations)[int(len(durations) * 0.99)]
}
monitor = PerformanceMonitor()
class MonitoredModule(dspy.Module):
def __init__(self, module, monitor):
super().__init__()
self.module = module
self.monitor = monitor
def forward(self, **kwargs):
start = time.time()
success = True
try:
result = self.module(**kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = time.time() - start
self.monitor.record(
self.module.__class__.__name__,
duration,
success
)
部署策略 #
Docker 部署 #
dockerfile
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
ENV OPENAI_API_KEY=""
ENV DSPY_CACHE_DIR="/app/cache"
CMD ["python", "main.py"]
API 服务 #
python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import dspy
app = FastAPI()
lm = dspy.LM('openai/gpt-4o-mini')
dspy.configure(lm=lm)
rag = RAG()
rag.load('optimized/rag.json')
class QueryRequest(BaseModel):
question: str
class QueryResponse(BaseModel):
answer: str
@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
try:
result = rag(question=request.question)
return QueryResponse(answer=result.answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "healthy"}
负载均衡 #
python
import dspy
from typing import List
import random
class LoadBalancedLM:
def __init__(self, models: List[str], api_keys: List[str]):
self.models = list(zip(models, api_keys))
def get_lm(self):
model, api_key = random.choice(self.models)
return dspy.LM(model, api_key=api_key)
def __call__(self, prompt, **kwargs):
lm = self.get_lm()
return lm(prompt, **kwargs)
balanced_lm = LoadBalancedLM(
models=['openai/gpt-4o-mini', 'openai/gpt-3.5-turbo'],
api_keys=[key1, key2]
)
dspy.configure(lm=balanced_lm)
测试策略 #
单元测试 #
python
import pytest
import dspy
from dspy_modules.signatures import QuestionAnswer
from dspy_modules.modules import RAG
class TestSignatures:
def test_question_answer_fields(self):
assert 'context' in QuestionAnswer.input_fields
assert 'question' in QuestionAnswer.input_fields
assert 'answer' in QuestionAnswer.output_fields
class TestModules:
@pytest.fixture
def rag(self):
return RAG(k=3)
def test_rag_forward(self, rag):
result = rag(question="测试问题")
assert hasattr(result, 'answer')
assert result.answer is not None
集成测试 #
python
import pytest
import dspy
from dspy_modules import RAG
from dspy import Example
class TestIntegration:
@pytest.fixture
def setup_dspy(self):
lm = dspy.LM('openai/gpt-4o-mini')
dspy.configure(lm=lm)
@pytest.fixture
def test_data(self):
return [
Example(question="测试问题1", answer="测试答案1").with_inputs("question"),
Example(question="测试问题2", answer="测试答案2").with_inputs("question"),
]
def test_end_to_end(self, setup_dspy, test_data):
rag = RAG()
for example in test_data:
result = rag(question=example.question)
assert result.answer is not None
成本优化 #
Token 计数 #
python
import dspy
import tiktoken
class TokenCounter:
def __init__(self, model="gpt-4o-mini"):
self.encoding = tiktoken.encoding_for_model(model)
self.total_tokens = 0
def count_tokens(self, text):
return len(self.encoding.encode(text))
def estimate_cost(self, input_tokens, output_tokens, model="gpt-4o-mini"):
prices = {
"gpt-4o-mini": {"input": 0.15 / 1_000_000, "output": 0.60 / 1_000_000},
"gpt-4o": {"input": 2.50 / 1_000_000, "output": 10.00 / 1_000_000},
}
price = prices.get(model, prices["gpt-4o-mini"])
return (
input_tokens * price["input"] +
output_tokens * price["output"]
)
成本控制 #
python
import dspy
class CostControlledModule:
def __init__(self, module, max_cost_per_day=10.0):
self.module = module
self.max_cost = max_cost_per_day
self.current_cost = 0.0
def __call__(self, **kwargs):
estimated_cost = self._estimate_cost(kwargs)
if self.current_cost + estimated_cost > self.max_cost:
raise RuntimeError(f"Daily cost limit exceeded: {self.current_cost} + {estimated_cost} > {self.max_cost}")
result = self.module(**kwargs)
self.current_cost += estimated_cost
return result
def _estimate_cost(self, kwargs):
return 0.001
总结 #
DSPy 是一个强大的框架,通过遵循这些最佳实践,你可以:
- 组织良好的项目结构 - 提高代码可维护性
- 完善的配置管理 - 支持多环境部署
- 有效的性能优化 - 降低成本、提高响应速度
- 健壮的错误处理 - 提高系统稳定性
- 严格的安全措施 - 保护敏感信息
- 全面的监控日志 - 便于问题排查
- 灵活的部署策略 - 适应不同场景
- 完善的测试覆盖 - 保证代码质量
继续探索 DSPy 的更多功能,构建更强大的 LLM 应用!
最后更新:2026-03-30