记忆与持久化 #

概述 #

记忆和持久化是构建长时间运行 Agent 的关键能力。LangGraph 通过 Checkpointer 机制自动保存和恢复状态,支持跨会话的记忆、故障恢复和人机交互。

text
┌─────────────────────────────────────────────────────────────┐
│                    记忆与持久化架构                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   短期记忆(Working Memory):                               │
│   ┌─────────────────────────────────────────────────────┐   │
│   │  当前会话的状态                                       │   │
│   │  - 对话历史                                          │   │
│   │  - 中间结果                                          │   │
│   │  - 当前上下文                                        │   │
│   └─────────────────────────────────────────────────────┘   │
│                                                             │
│   长期记忆(Long-term Memory):                             │
│   ┌─────────────────────────────────────────────────────┐   │
│   │  跨会话的持久化存储                                   │   │
│   │  - 用户偏好                                          │   │
│   │  - 历史交互                                          │   │
│   │  - 学习到的知识                                      │   │
│   └─────────────────────────────────────────────────────┘   │
│                                                             │
│   Checkpointer:                                             │
│   ┌─────────────────────────────────────────────────────┐   │
│   │  自动保存执行状态                                     │   │
│   │  - 状态快照                                          │   │
│   │  - 执行历史                                          │   │
│   │  - 恢复点                                            │   │
│   └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Checkpointer #

Checkpointer 是 LangGraph 持久化的核心组件,负责保存和恢复图的状态。

内存 Checkpointer #

适用于开发和测试:

python
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph

checkpointer = MemorySaver()
app = graph.compile(checkpointer=checkpointer)

SQLite Checkpointer #

适用于本地持久化:

python
from langgraph.checkpoint.sqlite import SqliteSaver

with SqliteSaver("checkpoints.db") as checkpointer:
    app = graph.compile(checkpointer=checkpointer)

PostgreSQL Checkpointer #

适用于生产环境:

python
from langgraph.checkpoint.postgres import PostgresSaver

conn_string = "postgresql://user:pass@localhost/db"
checkpointer = PostgresSaver(conn_string)
app = graph.compile(checkpointer=checkpointer)

Redis Checkpointer #

适用于分布式环境:

python
from langgraph.checkpoint.redis import RedisSaver

checkpointer = RedisSaver("redis://localhost:6379")
app = graph.compile(checkpointer=checkpointer)

Checkpointer 对比 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Checkpointer 对比                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  MemorySaver:                                               │
│  ✅ 简单易用                                                │
│  ✅ 无需配置                                                │
│  ❌ 重启后丢失                                              │
│  适用:开发、测试                                           │
│                                                             │
│  SqliteSaver:                                               │
│  ✅ 本地持久化                                              │
│  ✅ 无需额外服务                                            │
│  ⚠️ 单机限制                                                │
│  适用:小型应用、本地开发                                   │
│                                                             │
│  PostgresSaver:                                             │
│  ✅ 生产级                                                  │
│  ✅ 支持并发                                                │
│  ⚠️ 需要数据库                                              │
│  适用:生产环境                                             │
│                                                             │
│  RedisSaver:                                                │
│  ✅ 高性能                                                  │
│  ✅ 分布式支持                                              │
│  ⚠️ 需要 Redis                                              │
│  适用:高并发、分布式                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

使用 Checkpointer #

基本使用 #

python
from langgraph.checkpoint.memory import MemorySaver

checkpointer = MemorySaver()
app = graph.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "conversation-1"}}

result1 = app.invoke(
    {"messages": [("user", "My name is Alice")]},
    config
)

result2 = app.invoke(
    {"messages": [("user", "What's my name?")]},
    config
)

print(result2["messages"][-1].content)
# "Your name is Alice."

Thread ID #

Thread ID 用于区分不同的会话:

python
config1 = {"configurable": {"thread_id": "user-123"}}
config2 = {"configurable": {"thread_id": "user-456"}}

app.invoke({"messages": [...]}, config1)
app.invoke({"messages": [...]}, config2)

获取当前状态 #

python
config = {"configurable": {"thread_id": "conversation-1"}}

result = app.invoke({"messages": [...]}, config)

current_state = app.get_state(config)
print(current_state.values)

状态快照结构 #

python
snapshot = app.get_state(config)

snapshot.values       # 当前状态值
snapshot.next         # 下一个要执行的节点
snapshot.config       # 配置信息
snapshot.metadata     # 元数据(步骤、时间等)
snapshot.created_at   # 创建时间
snapshot.parent_config  # 父状态配置

获取状态历史 #

python
history = app.get_state_history(config)

for snapshot in history:
    print(f"Step: {snapshot.metadata['step']}")
    print(f"State: {snapshot.values}")
    print(f"Time: {snapshot.created_at}")

更新状态 #

python
app.update_state(
    config,
    {"messages": [("user", "Override message")]}
)

从历史状态恢复 #

python
history = list(app.get_state_history(config))
target_snapshot = history[-5]

result = app.invoke(None, target_snapshot.config)

记忆类型 #

短期记忆 #

短期记忆保存在当前会话的状态中:

python
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages

class State(TypedDict):
    messages: Annotated[list, add_messages]
    current_context: str
    temp_data: dict

def agent_node(state: State):
    context = state.get("current_context", "")
    temp = state.get("temp_data", {})
    return {"messages": [...]}

长期记忆 #

长期记忆需要额外的存储机制:

python
from typing import TypedDict
from langgraph.store.memory import InMemoryStore

class UserProfile(TypedDict):
    user_id: str
    preferences: dict
    history: list

store = InMemoryStore()

def save_user_memory(user_id: str, data: dict):
    store.put(("users", user_id), data)

def get_user_memory(user_id: str) -> dict:
    return store.get(("users", user_id))

使用 Store #

python
from langgraph.store.memory import InMemoryStore

store = InMemoryStore()

def agent_node(state: State, config: RunnableConfig):
    user_id = config["configurable"]["user_id"]
    
    user_memory = store.get(("users", user_id))
    
    response = llm.invoke(state["messages"])
    
    store.put(("users", user_id), {"last_interaction": response})
    
    return {"messages": [response]}

记忆模式 #

1. 对话记忆 #

python
from langgraph.graph import MessagesState

def chat_node(state: MessagesState):
    response = llm.invoke(state["messages"])
    return {"messages": [response]}

graph = StateGraph(MessagesState)
graph.add_node("chat", chat_node)
graph.add_edge(START, "chat")
graph.add_edge("chat", END)

app = graph.compile(checkpointer=MemorySaver())

config = {"configurable": {"thread_id": "chat-1"}}

app.invoke({"messages": [("user", "Hi, I'm Bob")]}, config)
app.invoke({"messages": [("user", "What's my name?")]}, config)

2. 摘要记忆 #

python
from typing import TypedDict, Annotated
from operator import add

class State(TypedDict):
    messages: Annotated[list, add_messages]
    summary: str

def summarize(state: State):
    if len(state["messages"]) > 10:
        summary = llm.invoke(f"Summarize: {state['messages']}")
        return {"summary": summary.content, "messages": []}
    return {}

def agent_node(state: State):
    context = f"Summary: {state.get('summary', 'None')}\n\n"
    context += f"Recent messages: {state['messages']}"
    response = llm.invoke(context)
    return {"messages": [response]}

3. 实体记忆 #

python
from typing import TypedDict

class State(TypedDict):
    messages: list
    entities: dict

def extract_entities(state: State):
    text = state["messages"][-1].content
    entities = llm.invoke(f"Extract entities from: {text}")
    return {"entities": entities}

def agent_node(state: State):
    entity_context = f"Known entities: {state['entities']}"
    response = llm.invoke([entity_context] + state["messages"])
    return {"messages": [response]}

4. 知识图谱记忆 #

python
from typing import TypedDict

class State(TypedDict):
    messages: list
    knowledge_graph: dict

def update_knowledge(state: State):
    new_facts = extract_facts(state["messages"][-1])
    graph = state.get("knowledge_graph", {})
    for fact in new_facts:
        graph = add_to_graph(graph, fact)
    return {"knowledge_graph": graph}

持久化工作流 #

故障恢复 #

python
from langgraph.checkpoint.sqlite import SqliteSaver

with SqliteSaver("checkpoints.db") as checkpointer:
    app = graph.compile(checkpointer=checkpointer)
    
    try:
        result = app.invoke({"messages": [...]}, config)
    except Exception as e:
        print(f"Error: {e}")
        state = app.get_state(config)
        print(f"Last saved state: {state.values}")

暂停和继续 #

python
config = {"configurable": {"thread_id": "long-task"}}

result = app.invoke({"messages": [...]}, config)

state = app.get_state(config)
print(f"Paused at: {state.next}")

result = app.invoke(None, config)

时间旅行 #

python
history = list(app.get_state_history(config))

for i, snapshot in enumerate(history):
    print(f"{i}: Step {snapshot.metadata['step']}")

target = history[3]
result = app.invoke(None, target.config)

记忆管理 #

清理旧状态 #

python
from datetime import datetime, timedelta

def cleanup_old_states(app, config, days=7):
    cutoff = datetime.now() - timedelta(days=days)
    
    history = list(app.get_state_history(config))
    
    for snapshot in history:
        if snapshot.created_at < cutoff:
            app.delete_state(snapshot.config)

状态压缩 #

python
def compress_state(state: State) -> dict:
    messages = state["messages"]
    if len(messages) > 20:
        summary = summarize_messages(messages[:-10])
        return {
            "messages": messages[-10:],
            "summary": summary
        }
    return state

记忆限制 #

python
from typing import TypedDict, Annotated

def limit_messages(messages: list, max_count: int = 50) -> list:
    if len(messages) > max_count:
        return messages[-max_count:]
    return messages

class State(TypedDict):
    messages: Annotated[list, lambda old, new: limit_messages(old + new)]

高级持久化 #

自定义 Checkpointer #

python
from langgraph.checkpoint.base import BaseCheckpointSaver
from typing import Any, Optional

class CustomCheckpointer(BaseCheckpointSaver):
    def __init__(self, storage_backend):
        self.storage = storage_backend
    
    def get(self, config: dict) -> Optional[dict]:
        thread_id = config["configurable"]["thread_id"]
        return self.storage.load(thread_id)
    
    def put(self, config: dict, checkpoint: dict) -> None:
        thread_id = config["configurable"]["thread_id"]
        self.storage.save(thread_id, checkpoint)

分片存储 #

python
def get_sharded_checkpointer(user_id: str):
    shard = hash(user_id) % 10
    return PostgresSaver(f"postgresql://db{shard}/checkpoints")

加密存储 #

python
from cryptography.fernet import Fernet

class EncryptedCheckpointer:
    def __init__(self, backend, key):
        self.backend = backend
        self.cipher = Fernet(key)
    
    def save(self, thread_id: str, state: dict):
        encrypted = self.cipher.encrypt(str(state).encode())
        self.backend.save(thread_id, encrypted)
    
    def load(self, thread_id: str) -> dict:
        encrypted = self.backend.load(thread_id)
        return eval(self.cipher.decrypt(encrypted).decode())

最佳实践 #

1. 选择合适的 Checkpointer #

text
┌─────────────────────────────────────────────────────────────┐
│                    Checkpointer 选择指南                     │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  开发阶段:                                                  │
│  → MemorySaver                                              │
│                                                             │
│  单机部署:                                                  │
│  → SqliteSaver                                              │
│                                                             │
│  生产环境:                                                  │
│  → PostgresSaver 或 RedisSaver                              │
│                                                             │
│  高并发:                                                    │
│  → RedisSaver + 分片                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. 合理设计 Thread ID #

python
import uuid

def get_thread_id(user_id: str, session_type: str) -> str:
    return f"{user_id}:{session_type}:{uuid.uuid4()}"

config = {
    "configurable": {
        "thread_id": get_thread_id("user-123", "chat")
    }
}

3. 定期清理 #

python
import asyncio

async def cleanup_task():
    while True:
        await asyncio.sleep(3600)
        cleanup_old_checkpoints()

4. 监控存储 #

python
def monitor_storage():
    total_size = get_storage_size()
    total_threads = get_thread_count()
    
    if total_size > WARNING_THRESHOLD:
        alert("Storage usage high")

下一步 #

现在你已经掌握了记忆与持久化,接下来学习 人机交互,了解如何让人类参与 Agent 决策!

最后更新:2026-03-30