工具(Tools) #

工具是代理与外部世界交互的桥梁。通过工具,LLM 可以搜索网络、查询数据库、调用 API、执行代码等,极大地扩展了模型的能力边界。

工具系统概览 #

text
┌─────────────────────────────────────────────────────────────┐
│                    工具的作用                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  没有工具的 LLM:                                            │
│  用户: "北京今天天气怎么样?"                                │
│  AI: "抱歉,我无法获取实时天气信息..."                       │
│                                                             │
│  有工具的 LLM:                                              │
│  用户: "北京今天天气怎么样?"                                │
│  AI: [调用 get_weather("北京")]                             │
│  AI: "北京今天晴,温度 25°C,适合出行。"                     │
│                                                             │
│  工具让 LLM 能够:                                           │
│  ✅ 访问实时信息                                            │
│  ✅ 执行计算和代码                                          │
│  ✅ 操作外部系统                                            │
│  ✅ 查询数据库                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

工具定义方式 #

1. @tool 装饰器(推荐) #

最简单的工具定义方式:

python
from langchain_core.tools import tool

@tool
def get_weather(city: str) -> str:
    """获取指定城市的天气信息"""
    weather_data = {
        "北京": "晴,25°C",
        "上海": "多云,28°C",
        "广州": "雨,30°C"
    }
    return weather_data.get(city, f"未找到 {city} 的天气信息")

@tool
def calculate(expression: str) -> float:
    """计算数学表达式
    
    Args:
        expression: 数学表达式,如 "1+2*3"
    """
    try:
        return eval(expression)
    except:
        return "计算错误"

2. 带完整文档的工具 #

python
@tool
def search_database(
    query: str,
    table: str = "users",
    limit: int = 10
) -> str:
    """搜索数据库
    
    根据查询条件搜索指定数据表。
    支持模糊匹配和多条件筛选。
    
    Args:
        query: 搜索关键词或 SQL WHERE 条件
        table: 要搜索的表名,默认为 users
        limit: 返回结果的最大数量,默认为 10
    
    Returns:
        匹配的记录列表,JSON 格式
    
    Example:
        search_database("name LIKE '%张%'", "customers", 5)
    """
    return f"在 {table} 表中搜索: {query},返回 {limit} 条结果"

3. StructuredTool #

用于复杂参数结构:

python
from langchain.tools import StructuredTool
from pydantic import BaseModel, Field

class EmailInput(BaseModel):
    """发送邮件的参数"""
    to: str = Field(description="收件人邮箱地址")
    subject: str = Field(description="邮件主题")
    body: str = Field(description="邮件正文")
    cc: list[str] | None = Field(default=None, description="抄送列表")
    priority: str = Field(default="normal", description="优先级: high, normal, low")

def send_email_func(to: str, subject: str, body: str, cc: list = None, priority: str = "normal") -> str:
    return f"邮件已发送至 {to},主题: {subject},优先级: {priority}"

send_email = StructuredTool(
    name="send_email",
    description="发送电子邮件",
    func=send_email_func,
    args_schema=EmailInput
)

4. BaseTool 类 #

用于复杂逻辑的工具:

python
from langchain.tools import BaseTool
from typing import Optional, Type
from pydantic import BaseModel

class SearchInput(BaseModel):
    query: str
    max_results: int = 5

class WebSearchTool(BaseTool):
    name = "web_search"
    description = "搜索网络获取信息"
    args_schema: Type[BaseModel] = SearchInput
    
    def _run(self, query: str, max_results: int = 5) -> str:
        """同步执行"""
        return f"搜索 '{query}',返回 {max_results} 条结果"
    
    async def _arun(self, query: str, max_results: int = 5) -> str:
        """异步执行"""
        return self._run(query, max_results)
text
┌─────────────────────────────────────────────────────────────┐
│                    工具定义方式对比                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  @tool 装饰器                                               │
│  ✅ 最简单,推荐使用                                        │
│  ✅ 自动从函数签名推断参数                                  │
│  ✅ 自动从 docstring 提取描述                               │
│                                                             │
│  StructuredTool                                             │
│  ✅ 支持复杂参数结构                                        │
│  ✅ 使用 Pydantic 验证                                      │
│  ✅ 更详细的参数描述                                        │
│                                                             │
│  BaseTool 类                                                │
│  ✅ 完全自定义                                              │
│  ✅ 支持异步                                                │
│  ✅ 复杂内部逻辑                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

内置工具 #

搜索工具 #

python
from langchain_community.tools import DuckDuckGoSearchRun

# DuckDuckGo 搜索(免费,无需 API Key)
search = DuckDuckGoSearchRun()
result = search.run("LangChain 教程")

# Google 搜索(需要 API Key)
from langchain_community.utilities import GoogleSearchAPIWrapper
from langchain_community.tools import GoogleSearchRun

google_search = GoogleSearchRun(
    api_wrapper=GoogleSearchAPIWrapper()
)

# Bing 搜索
from langchain_community.utilities import BingSearchAPIWrapper
from langchain_community.tools import BingSearchRun

Python 执行 #

python
from langchain_experimental.tools import PythonREPLTool

python_repl = PythonREPLTool()
result = python_repl.run("""
import math
print(math.sqrt(16))
print([x**2 for x in range(5)])
""")

文件操作 #

python
from langchain_community.tools.file_management import (
    ReadFileTool,
    WriteFileTool,
    ListDirectoryTool,
    CopyFileTool,
    MoveFileTool,
    DeleteFileTool,
)

read_file = ReadFileTool()
write_file = WriteFileTool()
list_dir = ListDirectoryTool()

数据库工具 #

python
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import (
    QuerySQLDataBaseTool,
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLCheckerTool,
)

db = SQLDatabase.from_uri("sqlite:///mydb.db")

# 执行 SQL 查询
query_tool = QuerySQLDataBaseTool(db=db)

# 获取表结构信息
info_tool = InfoSQLDatabaseTool(db=db)

# 列出所有表
list_tool = ListSQLDatabaseTool(db=db)

# 检查 SQL 语法
checker_tool = QuerySQLCheckerTool(db=db, llm=model)

Shell 命令 #

python
from langchain_community.tools import ShellTool

shell = ShellTool()
result = shell.run("ls -la")

Wikipedia #

python
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper

wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
result = wikipedia.run("Python programming language")

自定义工具开发 #

API 调用工具 #

python
import requests
from langchain_core.tools import tool

@tool
def call_api(
    endpoint: str,
    method: str = "GET",
    params: dict = None,
    data: dict = None
) -> str:
    """调用 REST API
    
    Args:
        endpoint: API 端点 URL
        method: HTTP 方法 (GET, POST, PUT, DELETE)
        params: URL 查询参数
        data: 请求体数据
    """
    try:
        response = requests.request(
            method=method,
            url=endpoint,
            params=params,
            json=data,
            timeout=10
        )
        return response.text
    except Exception as e:
        return f"API 调用失败: {str(e)}"

数据库查询工具 #

python
import sqlite3
from langchain_core.tools import tool

@tool
def query_sqlite(database: str, query: str) -> str:
    """执行 SQLite 查询
    
    Args:
        database: 数据库文件路径
        query: SQL 查询语句
    """
    try:
        conn = sqlite3.connect(database)
        cursor = conn.cursor()
        cursor.execute(query)
        
        if query.strip().upper().startswith("SELECT"):
            results = cursor.fetchall()
            return str(results)
        else:
            conn.commit()
            return f"执行成功,影响 {cursor.rowcount} 行"
            
    except Exception as e:
        return f"查询失败: {str(e)}"
    finally:
        conn.close()

文件处理工具 #

python
import os
from langchain_core.tools import tool
from typing import List

@tool
def list_files(directory: str, pattern: str = "*") -> List[str]:
    """列出目录中的文件
    
    Args:
        directory: 目录路径
        pattern: 文件匹配模式,如 "*.py"
    """
    import glob
    files = glob.glob(os.path.join(directory, pattern))
    return [os.path.basename(f) for f in files]

@tool
def read_file_content(file_path: str, encoding: str = "utf-8") -> str:
    """读取文件内容
    
    Args:
        file_path: 文件路径
        encoding: 文件编码
    """
    try:
        with open(file_path, "r", encoding=encoding) as f:
            return f.read()
    except Exception as e:
        return f"读取失败: {str(e)}"

@tool
def write_file_content(file_path: str, content: str, mode: str = "w") -> str:
    """写入文件内容
    
    Args:
        file_path: 文件路径
        content: 要写入的内容
        mode: 写入模式 (w: 覆盖, a: 追加)
    """
    try:
        with open(file_path, mode, encoding="utf-8") as f:
            f.write(content)
        return f"写入成功: {file_path}"
    except Exception as e:
        return f"写入失败: {str(e)}"

邮件发送工具 #

python
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from langchain_core.tools import tool

@tool
def send_email(
    to: str,
    subject: str,
    body: str,
    smtp_server: str = "smtp.gmail.com",
    smtp_port: int = 587,
    from_email: str = None,
    password: str = None
) -> str:
    """发送电子邮件
    
    Args:
        to: 收件人邮箱
        subject: 邮件主题
        body: 邮件正文
        smtp_server: SMTP 服务器地址
        smtp_port: SMTP 端口
        from_email: 发件人邮箱
        password: 邮箱密码或应用专用密码
    """
    try:
        msg = MIMEMultipart()
        msg['From'] = from_email
        msg['To'] = to
        msg['Subject'] = subject
        msg.attach(MIMEText(body, 'plain'))
        
        server = smtplib.SMTP(smtp_server, smtp_port)
        server.starttls()
        server.login(from_email, password)
        server.send_message(msg)
        server.quit()
        
        return f"邮件已发送至 {to}"
    except Exception as e:
        return f"发送失败: {str(e)}"

日历工具 #

python
from datetime import datetime, timedelta
from langchain_core.tools import tool
from typing import List, Dict

# 简单的内存日历
calendar_events: List[Dict] = []

@tool
def add_event(
    title: str,
    start_time: str,
    duration_minutes: int = 60,
    description: str = ""
) -> str:
    """添加日历事件
    
    Args:
        title: 事件标题
        start_time: 开始时间 (格式: YYYY-MM-DD HH:MM)
        duration_minutes: 持续时间(分钟)
        description: 事件描述
    """
    try:
        start = datetime.strptime(start_time, "%Y-%m-%d %H:%M")
        end = start + timedelta(minutes=duration_minutes)
        
        event = {
            "title": title,
            "start": start,
            "end": end,
            "description": description
        }
        calendar_events.append(event)
        
        return f"事件已添加: {title} ({start_time})"
    except Exception as e:
        return f"添加失败: {str(e)}"

@tool
def list_events(date: str = None) -> str:
    """列出日历事件
    
    Args:
        date: 日期 (格式: YYYY-MM-DD),不指定则列出所有
    """
    if date:
        target_date = datetime.strptime(date, "%Y-%m-%d").date()
        events = [
            e for e in calendar_events
            if e["start"].date() == target_date
        ]
    else:
        events = calendar_events
    
    if not events:
        return "没有找到事件"
    
    result = []
    for e in events:
        result.append(
            f"{e['start'].strftime('%Y-%m-%d %H:%M')} - {e['title']}"
        )
    return "\n".join(result)

工具组合 #

工具集 #

python
from langchain_core.tools import tool

class ToolSet:
    """工具集合"""
    
    @staticmethod
    @tool
    def get_time() -> str:
        """获取当前时间"""
        from datetime import datetime
        return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    @staticmethod
    @tool
    def get_date() -> str:
        """获取当前日期"""
        from datetime import datetime
        return datetime.now().strftime("%Y-%m-%d")
    
    @staticmethod
    @tool
    def get_weekday() -> str:
        """获取当前星期几"""
        from datetime import datetime
        weekdays = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
        return weekdays[datetime.now().weekday()]

# 获取所有工具
tools = [
    ToolSet.get_time,
    ToolSet.get_date,
    ToolSet.get_weekday
]

工具链 #

python
from langchain_core.tools import tool

@tool
def search_web(query: str) -> str:
    """搜索网络"""
    return f"搜索结果: {query}"

@tool
def summarize_text(text: str) -> str:
    """总结文本"""
    return f"摘要: {text[:100]}..."

@tool
def translate_text(text: str, target_lang: str = "英文") -> str:
    """翻译文本"""
    return f"翻译({target_lang}): {text}"

# 工具可以组合使用
# Agent 会自动决定调用顺序
tools = [search_web, summarize_text, translate_text]

工具配置 #

工具名称和描述 #

python
@tool
def my_tool(input: str) -> str:
    """工具描述
    
    详细说明工具的功能和使用方法。
    这段描述会被 LLM 用来决定是否使用该工具。
    """
    return input

# 也可以手动设置
my_tool.name = "custom_name"
my_tool.description = "自定义描述"

返回类型 #

python
from typing import Union
from langchain_core.tools import tool

@tool
def flexible_tool(input: str) -> Union[str, dict, list]:
    """返回多种类型的工具"""
    if input.startswith("json:"):
        return {"result": input[5:]}
    elif input.startswith("list:"):
        return input[5:].split(",")
    else:
        return input

最佳实践 #

1. 清晰的描述 #

python
# 好的描述
@tool
def search_product(query: str, category: str = None) -> str:
    """搜索产品数据库
    
    根据关键词搜索产品。支持按类别过滤。
    返回匹配的产品列表,包含名称、价格、库存等信息。
    
    Args:
        query: 搜索关键词
        category: 可选的产品类别
    """
    pass

# 不好的描述
@tool
def search(q: str) -> str:
    """搜索"""
    pass

2. 错误处理 #

python
@tool
def safe_api_call(endpoint: str) -> str:
    """安全的 API 调用"""
    try:
        response = requests.get(endpoint, timeout=10)
        response.raise_for_status()
        return response.text
    except requests.Timeout:
        return "请求超时,请稍后重试"
    except requests.RequestException as e:
        return f"请求失败: {str(e)}"

3. 参数验证 #

python
from pydantic import BaseModel, Field, validator

class TransferInput(BaseModel):
    from_account: str = Field(description="转出账户")
    to_account: str = Field(description="转入账户")
    amount: float = Field(description="转账金额", gt=0)
    
    @validator('amount')
    def validate_amount(cls, v):
        if v <= 0:
            raise ValueError('金额必须大于 0')
        return v

@tool(args_schema=TransferInput)
def transfer_money(from_account: str, to_account: str, amount: float) -> str:
    """转账"""
    return f"已从 {from_account} 转账 {amount} 元到 {to_account}"

4. 异步支持 #

python
import aiohttp
from langchain_core.tools import tool

@tool
async def async_api_call(url: str) -> str:
    """异步 API 调用"""
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            return await response.text()

下一步 #

最后更新:2026-03-30