DSPy 检索器 #

什么是检索器? #

检索器(Retriever)是 DSPy 中用于从知识库获取相关内容的组件。它是构建 RAG(Retrieval-Augmented Generation)应用的核心组件。

text
┌─────────────────────────────────────────────────────────────┐
│                    检索器的作用                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 知识检索:从大量文档中找到与问题相关的内容               │
│  2. 上下文增强:为 LLM 提供准确的背景信息                   │
│  3. 减少幻觉:基于事实生成回答                              │
│  4. 知识更新:无需重新训练即可更新知识                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

RAG 架构概览 #

text
┌─────────────────────────────────────────────────────────────┐
│                    RAG 基本架构                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   ┌─────────┐     ┌─────────┐     ┌─────────┐              │
│   │  用户   │     │  检索   │     │  生成   │              │
│   │  问题   │────>│  模块   │────>│  模块   │              │
│   └─────────┘     └─────────┘     └─────────┘              │
│                        │               │                    │
│                        ▼               ▼                    │
│                   ┌─────────┐     ┌─────────┐              │
│                   │ 知识库  │     │  回答   │              │
│                   │ 向量库  │     │         │              │
│                   └─────────┘     └─────────┘              │
│                                                             │
│   流程:                                                    │
│   1. 用户提问                                               │
│   2. 检索器从知识库获取相关文档                             │
│   3. 生成模块基于检索结果生成回答                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

内置检索器 #

ColBERTv2 #

ColBERTv2 是 DSPy 默认支持的检索器,使用延迟交互模型。

python
import dspy

rm = dspy.ColBERTv2(url='http://localhost:8893/api/search')
dspy.configure(rm=rm)

retrieve = dspy.Retrieve(k=3)
results = retrieve("什么是机器学习?")
print(results.passages)

配置参数 #

python
retrieve = dspy.Retrieve(
    k=3,
    filter={"category": "tech"}
)

向量数据库集成 #

ChromaDB #

python
import dspy
import chromadb
from chromadb.utils import embedding_functions

client = chromadb.PersistentClient(path="./chroma_db")
embedding_func = embedding_functions.DefaultEmbeddingFunction()

collection = client.get_or_create_collection(
    name="documents",
    embedding_function=embedding_func
)

class ChromaRM(dspy.Retrieve):
    def __init__(self, collection, k=3):
        super().__init__(k=k)
        self.collection = collection
    
    def forward(self, query):
        results = self.collection.query(
            query_texts=[query],
            n_results=self.k
        )
        passages = results['documents'][0]
        return dspy.Prediction(passages=passages)

dspy.configure(rm=ChromaRM(collection, k=3))

Pinecone #

python
import dspy
import pinecone

pinecone.init(api_key="your-api-key", environment="us-west1-gcp")
index = pinecone.Index("documents")

class PineconeRM(dspy.Retrieve):
    def __init__(self, index, embed_func, k=3):
        super().__init__(k=k)
        self.index = index
        self.embed_func = embed_func
    
    def forward(self, query):
        query_vector = self.embed_func(query)
        results = self.index.query(
            vector=query_vector,
            top_k=self.k,
            include_metadata=True
        )
        passages = [m['metadata']['text'] for m in results['matches']]
        return dspy.Prediction(passages=passages)

dspy.configure(rm=PineconeRM(index, my_embed_func, k=3))

Weaviate #

python
import dspy
import weaviate

client = weaviate.Client("http://localhost:8080")

class WeaviateRM(dspy.Retrieve):
    def __init__(self, client, class_name, k=3):
        super().__init__(k=k)
        self.client = client
        self.class_name = class_name
    
    def forward(self, query):
        results = (
            self.client.query
            .get(self.class_name, ["content"])
            .with_near_text({"concepts": [query]})
            .with_limit(self.k)
            .do()
        )
        passages = [r['content'] for r in results['data']['Get'][self.class_name]]
        return dspy.Prediction(passages=passages)

dspy.configure(rm=WeaviateRM(client, "Document", k=3))

FAISS #

python
import dspy
import faiss
import numpy as np

class FAISSRM(dspy.Retrieve):
    def __init__(self, index, documents, embed_func, k=3):
        super().__init__(k=k)
        self.index = index
        self.documents = documents
        self.embed_func = embed_func
    
    def forward(self, query):
        query_vector = self.embed_func(query)
        distances, indices = self.index.search(
            np.array([query_vector]), self.k
        )
        passages = [self.documents[i] for i in indices[0]]
        return dspy.Prediction(passages=passages)

dspy.configure(rm=FAISSRM(index, documents, embed_func, k=3))

构建文档索引 #

文档预处理 #

python
def preprocess_documents(raw_docs):
    processed = []
    for doc in raw_docs:
        chunks = split_into_chunks(doc, chunk_size=500, overlap=50)
        processed.extend(chunks)
    return processed

def split_into_chunks(text, chunk_size=500, overlap=50):
    words = text.split()
    chunks = []
    for i in range(0, len(words), chunk_size - overlap):
        chunk = ' '.join(words[i:i + chunk_size])
        chunks.append(chunk)
    return chunks

添加文档到索引 #

python
def index_documents(collection, documents):
    for i, doc in enumerate(documents):
        collection.add(
            ids=[f"doc_{i}"],
            documents=[doc],
            metadatas=[{"source": doc.source}]
        )

构建 RAG 应用 #

基础 RAG #

python
import dspy

lm = dspy.LM('openai/gpt-4o-mini')
rm = dspy.ColBERTv2(url='http://localhost:8893/api/search')
dspy.configure(lm=lm, rm=rm)

class GenerateAnswer(dspy.Signature):
    """根据上下文回答问题"""
    context = dspy.InputField(desc="相关上下文")
    question = dspy.InputField(desc="用户问题")
    answer = dspy.OutputField(desc="基于上下文的回答")

class RAG(dspy.Module):
    def __init__(self, k=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=k)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        return self.generate(context=context, question=question)

rag = RAG()
result = rag(question="什么是机器学习?")
print(result.answer)

多跳检索 #

python
import dspy

class GenerateQuery(dspy.Signature):
    """生成搜索查询"""
    context = dspy.InputField()
    question = dspy.InputField()
    query = dspy.OutputField()

class GenerateAnswer(dspy.Signature):
    """根据上下文回答问题"""
    context = dspy.InputField()
    question = dspy.InputField()
    answer = dspy.OutputField()

class MultiHopRAG(dspy.Module):
    def __init__(self, passages_per_hop=3, max_hops=2):
        super().__init__()
        self.max_hops = max_hops
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_query = dspy.ChainOfThought(GenerateQuery)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = []
        
        for hop in range(self.max_hops):
            query = self.generate_query(
                context=context,
                question=question
            ).query
            passages = self.retrieve(query).passages
            context.extend(passages)
        
        return self.generate_answer(
            context=context,
            question=question
        )

multi_hop = MultiHopRAG(passages_per_hop=3, max_hops=2)
result = multi_hop(question="谁发明了 Python,它有什么特点?")

带来源引用的 RAG #

python
import dspy

class GenerateAnswerWithSource(dspy.Signature):
    """根据上下文回答问题并引用来源"""
    context = dspy.InputField(desc="相关上下文,每段带有编号")
    question = dspy.InputField(desc="用户问题")
    answer = dspy.OutputField(desc="回答内容")
    sources = dspy.OutputField(desc="引用的上下文编号列表")

class SourceRAG(dspy.Module):
    def __init__(self, k=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=k)
        self.generate = dspy.ChainOfThought(GenerateAnswerWithSource)
    
    def forward(self, question):
        passages = self.retrieve(question).passages
        context = self._format_context(passages)
        result = self.generate(context=context, question=question)
        
        cited_passages = [
            passages[i-1] for i in result.sources if i <= len(passages)
        ]
        
        return dspy.Prediction(
            answer=result.answer,
            sources=result.sources,
            cited_passages=cited_passages
        )
    
    def _format_context(self, passages):
        return "\n\n".join(
            f"[{i+1}] {p}" for i, p in enumerate(passages)
        )

source_rag = SourceRAG()
result = source_rag(question="什么是深度学习?")
print(f"回答: {result.answer}")
print(f"引用来源: {result.sources}")

检索优化 #

查询重写 #

python
import dspy

class RewriteQuery(dspy.Signature):
    """优化搜索查询"""
    original_query = dspy.InputField()
    rewritten_query = dspy.OutputField(desc="更精确的搜索查询")

class OptimizedRAG(dspy.Module):
    def __init__(self, k=3):
        super().__init__()
        self.rewrite = dspy.Predict(RewriteQuery)
        self.retrieve = dspy.Retrieve(k=k)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        rewritten = self.rewrite(original_query=question).rewritten_query
        context = self.retrieve(rewritten).passages
        return self.generate(context=context, question=question)

混合检索 #

python
import dspy

class HybridRetriever(dspy.Retrieve):
    def __init__(self, dense_rm, sparse_rm, k=3, alpha=0.5):
        super().__init__(k=k)
        self.dense_rm = dense_rm
        self.sparse_rm = sparse_rm
        self.alpha = alpha
    
    def forward(self, query):
        dense_results = self.dense_rm(query).passages
        sparse_results = self.sparse_rm(query).passages
        
        combined = self._merge_results(dense_results, sparse_results)
        return dspy.Prediction(passages=combined[:self.k])
    
    def _merge_results(self, dense, sparse):
        seen = set()
        merged = []
        for d, s in zip(dense, sparse):
            if d not in seen:
                merged.append(d)
                seen.add(d)
            if s not in seen:
                merged.append(s)
                seen.add(s)
        return merged

重排序 #

python
import dspy

class Rerank(dspy.Signature):
    """对检索结果进行重排序"""
    query = dspy.InputField()
    passages = dspy.InputField(desc="候选段落列表")
    ranked_passages = dspy.OutputField(desc="按相关性排序的段落")

class RerankedRAG(dspy.Module):
    def __init__(self, k=5, final_k=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=k)
        self.rerank = dspy.Predict(Rerank)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
        self.final_k = final_k
    
    def forward(self, question):
        passages = self.retrieve(question).passages
        ranked = self.rerank(
            query=question,
            passages=passages
        ).ranked_passages
        
        top_passages = ranked[:self.final_k]
        return self.generate(context=top_passages, question=question)

评估检索质量 #

检索评估指标 #

python
def retrieval_recall(retrieved, relevant, k=None):
    if k:
        retrieved = retrieved[:k]
    return len(set(retrieved) & set(relevant)) / len(relevant)

def retrieval_precision(retrieved, relevant, k=None):
    if k:
        retrieved = retrieved[:k]
    return len(set(retrieved) & set(relevant)) / len(retrieved)

def mean_reciprocal_rank(results_list):
    scores = []
    for retrieved, relevant in results_list:
        for i, doc in enumerate(retrieved):
            if doc in relevant:
                scores.append(1.0 / (i + 1))
                break
        else:
            scores.append(0.0)
    return sum(scores) / len(scores)

端到端评估 #

python
from dspy.evaluate import Evaluate

def rag_metric(example, pred, trace=None):
    answer_correct = example.answer.lower() in pred.answer.lower()
    if hasattr(pred, 'sources'):
        source_correct = all(s in example.sources for s in pred.sources)
        return answer_correct and source_correct
    return answer_correct

evaluator = Evaluate(
    devset=testset,
    metric=rag_metric,
    num_threads=4
)

score = evaluator(rag)

检索器配置最佳实践 #

1. 选择合适的 k 值 #

python
class AdaptiveRAG(dspy.Module):
    def __init__(self, min_k=2, max_k=10):
        super().__init__()
        self.min_k = min_k
        self.max_k = max_k
        self.retrieve = dspy.Retrieve(k=max_k)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        all_passages = self.retrieve(question).passages
        relevant_passages = self._filter_relevant(all_passages, question)
        return self.generate(context=relevant_passages, question=question)

2. 文档分块策略 #

python
def smart_chunk(document, max_size=500, overlap=50):
    if len(document) <= max_size:
        return [document]
    
    paragraphs = document.split('\n\n')
    chunks = []
    current_chunk = []
    current_size = 0
    
    for para in paragraphs:
        para_size = len(para)
        if current_size + para_size > max_size and current_chunk:
            chunks.append('\n\n'.join(current_chunk))
            current_chunk = [para]
            current_size = para_size
        else:
            current_chunk.append(para)
            current_size += para_size
    
    if current_chunk:
        chunks.append('\n\n'.join(current_chunk))
    
    return chunks

3. 缓存策略 #

python
import dspy
from functools import lru_cache

class CachedRetriever(dspy.Retrieve):
    def __init__(self, rm, cache_size=1000):
        super().__init__(k=rm.k)
        self.rm = rm
        self._cached_forward = lru_cache(maxsize=cache_size)(self._forward_impl)
    
    def forward(self, query):
        return self._cached_forward(query)
    
    def _forward_impl(self, query):
        return self.rm(query)

完整 RAG 示例 #

python
import dspy
from dspy import Example
from dspy.teleprompt import BootstrapFewShot

lm = dspy.LM('openai/gpt-4o-mini')
rm = dspy.ColBERTv2(url='http://localhost:8893/api/search')
dspy.configure(lm=lm, rm=rm)

class GenerateAnswer(dspy.Signature):
    """根据上下文回答问题"""
    context = dspy.InputField(desc="相关上下文")
    question = dspy.InputField(desc="用户问题")
    answer = dspy.OutputField(desc="简洁准确的回答")

class RAG(dspy.Module):
    def __init__(self, k=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=k)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        return self.generate(context=context, question=question)

trainset = [
    Example(question="什么是机器学习?", answer="机器学习是人工智能的一个分支...").with_inputs("question"),
    Example(question="深度学习的优势是什么?", answer="深度学习可以自动学习特征...").with_inputs("question"),
]

def validate_answer(example, pred, trace=None):
    return example.answer.lower() in pred.answer.lower()

optimizer = BootstrapFewShot(metric=validate_answer, max_bootstrapped_demos=3)
optimized_rag = optimizer.compile(RAG(), trainset=trainset)

result = optimized_rag(question="神经网络的基本原理是什么?")
print(result.answer)

下一步 #

现在你已经掌握了检索器的使用方法,接下来学习 高级主题,了解 DSPy 的更多高级功能!

最后更新:2026-03-30