文档问答系统 #
项目概述 #
本章将构建一个完整的企业级文档问答系统,支持多种文档格式、增量更新、高效检索等功能。
text
┌─────────────────────────────────────────────────────────────┐
│ 文档问答系统架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 用户界面层 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Web UI │ CLI │ API │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 应用层 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Query Engine │ Chat Engine │ Agent │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 数据层 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Document Loader │ Index │ Vector Store │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 存储层 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ File System │ Database │ Cloud Storage │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
项目结构 #
text
qa_system/
├── config/
│ └── settings.py
├── data/
│ └── documents/
├── src/
│ ├── __init__.py
│ ├── document_processor.py
│ ├── index_manager.py
│ ├── query_engine.py
│ └── utils.py
├── storage/
│ └── chroma/
├── main.py
└── requirements.txt
核心代码实现 #
配置管理 #
python
import os
from pydantic_settings import BaseSettings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
class Settings(BaseSettings):
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
llm_model: str = "gpt-4o-mini"
embed_model: str = "text-embedding-3-small"
chunk_size: int = 512
chunk_overlap: int = 50
data_dir: str = "./data/documents"
storage_dir: str = "./storage/chroma"
similarity_top_k: int = 5
class Config:
env_file = ".env"
settings = Settings()
def get_llm():
return OpenAI(model=settings.llm_model)
def get_embed_model():
return OpenAIEmbedding(model=settings.embed_model)
文档处理器 #
python
from llama_index.core import SimpleDirectoryReader, Document
from llama_index.core.node_parser import SentenceSplitter
from typing import List
import os
class DocumentProcessor:
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
self.splitter = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
def load_documents(self, directory: str) -> List[Document]:
if not os.path.exists(directory):
raise FileNotFoundError(f"目录不存在: {directory}")
reader = SimpleDirectoryReader(
input_dir=directory,
recursive=True,
required_exts=[".pdf", ".txt", ".md", ".docx"],
)
documents = reader.load_data(show_progress=True)
print(f"加载了 {len(documents)} 个文档")
return documents
def process_documents(self, documents: List[Document]):
nodes = self.splitter.get_nodes_from_documents(
documents,
show_progress=True,
)
print(f"生成了 {len(nodes)} 个节点")
return nodes
def add_metadata(self, documents: List[Document]) -> List[Document]:
for doc in documents:
file_path = doc.metadata.get("file_path", "")
file_name = os.path.basename(file_path)
file_ext = os.path.splitext(file_name)[1].lower()
doc.metadata["file_name"] = file_name
doc.metadata["file_type"] = file_ext
doc.metadata["processed_at"] = datetime.now().isoformat()
return documents
索引管理器 #
python
import chromadb
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import Settings
from typing import Optional, List
class IndexManager:
def __init__(
self,
storage_dir: str,
collection_name: str = "qa_system",
):
self.storage_dir = storage_dir
self.collection_name = collection_name
self.index: Optional[VectorStoreIndex] = None
self.db: Optional[chromadb.PersistentClient] = None
def initialize(self):
Settings.llm = get_llm()
Settings.embed_model = get_embed_model()
self.db = chromadb.PersistentClient(path=self.storage_dir)
chroma_collection = self.db.get_or_create_collection(
self.collection_name
)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store
)
def build_index(self, documents: List[Document]) -> VectorStoreIndex:
self.index = VectorStoreIndex.from_documents(
documents,
storage_context=self.storage_context,
show_progress=True,
)
print("索引构建完成")
return self.index
def load_index(self) -> VectorStoreIndex:
if self.db is None:
self.initialize()
chroma_collection = self.db.get_collection(self.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.index = VectorStoreIndex.from_vector_store(vector_store)
print("索引加载完成")
return self.index
def add_documents(self, documents: List[Document]):
if self.index is None:
raise ValueError("索引未初始化")
for doc in documents:
self.index.insert(doc)
print(f"添加了 {len(documents)} 个文档")
def delete_document(self, doc_id: str):
if self.index is None:
raise ValueError("索引未初始化")
self.index.delete_ref_doc(doc_id)
print(f"删除了文档: {doc_id}")
def get_stats(self) -> dict:
if self.db is None:
return {}
collection = self.db.get_collection(self.collection_name)
return {
"document_count": collection.count(),
"collection_name": self.collection_name,
}
查询引擎 #
python
from llama_index.core import VectorStoreIndex
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import get_response_synthesizer
from llama_index.postprocessor.sentence_transformer_rerank import (
SentenceTransformerRerank,
)
from typing import Optional, List
class QAQueryEngine:
def __init__(
self,
index: VectorStoreIndex,
similarity_top_k: int = 5,
use_reranker: bool = True,
):
self.index = index
self.similarity_top_k = similarity_top_k
self.use_reranker = use_reranker
self._setup_engine()
def _setup_engine(self):
retriever = self.index.as_retriever(
similarity_top_k=self.similarity_top_k * 2,
)
response_synthesizer = get_response_synthesizer(
response_mode="compact",
)
node_postprocessors = []
if self.use_reranker:
reranker = SentenceTransformerRerank(
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
top_n=self.similarity_top_k,
)
node_postprocessors.append(reranker)
self.query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=node_postprocessors,
)
def query(self, question: str) -> dict:
response = self.query_engine.query(question)
sources = []
for node in response.source_nodes:
sources.append({
"content": node.node.text[:200] + "...",
"score": node.score,
"metadata": node.node.metadata,
})
return {
"answer": str(response),
"sources": sources,
}
def query_stream(self, question: str):
streaming_engine = self.index.as_query_engine(
streaming=True,
similarity_top_k=self.similarity_top_k,
)
response = streaming_engine.query(question)
for text in response.response_gen:
yield text
主程序 #
python
import os
from datetime import datetime
from config.settings import settings
from src.document_processor import DocumentProcessor
from src.index_manager import IndexManager
from src.query_engine import QAQueryEngine
class QASystem:
def __init__(self):
self.doc_processor = DocumentProcessor(
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap,
)
self.index_manager = IndexManager(
storage_dir=settings.storage_dir,
)
self.query_engine: Optional[QAQueryEngine] = None
def initialize(self, rebuild: bool = False):
self.index_manager.initialize()
if rebuild:
print("重新构建索引...")
documents = self.doc_processor.load_documents(settings.data_dir)
documents = self.doc_processor.add_metadata(documents)
index = self.index_manager.build_index(documents)
else:
print("加载现有索引...")
index = self.index_manager.load_index()
self.query_engine = QAQueryEngine(
index=index,
similarity_top_k=settings.similarity_top_k,
)
print("系统初始化完成")
def ask(self, question: str) -> dict:
if self.query_engine is None:
raise ValueError("系统未初始化")
return self.query_engine.query(question)
def ask_stream(self, question: str):
if self.query_engine is None:
raise ValueError("系统未初始化")
return self.query_engine.query_stream(question)
def add_document(self, file_path: str):
documents = self.doc_processor.load_documents(
os.path.dirname(file_path)
)
documents = self.doc_processor.add_metadata(documents)
self.index_manager.add_documents(documents)
def get_stats(self) -> dict:
return self.index_manager.get_stats()
def main():
qa_system = QASystem()
qa_system.initialize(rebuild=False)
print("\n=== 文档问答系统 ===")
print("输入 'quit' 退出")
print("输入 'stats' 查看统计信息")
print("输入 'rebuild' 重建索引\n")
while True:
question = input("问题: ").strip()
if question.lower() == "quit":
break
if question.lower() == "stats":
stats = qa_system.get_stats()
print(f"\n统计信息: {stats}\n")
continue
if question.lower() == "rebuild":
qa_system.initialize(rebuild=True)
continue
if not question:
continue
print("\n回答: ", end="")
result = qa_system.ask(question)
print(result["answer"])
print("\n来源:")
for i, source in enumerate(result["sources"][:3]):
print(f" [{i+1}] {source['metadata'].get('file_name', 'unknown')}")
print(f" 相似度: {source['score']:.4f}")
print()
if __name__ == "__main__":
main()
API 服务 #
FastAPI 实现 #
python
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
import tempfile
import os
app = FastAPI(title="文档问答系统 API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
qa_system: Optional[QASystem] = None
class QueryRequest(BaseModel):
question: str
stream: bool = False
class QueryResponse(BaseModel):
answer: str
sources: List[dict]
class StatsResponse(BaseModel):
document_count: int
collection_name: str
@app.on_event("startup")
async def startup():
global qa_system
qa_system = QASystem()
qa_system.initialize()
@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
if qa_system is None:
raise HTTPException(status_code=500, detail="系统未初始化")
result = qa_system.ask(request.question)
return QueryResponse(**result)
@app.get("/stats", response_model=StatsResponse)
async def stats():
if qa_system is None:
raise HTTPException(status_code=500, detail="系统未初始化")
stats = qa_system.get_stats()
return StatsResponse(**stats)
@app.post("/upload")
async def upload_document(file: UploadFile = File(...)):
if qa_system is None:
raise HTTPException(status_code=500, detail="系统未初始化")
with tempfile.NamedTemporaryFile(delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
qa_system.add_document(tmp_path)
return {"message": "文档上传成功", "filename": file.filename}
finally:
os.unlink(tmp_path)
@app.post("/rebuild")
async def rebuild_index():
if qa_system is None:
raise HTTPException(status_code=500, detail="系统未初始化")
qa_system.initialize(rebuild=True)
return {"message": "索引重建成功"}
部署建议 #
Docker 部署 #
dockerfile
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]
docker-compose.yml #
yaml
version: '3.8'
services:
qa-api:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
volumes:
- ./data:/app/data
- ./storage:/app/storage
restart: unless-stopped
下一步 #
完成文档问答系统后,接下来学习 聊天机器人 构建多轮对话系统!
最后更新:2026-03-30