DSPy 检索器 #
什么是检索器? #
检索器(Retriever)是 DSPy 中用于从知识库获取相关内容的组件。它是构建 RAG(Retrieval-Augmented Generation)应用的核心组件。
text
┌─────────────────────────────────────────────────────────────┐
│ 检索器的作用 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 知识检索:从大量文档中找到与问题相关的内容 │
│ 2. 上下文增强:为 LLM 提供准确的背景信息 │
│ 3. 减少幻觉:基于事实生成回答 │
│ 4. 知识更新:无需重新训练即可更新知识 │
│ │
└─────────────────────────────────────────────────────────────┘
RAG 架构概览 #
text
┌─────────────────────────────────────────────────────────────┐
│ RAG 基本架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 用户 │ │ 检索 │ │ 生成 │ │
│ │ 问题 │────>│ 模块 │────>│ 模块 │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ │
│ │ 知识库 │ │ 回答 │ │
│ │ 向量库 │ │ │ │
│ └─────────┘ └─────────┘ │
│ │
│ 流程: │
│ 1. 用户提问 │
│ 2. 检索器从知识库获取相关文档 │
│ 3. 生成模块基于检索结果生成回答 │
│ │
└─────────────────────────────────────────────────────────────┘
内置检索器 #
ColBERTv2 #
ColBERTv2 是 DSPy 默认支持的检索器,使用延迟交互模型。
python
import dspy
rm = dspy.ColBERTv2(url='http://localhost:8893/api/search')
dspy.configure(rm=rm)
retrieve = dspy.Retrieve(k=3)
results = retrieve("什么是机器学习?")
print(results.passages)
配置参数 #
python
retrieve = dspy.Retrieve(
k=3,
filter={"category": "tech"}
)
向量数据库集成 #
ChromaDB #
python
import dspy
import chromadb
from chromadb.utils import embedding_functions
client = chromadb.PersistentClient(path="./chroma_db")
embedding_func = embedding_functions.DefaultEmbeddingFunction()
collection = client.get_or_create_collection(
name="documents",
embedding_function=embedding_func
)
class ChromaRM(dspy.Retrieve):
def __init__(self, collection, k=3):
super().__init__(k=k)
self.collection = collection
def forward(self, query):
results = self.collection.query(
query_texts=[query],
n_results=self.k
)
passages = results['documents'][0]
return dspy.Prediction(passages=passages)
dspy.configure(rm=ChromaRM(collection, k=3))
Pinecone #
python
import dspy
import pinecone
pinecone.init(api_key="your-api-key", environment="us-west1-gcp")
index = pinecone.Index("documents")
class PineconeRM(dspy.Retrieve):
def __init__(self, index, embed_func, k=3):
super().__init__(k=k)
self.index = index
self.embed_func = embed_func
def forward(self, query):
query_vector = self.embed_func(query)
results = self.index.query(
vector=query_vector,
top_k=self.k,
include_metadata=True
)
passages = [m['metadata']['text'] for m in results['matches']]
return dspy.Prediction(passages=passages)
dspy.configure(rm=PineconeRM(index, my_embed_func, k=3))
Weaviate #
python
import dspy
import weaviate
client = weaviate.Client("http://localhost:8080")
class WeaviateRM(dspy.Retrieve):
def __init__(self, client, class_name, k=3):
super().__init__(k=k)
self.client = client
self.class_name = class_name
def forward(self, query):
results = (
self.client.query
.get(self.class_name, ["content"])
.with_near_text({"concepts": [query]})
.with_limit(self.k)
.do()
)
passages = [r['content'] for r in results['data']['Get'][self.class_name]]
return dspy.Prediction(passages=passages)
dspy.configure(rm=WeaviateRM(client, "Document", k=3))
FAISS #
python
import dspy
import faiss
import numpy as np
class FAISSRM(dspy.Retrieve):
def __init__(self, index, documents, embed_func, k=3):
super().__init__(k=k)
self.index = index
self.documents = documents
self.embed_func = embed_func
def forward(self, query):
query_vector = self.embed_func(query)
distances, indices = self.index.search(
np.array([query_vector]), self.k
)
passages = [self.documents[i] for i in indices[0]]
return dspy.Prediction(passages=passages)
dspy.configure(rm=FAISSRM(index, documents, embed_func, k=3))
构建文档索引 #
文档预处理 #
python
def preprocess_documents(raw_docs):
processed = []
for doc in raw_docs:
chunks = split_into_chunks(doc, chunk_size=500, overlap=50)
processed.extend(chunks)
return processed
def split_into_chunks(text, chunk_size=500, overlap=50):
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
chunks.append(chunk)
return chunks
添加文档到索引 #
python
def index_documents(collection, documents):
for i, doc in enumerate(documents):
collection.add(
ids=[f"doc_{i}"],
documents=[doc],
metadatas=[{"source": doc.source}]
)
构建 RAG 应用 #
基础 RAG #
python
import dspy
lm = dspy.LM('openai/gpt-4o-mini')
rm = dspy.ColBERTv2(url='http://localhost:8893/api/search')
dspy.configure(lm=lm, rm=rm)
class GenerateAnswer(dspy.Signature):
"""根据上下文回答问题"""
context = dspy.InputField(desc="相关上下文")
question = dspy.InputField(desc="用户问题")
answer = dspy.OutputField(desc="基于上下文的回答")
class RAG(dspy.Module):
def __init__(self, k=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=k)
self.generate = dspy.ChainOfThought(GenerateAnswer)
def forward(self, question):
context = self.retrieve(question).passages
return self.generate(context=context, question=question)
rag = RAG()
result = rag(question="什么是机器学习?")
print(result.answer)
多跳检索 #
python
import dspy
class GenerateQuery(dspy.Signature):
"""生成搜索查询"""
context = dspy.InputField()
question = dspy.InputField()
query = dspy.OutputField()
class GenerateAnswer(dspy.Signature):
"""根据上下文回答问题"""
context = dspy.InputField()
question = dspy.InputField()
answer = dspy.OutputField()
class MultiHopRAG(dspy.Module):
def __init__(self, passages_per_hop=3, max_hops=2):
super().__init__()
self.max_hops = max_hops
self.retrieve = dspy.Retrieve(k=passages_per_hop)
self.generate_query = dspy.ChainOfThought(GenerateQuery)
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
def forward(self, question):
context = []
for hop in range(self.max_hops):
query = self.generate_query(
context=context,
question=question
).query
passages = self.retrieve(query).passages
context.extend(passages)
return self.generate_answer(
context=context,
question=question
)
multi_hop = MultiHopRAG(passages_per_hop=3, max_hops=2)
result = multi_hop(question="谁发明了 Python,它有什么特点?")
带来源引用的 RAG #
python
import dspy
class GenerateAnswerWithSource(dspy.Signature):
"""根据上下文回答问题并引用来源"""
context = dspy.InputField(desc="相关上下文,每段带有编号")
question = dspy.InputField(desc="用户问题")
answer = dspy.OutputField(desc="回答内容")
sources = dspy.OutputField(desc="引用的上下文编号列表")
class SourceRAG(dspy.Module):
def __init__(self, k=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=k)
self.generate = dspy.ChainOfThought(GenerateAnswerWithSource)
def forward(self, question):
passages = self.retrieve(question).passages
context = self._format_context(passages)
result = self.generate(context=context, question=question)
cited_passages = [
passages[i-1] for i in result.sources if i <= len(passages)
]
return dspy.Prediction(
answer=result.answer,
sources=result.sources,
cited_passages=cited_passages
)
def _format_context(self, passages):
return "\n\n".join(
f"[{i+1}] {p}" for i, p in enumerate(passages)
)
source_rag = SourceRAG()
result = source_rag(question="什么是深度学习?")
print(f"回答: {result.answer}")
print(f"引用来源: {result.sources}")
检索优化 #
查询重写 #
python
import dspy
class RewriteQuery(dspy.Signature):
"""优化搜索查询"""
original_query = dspy.InputField()
rewritten_query = dspy.OutputField(desc="更精确的搜索查询")
class OptimizedRAG(dspy.Module):
def __init__(self, k=3):
super().__init__()
self.rewrite = dspy.Predict(RewriteQuery)
self.retrieve = dspy.Retrieve(k=k)
self.generate = dspy.ChainOfThought(GenerateAnswer)
def forward(self, question):
rewritten = self.rewrite(original_query=question).rewritten_query
context = self.retrieve(rewritten).passages
return self.generate(context=context, question=question)
混合检索 #
python
import dspy
class HybridRetriever(dspy.Retrieve):
def __init__(self, dense_rm, sparse_rm, k=3, alpha=0.5):
super().__init__(k=k)
self.dense_rm = dense_rm
self.sparse_rm = sparse_rm
self.alpha = alpha
def forward(self, query):
dense_results = self.dense_rm(query).passages
sparse_results = self.sparse_rm(query).passages
combined = self._merge_results(dense_results, sparse_results)
return dspy.Prediction(passages=combined[:self.k])
def _merge_results(self, dense, sparse):
seen = set()
merged = []
for d, s in zip(dense, sparse):
if d not in seen:
merged.append(d)
seen.add(d)
if s not in seen:
merged.append(s)
seen.add(s)
return merged
重排序 #
python
import dspy
class Rerank(dspy.Signature):
"""对检索结果进行重排序"""
query = dspy.InputField()
passages = dspy.InputField(desc="候选段落列表")
ranked_passages = dspy.OutputField(desc="按相关性排序的段落")
class RerankedRAG(dspy.Module):
def __init__(self, k=5, final_k=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=k)
self.rerank = dspy.Predict(Rerank)
self.generate = dspy.ChainOfThought(GenerateAnswer)
self.final_k = final_k
def forward(self, question):
passages = self.retrieve(question).passages
ranked = self.rerank(
query=question,
passages=passages
).ranked_passages
top_passages = ranked[:self.final_k]
return self.generate(context=top_passages, question=question)
评估检索质量 #
检索评估指标 #
python
def retrieval_recall(retrieved, relevant, k=None):
if k:
retrieved = retrieved[:k]
return len(set(retrieved) & set(relevant)) / len(relevant)
def retrieval_precision(retrieved, relevant, k=None):
if k:
retrieved = retrieved[:k]
return len(set(retrieved) & set(relevant)) / len(retrieved)
def mean_reciprocal_rank(results_list):
scores = []
for retrieved, relevant in results_list:
for i, doc in enumerate(retrieved):
if doc in relevant:
scores.append(1.0 / (i + 1))
break
else:
scores.append(0.0)
return sum(scores) / len(scores)
端到端评估 #
python
from dspy.evaluate import Evaluate
def rag_metric(example, pred, trace=None):
answer_correct = example.answer.lower() in pred.answer.lower()
if hasattr(pred, 'sources'):
source_correct = all(s in example.sources for s in pred.sources)
return answer_correct and source_correct
return answer_correct
evaluator = Evaluate(
devset=testset,
metric=rag_metric,
num_threads=4
)
score = evaluator(rag)
检索器配置最佳实践 #
1. 选择合适的 k 值 #
python
class AdaptiveRAG(dspy.Module):
def __init__(self, min_k=2, max_k=10):
super().__init__()
self.min_k = min_k
self.max_k = max_k
self.retrieve = dspy.Retrieve(k=max_k)
self.generate = dspy.ChainOfThought(GenerateAnswer)
def forward(self, question):
all_passages = self.retrieve(question).passages
relevant_passages = self._filter_relevant(all_passages, question)
return self.generate(context=relevant_passages, question=question)
2. 文档分块策略 #
python
def smart_chunk(document, max_size=500, overlap=50):
if len(document) <= max_size:
return [document]
paragraphs = document.split('\n\n')
chunks = []
current_chunk = []
current_size = 0
for para in paragraphs:
para_size = len(para)
if current_size + para_size > max_size and current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = [para]
current_size = para_size
else:
current_chunk.append(para)
current_size += para_size
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
return chunks
3. 缓存策略 #
python
import dspy
from functools import lru_cache
class CachedRetriever(dspy.Retrieve):
def __init__(self, rm, cache_size=1000):
super().__init__(k=rm.k)
self.rm = rm
self._cached_forward = lru_cache(maxsize=cache_size)(self._forward_impl)
def forward(self, query):
return self._cached_forward(query)
def _forward_impl(self, query):
return self.rm(query)
完整 RAG 示例 #
python
import dspy
from dspy import Example
from dspy.teleprompt import BootstrapFewShot
lm = dspy.LM('openai/gpt-4o-mini')
rm = dspy.ColBERTv2(url='http://localhost:8893/api/search')
dspy.configure(lm=lm, rm=rm)
class GenerateAnswer(dspy.Signature):
"""根据上下文回答问题"""
context = dspy.InputField(desc="相关上下文")
question = dspy.InputField(desc="用户问题")
answer = dspy.OutputField(desc="简洁准确的回答")
class RAG(dspy.Module):
def __init__(self, k=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=k)
self.generate = dspy.ChainOfThought(GenerateAnswer)
def forward(self, question):
context = self.retrieve(question).passages
return self.generate(context=context, question=question)
trainset = [
Example(question="什么是机器学习?", answer="机器学习是人工智能的一个分支...").with_inputs("question"),
Example(question="深度学习的优势是什么?", answer="深度学习可以自动学习特征...").with_inputs("question"),
]
def validate_answer(example, pred, trace=None):
return example.answer.lower() in pred.answer.lower()
optimizer = BootstrapFewShot(metric=validate_answer, max_bootstrapped_demos=3)
optimized_rag = optimizer.compile(RAG(), trainset=trainset)
result = optimized_rag(question="神经网络的基本原理是什么?")
print(result.answer)
下一步 #
现在你已经掌握了检索器的使用方法,接下来学习 高级主题,了解 DSPy 的更多高级功能!
最后更新:2026-03-30