记忆与持久化 #
概述 #
记忆和持久化是构建长时间运行 Agent 的关键能力。LangGraph 通过 Checkpointer 机制自动保存和恢复状态,支持跨会话的记忆、故障恢复和人机交互。
text
┌─────────────────────────────────────────────────────────────┐
│ 记忆与持久化架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 短期记忆(Working Memory): │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 当前会话的状态 │ │
│ │ - 对话历史 │ │
│ │ - 中间结果 │ │
│ │ - 当前上下文 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 长期记忆(Long-term Memory): │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 跨会话的持久化存储 │ │
│ │ - 用户偏好 │ │
│ │ - 历史交互 │ │
│ │ - 学习到的知识 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ Checkpointer: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 自动保存执行状态 │ │
│ │ - 状态快照 │ │
│ │ - 执行历史 │ │
│ │ - 恢复点 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
Checkpointer #
Checkpointer 是 LangGraph 持久化的核心组件,负责保存和恢复图的状态。
内存 Checkpointer #
适用于开发和测试:
python
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph
checkpointer = MemorySaver()
app = graph.compile(checkpointer=checkpointer)
SQLite Checkpointer #
适用于本地持久化:
python
from langgraph.checkpoint.sqlite import SqliteSaver
with SqliteSaver("checkpoints.db") as checkpointer:
app = graph.compile(checkpointer=checkpointer)
PostgreSQL Checkpointer #
适用于生产环境:
python
from langgraph.checkpoint.postgres import PostgresSaver
conn_string = "postgresql://user:pass@localhost/db"
checkpointer = PostgresSaver(conn_string)
app = graph.compile(checkpointer=checkpointer)
Redis Checkpointer #
适用于分布式环境:
python
from langgraph.checkpoint.redis import RedisSaver
checkpointer = RedisSaver("redis://localhost:6379")
app = graph.compile(checkpointer=checkpointer)
Checkpointer 对比 #
text
┌─────────────────────────────────────────────────────────────┐
│ Checkpointer 对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ MemorySaver: │
│ ✅ 简单易用 │
│ ✅ 无需配置 │
│ ❌ 重启后丢失 │
│ 适用:开发、测试 │
│ │
│ SqliteSaver: │
│ ✅ 本地持久化 │
│ ✅ 无需额外服务 │
│ ⚠️ 单机限制 │
│ 适用:小型应用、本地开发 │
│ │
│ PostgresSaver: │
│ ✅ 生产级 │
│ ✅ 支持并发 │
│ ⚠️ 需要数据库 │
│ 适用:生产环境 │
│ │
│ RedisSaver: │
│ ✅ 高性能 │
│ ✅ 分布式支持 │
│ ⚠️ 需要 Redis │
│ 适用:高并发、分布式 │
│ │
└─────────────────────────────────────────────────────────────┘
使用 Checkpointer #
基本使用 #
python
from langgraph.checkpoint.memory import MemorySaver
checkpointer = MemorySaver()
app = graph.compile(checkpointer=checkpointer)
config = {"configurable": {"thread_id": "conversation-1"}}
result1 = app.invoke(
{"messages": [("user", "My name is Alice")]},
config
)
result2 = app.invoke(
{"messages": [("user", "What's my name?")]},
config
)
print(result2["messages"][-1].content)
# "Your name is Alice."
Thread ID #
Thread ID 用于区分不同的会话:
python
config1 = {"configurable": {"thread_id": "user-123"}}
config2 = {"configurable": {"thread_id": "user-456"}}
app.invoke({"messages": [...]}, config1)
app.invoke({"messages": [...]}, config2)
获取当前状态 #
python
config = {"configurable": {"thread_id": "conversation-1"}}
result = app.invoke({"messages": [...]}, config)
current_state = app.get_state(config)
print(current_state.values)
状态快照结构 #
python
snapshot = app.get_state(config)
snapshot.values # 当前状态值
snapshot.next # 下一个要执行的节点
snapshot.config # 配置信息
snapshot.metadata # 元数据(步骤、时间等)
snapshot.created_at # 创建时间
snapshot.parent_config # 父状态配置
获取状态历史 #
python
history = app.get_state_history(config)
for snapshot in history:
print(f"Step: {snapshot.metadata['step']}")
print(f"State: {snapshot.values}")
print(f"Time: {snapshot.created_at}")
更新状态 #
python
app.update_state(
config,
{"messages": [("user", "Override message")]}
)
从历史状态恢复 #
python
history = list(app.get_state_history(config))
target_snapshot = history[-5]
result = app.invoke(None, target_snapshot.config)
记忆类型 #
短期记忆 #
短期记忆保存在当前会话的状态中:
python
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
class State(TypedDict):
messages: Annotated[list, add_messages]
current_context: str
temp_data: dict
def agent_node(state: State):
context = state.get("current_context", "")
temp = state.get("temp_data", {})
return {"messages": [...]}
长期记忆 #
长期记忆需要额外的存储机制:
python
from typing import TypedDict
from langgraph.store.memory import InMemoryStore
class UserProfile(TypedDict):
user_id: str
preferences: dict
history: list
store = InMemoryStore()
def save_user_memory(user_id: str, data: dict):
store.put(("users", user_id), data)
def get_user_memory(user_id: str) -> dict:
return store.get(("users", user_id))
使用 Store #
python
from langgraph.store.memory import InMemoryStore
store = InMemoryStore()
def agent_node(state: State, config: RunnableConfig):
user_id = config["configurable"]["user_id"]
user_memory = store.get(("users", user_id))
response = llm.invoke(state["messages"])
store.put(("users", user_id), {"last_interaction": response})
return {"messages": [response]}
记忆模式 #
1. 对话记忆 #
python
from langgraph.graph import MessagesState
def chat_node(state: MessagesState):
response = llm.invoke(state["messages"])
return {"messages": [response]}
graph = StateGraph(MessagesState)
graph.add_node("chat", chat_node)
graph.add_edge(START, "chat")
graph.add_edge("chat", END)
app = graph.compile(checkpointer=MemorySaver())
config = {"configurable": {"thread_id": "chat-1"}}
app.invoke({"messages": [("user", "Hi, I'm Bob")]}, config)
app.invoke({"messages": [("user", "What's my name?")]}, config)
2. 摘要记忆 #
python
from typing import TypedDict, Annotated
from operator import add
class State(TypedDict):
messages: Annotated[list, add_messages]
summary: str
def summarize(state: State):
if len(state["messages"]) > 10:
summary = llm.invoke(f"Summarize: {state['messages']}")
return {"summary": summary.content, "messages": []}
return {}
def agent_node(state: State):
context = f"Summary: {state.get('summary', 'None')}\n\n"
context += f"Recent messages: {state['messages']}"
response = llm.invoke(context)
return {"messages": [response]}
3. 实体记忆 #
python
from typing import TypedDict
class State(TypedDict):
messages: list
entities: dict
def extract_entities(state: State):
text = state["messages"][-1].content
entities = llm.invoke(f"Extract entities from: {text}")
return {"entities": entities}
def agent_node(state: State):
entity_context = f"Known entities: {state['entities']}"
response = llm.invoke([entity_context] + state["messages"])
return {"messages": [response]}
4. 知识图谱记忆 #
python
from typing import TypedDict
class State(TypedDict):
messages: list
knowledge_graph: dict
def update_knowledge(state: State):
new_facts = extract_facts(state["messages"][-1])
graph = state.get("knowledge_graph", {})
for fact in new_facts:
graph = add_to_graph(graph, fact)
return {"knowledge_graph": graph}
持久化工作流 #
故障恢复 #
python
from langgraph.checkpoint.sqlite import SqliteSaver
with SqliteSaver("checkpoints.db") as checkpointer:
app = graph.compile(checkpointer=checkpointer)
try:
result = app.invoke({"messages": [...]}, config)
except Exception as e:
print(f"Error: {e}")
state = app.get_state(config)
print(f"Last saved state: {state.values}")
暂停和继续 #
python
config = {"configurable": {"thread_id": "long-task"}}
result = app.invoke({"messages": [...]}, config)
state = app.get_state(config)
print(f"Paused at: {state.next}")
result = app.invoke(None, config)
时间旅行 #
python
history = list(app.get_state_history(config))
for i, snapshot in enumerate(history):
print(f"{i}: Step {snapshot.metadata['step']}")
target = history[3]
result = app.invoke(None, target.config)
记忆管理 #
清理旧状态 #
python
from datetime import datetime, timedelta
def cleanup_old_states(app, config, days=7):
cutoff = datetime.now() - timedelta(days=days)
history = list(app.get_state_history(config))
for snapshot in history:
if snapshot.created_at < cutoff:
app.delete_state(snapshot.config)
状态压缩 #
python
def compress_state(state: State) -> dict:
messages = state["messages"]
if len(messages) > 20:
summary = summarize_messages(messages[:-10])
return {
"messages": messages[-10:],
"summary": summary
}
return state
记忆限制 #
python
from typing import TypedDict, Annotated
def limit_messages(messages: list, max_count: int = 50) -> list:
if len(messages) > max_count:
return messages[-max_count:]
return messages
class State(TypedDict):
messages: Annotated[list, lambda old, new: limit_messages(old + new)]
高级持久化 #
自定义 Checkpointer #
python
from langgraph.checkpoint.base import BaseCheckpointSaver
from typing import Any, Optional
class CustomCheckpointer(BaseCheckpointSaver):
def __init__(self, storage_backend):
self.storage = storage_backend
def get(self, config: dict) -> Optional[dict]:
thread_id = config["configurable"]["thread_id"]
return self.storage.load(thread_id)
def put(self, config: dict, checkpoint: dict) -> None:
thread_id = config["configurable"]["thread_id"]
self.storage.save(thread_id, checkpoint)
分片存储 #
python
def get_sharded_checkpointer(user_id: str):
shard = hash(user_id) % 10
return PostgresSaver(f"postgresql://db{shard}/checkpoints")
加密存储 #
python
from cryptography.fernet import Fernet
class EncryptedCheckpointer:
def __init__(self, backend, key):
self.backend = backend
self.cipher = Fernet(key)
def save(self, thread_id: str, state: dict):
encrypted = self.cipher.encrypt(str(state).encode())
self.backend.save(thread_id, encrypted)
def load(self, thread_id: str) -> dict:
encrypted = self.backend.load(thread_id)
return eval(self.cipher.decrypt(encrypted).decode())
最佳实践 #
1. 选择合适的 Checkpointer #
text
┌─────────────────────────────────────────────────────────────┐
│ Checkpointer 选择指南 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 开发阶段: │
│ → MemorySaver │
│ │
│ 单机部署: │
│ → SqliteSaver │
│ │
│ 生产环境: │
│ → PostgresSaver 或 RedisSaver │
│ │
│ 高并发: │
│ → RedisSaver + 分片 │
│ │
└─────────────────────────────────────────────────────────────┘
2. 合理设计 Thread ID #
python
import uuid
def get_thread_id(user_id: str, session_type: str) -> str:
return f"{user_id}:{session_type}:{uuid.uuid4()}"
config = {
"configurable": {
"thread_id": get_thread_id("user-123", "chat")
}
}
3. 定期清理 #
python
import asyncio
async def cleanup_task():
while True:
await asyncio.sleep(3600)
cleanup_old_checkpoints()
4. 监控存储 #
python
def monitor_storage():
total_size = get_storage_size()
total_threads = get_thread_count()
if total_size > WARNING_THRESHOLD:
alert("Storage usage high")
下一步 #
现在你已经掌握了记忆与持久化,接下来学习 人机交互,了解如何让人类参与 Agent 决策!
最后更新:2026-03-30