DSPy 最佳实践 #

项目结构 #

推荐目录结构 #

text
project/
├── dspy_modules/
│   ├── __init__.py
│   ├── signatures.py
│   ├── modules.py
│   └── retrievers.py
├── data/
│   ├── train.json
│   ├── val.json
│   └── test.json
├── configs/
│   ├── development.yaml
│   ├── production.yaml
│   └── testing.yaml
├── optimized/
│   └── saved_modules/
├── scripts/
│   ├── optimize.py
│   └── evaluate.py
├── tests/
│   ├── test_signatures.py
│   ├── test_modules.py
│   └── test_integration.py
├── main.py
└── requirements.txt

模块组织 #

python
class signatures.py
from dataclasses import dataclass
import dspy

class QuestionAnswer(dspy.Signature):
    """根据上下文回答问题"""
    context = dspy.InputField(desc="相关上下文")
    question = dspy.InputField(desc="用户问题")
    answer = dspy.OutputField(desc="基于上下文的回答")

class Summarize(dspy.Signature):
    """生成文档摘要"""
    document = dspy.InputField(desc="原始文档")
    summary = dspy.OutputField(desc="简洁摘要")

class Classify(dspy.Signature):
    """文本分类"""
    text = dspy.InputField(desc="待分类文本")
    category = dspy.OutputField(desc="分类结果")
python
class modules.py
import dspy
from .signatures import QuestionAnswer, Summarize, Classify

class RAG(dspy.Module):
    def __init__(self, k=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=k)
        self.generate = dspy.ChainOfThought(QuestionAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        return self.generate(context=context, question=question)

class TextProcessor(dspy.Module):
    def __init__(self):
        super().__init__()
        self.summarize = dspy.Predict(Summarize)
        self.classify = dspy.Predict(Classify)
    
    def forward(self, text):
        summary = self.summarize(document=text).summary
        category = self.classify(text=text).category
        return dspy.Prediction(summary=summary, category=category)

配置管理 #

配置文件 #

yaml
class configs/production.yaml
lm:
  model: "openai/gpt-4o-mini"
  api_key: "${OPENAI_API_KEY}"
  temperature: 0.0
  max_tokens: 1000

rm:
  type: "chroma"
  persist_directory: "./chroma_db"
  collection_name: "documents"

optimization:
  optimizer: "BootstrapFewShot"
  max_bootstrapped_demos: 4
  max_labeled_demos: 16

retrieval:
  k: 3
  rerank: true

配置加载 #

python
import yaml
import os
import dspy

def load_config(config_path):
    with open(config_path) as f:
        config = yaml.safe_load(f)
    
    config['lm']['api_key'] = os.environ.get('OPENAI_API_KEY')
    
    return config

def setup_dspy(config):
    lm = dspy.LM(
        config['lm']['model'],
        api_key=config['lm']['api_key'],
        temperature=config['lm']['temperature']
    )
    dspy.configure(lm=lm)
    
    if config.get('rm'):
        rm = setup_retriever(config['rm'])
        dspy.configure(rm=rm)

config = load_config('configs/production.yaml')
setup_dspy(config)

数据管理 #

数据加载 #

python
import json
from dspy import Example

def load_dataset(path):
    with open(path) as f:
        data = json.load(f)
    
    return [
        Example(**item).with_inputs(*item.keys() - {'answer'})
        for item in data
    ]

trainset = load_dataset('data/train.json')
valset = load_dataset('data/val.json')
testset = load_dataset('data/test.json')

数据验证 #

python
def validate_example(example, signature):
    required_inputs = signature.input_fields.keys()
    for field in required_inputs:
        if not hasattr(example, field):
            raise ValueError(f"Missing required field: {field}")
    return True

def validate_dataset(dataset, signature):
    for i, example in enumerate(dataset):
        try:
            validate_example(example, signature)
        except ValueError as e:
            print(f"Invalid example at index {i}: {e}")
            return False
    return True

性能优化 #

缓存策略 #

python
import dspy
from functools import lru_cache
import hashlib

class CachedModule:
    def __init__(self, module, cache_size=1000):
        self.module = module
        self.cache_size = cache_size
        self._setup_cache()
    
    def _setup_cache(self):
        @lru_cache(maxsize=self.cache_size)
        def cached_forward(cache_key):
            return self.module(**cache_key)
        
        self._cached_forward = cached_forward
    
    def __call__(self, **kwargs):
        cache_key = self._make_cache_key(kwargs)
        return self._cached_forward(cache_key)
    
    def _make_cache_key(self, kwargs):
        key_str = str(sorted(kwargs.items()))
        return hashlib.md5(key_str.encode()).hexdigest()

批处理优化 #

python
import dspy
from concurrent.futures import ThreadPoolExecutor

class BatchProcessor:
    def __init__(self, module, batch_size=10, max_workers=4):
        self.module = module
        self.batch_size = batch_size
        self.max_workers = max_workers
    
    def process(self, inputs_list):
        results = []
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = [
                executor.submit(self.module, **inputs)
                for inputs in inputs_list
            ]
            results = [f.result() for f in futures]
        
        return results

连接池管理 #

python
import dspy
from contextlib import contextmanager

class ConnectionPool:
    def __init__(self, max_connections=10):
        self.max_connections = max_connections
        self.connections = []
    
    @contextmanager
    def get_connection(self):
        conn = self._acquire()
        try:
            yield conn
        finally:
            self._release(conn)
    
    def _acquire(self):
        pass
    
    def _release(self, conn):
        pass

错误处理 #

统一错误处理 #

python
import dspy
from typing import Optional
from dataclasses import dataclass

@dataclass
class ModuleResult:
    success: bool
    result: Optional[dspy.Prediction] = None
    error: Optional[str] = None

class SafeModule(dspy.Module):
    def __init__(self, module, max_retries=3):
        super().__init__()
        self.module = module
        self.max_retries = max_retries
    
    def forward(self, **kwargs):
        for attempt in range(self.max_retries):
            try:
                result = self.module(**kwargs)
                return ModuleResult(success=True, result=result)
            except Exception as e:
                error_msg = str(e)
                if attempt == self.max_retries - 1:
                    return ModuleResult(success=False, error=error_msg)
        
        return ModuleResult(success=False, error="Max retries exceeded")

降级策略 #

python
import dspy

class FallbackChain:
    def __init__(self, modules):
        self.modules = modules
    
    def __call__(self, **kwargs):
        errors = []
        
        for module in self.modules:
            try:
                return module(**kwargs)
            except Exception as e:
                errors.append(str(e))
                continue
        
        raise RuntimeError(f"All modules failed: {errors}")

安全考虑 #

API 密钥管理 #

python
import os
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.environ.get('OPENAI_API_KEY')
if not API_KEY:
    raise ValueError("OPENAI_API_KEY not found in environment")

lm = dspy.LM('openai/gpt-4o-mini', api_key=API_KEY)

输入验证 #

python
import dspy
import re

class SecureModule:
    def __init__(self, module):
        self.module = module
        self.max_input_length = 10000
        self.blocked_patterns = [
            r'system\s*:',
            r'ignore\s+previous',
            r'disregard\s+instructions'
        ]
    
    def __call__(self, **kwargs):
        for key, value in kwargs.items():
            if isinstance(value, str):
                self._validate_input(value)
        
        return self.module(**kwargs)
    
    def _validate_input(self, text):
        if len(text) > self.max_input_length:
            raise ValueError(f"Input too long: {len(text)} > {self.max_input_length}")
        
        for pattern in self.blocked_patterns:
            if re.search(pattern, text, re.IGNORECASE):
                raise ValueError(f"Blocked pattern detected: {pattern}")

输出过滤 #

python
import dspy
import re

class OutputFilter:
    def __init__(self, module, sensitive_patterns=None):
        self.module = module
        self.sensitive_patterns = sensitive_patterns or [
            r'\b\d{16}\b',
            r'\b\d{3}-\d{2}-\d{4}\b',
            r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
        ]
    
    def __call__(self, **kwargs):
        result = self.module(**kwargs)
        return self._filter_output(result)
    
    def _filter_output(self, result):
        for field in result._fields:
            value = getattr(result, field)
            if isinstance(value, str):
                filtered = self._redact_sensitive(value)
                setattr(result, field, filtered)
        return result
    
    def _redact_sensitive(self, text):
        for pattern in self.sensitive_patterns:
            text = re.sub(pattern, '[REDACTED]', text)
        return text

监控与日志 #

结构化日志 #

python
import dspy
import logging
import json
from datetime import datetime

class StructuredLogger:
    def __init__(self, name):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(logging.INFO)
        
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(message)s'))
        self.logger.addHandler(handler)
    
    def log_call(self, module_name, inputs, outputs, duration):
        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'module': module_name,
            'inputs': inputs,
            'outputs': outputs,
            'duration_ms': duration * 1000
        }
        self.logger.info(json.dumps(log_entry))

class LoggedModule(dspy.Module):
    def __init__(self, module, logger):
        super().__init__()
        self.module = module
        self.logger = logger
    
    def forward(self, **kwargs):
        import time
        start = time.time()
        
        result = self.module(**kwargs)
        
        duration = time.time() - start
        self.logger.log_call(
            self.module.__class__.__name__,
            kwargs,
            result._asdict(),
            duration
        )
        
        return result

性能监控 #

python
import dspy
import time
from collections import defaultdict

class PerformanceMonitor:
    def __init__(self):
        self.metrics = defaultdict(list)
    
    def record(self, module_name, duration, success):
        self.metrics[module_name].append({
            'duration': duration,
            'success': success,
            'timestamp': time.time()
        })
    
    def get_stats(self, module_name):
        records = self.metrics[module_name]
        if not records:
            return None
        
        durations = [r['duration'] for r in records]
        successes = [r['success'] for r in records]
        
        return {
            'count': len(records),
            'avg_duration': sum(durations) / len(durations),
            'success_rate': sum(successes) / len(successes),
            'p50_duration': sorted(durations)[len(durations) // 2],
            'p99_duration': sorted(durations)[int(len(durations) * 0.99)]
        }

monitor = PerformanceMonitor()

class MonitoredModule(dspy.Module):
    def __init__(self, module, monitor):
        super().__init__()
        self.module = module
        self.monitor = monitor
    
    def forward(self, **kwargs):
        start = time.time()
        success = True
        
        try:
            result = self.module(**kwargs)
            return result
        except Exception as e:
            success = False
            raise
        finally:
            duration = time.time() - start
            self.monitor.record(
                self.module.__class__.__name__,
                duration,
                success
            )

部署策略 #

Docker 部署 #

dockerfile
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

ENV OPENAI_API_KEY=""
ENV DSPY_CACHE_DIR="/app/cache"

CMD ["python", "main.py"]

API 服务 #

python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import dspy

app = FastAPI()

lm = dspy.LM('openai/gpt-4o-mini')
dspy.configure(lm=lm)

rag = RAG()
rag.load('optimized/rag.json')

class QueryRequest(BaseModel):
    question: str

class QueryResponse(BaseModel):
    answer: str

@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
    try:
        result = rag(question=request.question)
        return QueryResponse(answer=result.answer)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health():
    return {"status": "healthy"}

负载均衡 #

python
import dspy
from typing import List
import random

class LoadBalancedLM:
    def __init__(self, models: List[str], api_keys: List[str]):
        self.models = list(zip(models, api_keys))
    
    def get_lm(self):
        model, api_key = random.choice(self.models)
        return dspy.LM(model, api_key=api_key)
    
    def __call__(self, prompt, **kwargs):
        lm = self.get_lm()
        return lm(prompt, **kwargs)

balanced_lm = LoadBalancedLM(
    models=['openai/gpt-4o-mini', 'openai/gpt-3.5-turbo'],
    api_keys=[key1, key2]
)
dspy.configure(lm=balanced_lm)

测试策略 #

单元测试 #

python
import pytest
import dspy
from dspy_modules.signatures import QuestionAnswer
from dspy_modules.modules import RAG

class TestSignatures:
    def test_question_answer_fields(self):
        assert 'context' in QuestionAnswer.input_fields
        assert 'question' in QuestionAnswer.input_fields
        assert 'answer' in QuestionAnswer.output_fields

class TestModules:
    @pytest.fixture
    def rag(self):
        return RAG(k=3)
    
    def test_rag_forward(self, rag):
        result = rag(question="测试问题")
        assert hasattr(result, 'answer')
        assert result.answer is not None

集成测试 #

python
import pytest
import dspy
from dspy_modules import RAG
from dspy import Example

class TestIntegration:
    @pytest.fixture
    def setup_dspy(self):
        lm = dspy.LM('openai/gpt-4o-mini')
        dspy.configure(lm=lm)
    
    @pytest.fixture
    def test_data(self):
        return [
            Example(question="测试问题1", answer="测试答案1").with_inputs("question"),
            Example(question="测试问题2", answer="测试答案2").with_inputs("question"),
        ]
    
    def test_end_to_end(self, setup_dspy, test_data):
        rag = RAG()
        for example in test_data:
            result = rag(question=example.question)
            assert result.answer is not None

成本优化 #

Token 计数 #

python
import dspy
import tiktoken

class TokenCounter:
    def __init__(self, model="gpt-4o-mini"):
        self.encoding = tiktoken.encoding_for_model(model)
        self.total_tokens = 0
    
    def count_tokens(self, text):
        return len(self.encoding.encode(text))
    
    def estimate_cost(self, input_tokens, output_tokens, model="gpt-4o-mini"):
        prices = {
            "gpt-4o-mini": {"input": 0.15 / 1_000_000, "output": 0.60 / 1_000_000},
            "gpt-4o": {"input": 2.50 / 1_000_000, "output": 10.00 / 1_000_000},
        }
        
        price = prices.get(model, prices["gpt-4o-mini"])
        return (
            input_tokens * price["input"] +
            output_tokens * price["output"]
        )

成本控制 #

python
import dspy

class CostControlledModule:
    def __init__(self, module, max_cost_per_day=10.0):
        self.module = module
        self.max_cost = max_cost_per_day
        self.current_cost = 0.0
    
    def __call__(self, **kwargs):
        estimated_cost = self._estimate_cost(kwargs)
        
        if self.current_cost + estimated_cost > self.max_cost:
            raise RuntimeError(f"Daily cost limit exceeded: {self.current_cost} + {estimated_cost} > {self.max_cost}")
        
        result = self.module(**kwargs)
        self.current_cost += estimated_cost
        
        return result
    
    def _estimate_cost(self, kwargs):
        return 0.001

总结 #

DSPy 是一个强大的框架,通过遵循这些最佳实践,你可以:

  1. 组织良好的项目结构 - 提高代码可维护性
  2. 完善的配置管理 - 支持多环境部署
  3. 有效的性能优化 - 降低成本、提高响应速度
  4. 健壮的错误处理 - 提高系统稳定性
  5. 严格的安全措施 - 保护敏感信息
  6. 全面的监控日志 - 便于问题排查
  7. 灵活的部署策略 - 适应不同场景
  8. 完善的测试覆盖 - 保证代码质量

继续探索 DSPy 的更多功能,构建更强大的 LLM 应用!

最后更新:2026-03-30