中间件 #

什么是中间件? #

中间件是一个在请求到达路由处理器之前和响应返回客户端之后执行的函数。

text
┌─────────────────────────────────────────────────────────────┐
│                    中间件工作流程                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   客户端 ────> 中间件1 ────> 中间件2 ────> 路由处理器        │
│                                                             │
│   客户端 <──── 中间件1 <──── 中间件2 <──── 路由处理器        │
│                                                             │
│   请求阶段:依次执行                                         │
│   响应阶段:逆序执行                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本用法 #

创建中间件 #

python
from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware('http')
async def add_process_time_header(request: Request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers['X-Process-Time'] = str(process_time)
    return response

中间件顺序 #

python
@app.middleware('http')
async def middleware1(request: Request, call_next):
    print('Middleware 1 - Before')
    response = await call_next(request)
    print('Middleware 1 - After')
    return response

@app.middleware('http')
async def middleware2(request: Request, call_next):
    print('Middleware 2 - Before')
    response = await call_next(request)
    print('Middleware 2 - After')
    return response

# 输出顺序:
# Middleware 2 - Before
# Middleware 1 - Before
# Middleware 1 - After
# Middleware 2 - After

内置中间件 #

CORS 中间件 #

python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

CORS 配置详解 #

python
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        'https://example.com',
        'https://www.example.com',
    ],
    allow_credentials=True,
    allow_methods=['GET', 'POST', 'PUT', 'DELETE'],
    allow_headers=['*'],
    expose_headers=['X-Total-Count'],
    max_age=600,
)

HTTPS 重定向 #

python
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

app = FastAPI()
app.add_middleware(HTTPSRedirectMiddleware)

Trusted Host #

python
from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=['example.com', '*.example.com']
)

GZip 压缩 #

python
from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(GZipMiddleware, minimum_size=1000)

自定义中间件 #

类方式 #

python
from starlette.middleware.base import BaseHTTPMiddleware

class CustomMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)
        response.headers['Custom-Header'] = 'value'
        return response

app.add_middleware(CustomMiddleware)

函数方式 #

python
@app.middleware('http')
async def custom_middleware(request: Request, call_next):
    response = await call_next(request)
    response.headers['X-Custom'] = 'value'
    return response

常用中间件示例 #

请求日志 #

python
import time
import logging
from fastapi import FastAPI, Request

logger = logging.getLogger(__name__)

app = FastAPI()

@app.middleware('http')
async def log_requests(request: Request, call_next):
    start_time = time.time()
    
    response = await call_next(request)
    
    process_time = (time.time() - start_time) * 1000
    formatted_time = f'{process_time:.2f}ms'
    
    logger.info(
        f'{request.method} {request.url.path} - '
        f'{response.status_code} - {formatted_time}'
    )
    
    return response

请求计时 #

python
@app.middleware('http')
async def add_process_time(request: Request, call_next):
    start_time = time.perf_counter()
    response = await call_next(request)
    process_time = time.perf_counter() - start_time
    response.headers['X-Process-Time'] = f'{process_time:.4f}'
    return response

请求 ID #

python
import uuid

@app.middleware('http')
async def add_request_id(request: Request, call_next):
    request_id = str(uuid.uuid4())
    request.state.request_id = request_id
    response = await call_next(request)
    response.headers['X-Request-ID'] = request_id
    return response

IP 限制 #

python
from fastapi import FastAPI, Request, HTTPException

app = FastAPI()

ALLOWED_IPS = ['127.0.0.1', '192.168.1.0/24']

@app.middleware('http')
async def ip_filter(request: Request, call_next):
    client_ip = request.client.host
    
    if client_ip not in ALLOWED_IPS:
        raise HTTPException(status_code=403, detail='IP not allowed')
    
    return await call_next(request)

速率限制 #

python
from collections import defaultdict
from time import time

request_counts = defaultdict(list)

@app.middleware('http')
async def rate_limit(request: Request, call_next):
    client_ip = request.client.host
    current_time = time()
    
    request_counts[client_ip] = [
        t for t in request_counts[client_ip]
        if current_time - t < 60
    ]
    
    if len(request_counts[client_ip]) >= 100:
        raise HTTPException(status_code=429, detail='Too many requests')
    
    request_counts[client_ip].append(current_time)
    
    return await call_next(request)

异常处理 #

python
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

app = FastAPI()

@app.middleware('http')
async def catch_exceptions(request: Request, call_next):
    try:
        return await call_next(request)
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={'detail': 'Internal server error'}
        )

认证中间件 #

python
from fastapi import FastAPI, Request, HTTPException

app = FastAPI()

PUBLIC_PATHS = ['/login', '/register', '/docs', '/openapi.json']

@app.middleware('http')
async def auth_middleware(request: Request, call_next):
    if request.url.path in PUBLIC_PATHS:
        return await call_next(request)
    
    authorization = request.headers.get('Authorization')
    
    if not authorization:
        raise HTTPException(status_code=401, detail='Not authenticated')
    
    token = authorization.replace('Bearer ', '')
    user = verify_token(token)
    
    if not user:
        raise HTTPException(status_code=401, detail='Invalid token')
    
    request.state.user = user
    
    return await call_next(request)

中间件最佳实践 #

1. 保持简单 #

python
@app.middleware('http')
async def simple_middleware(request: Request, call_next):
    return await call_next(request)

2. 避免阻塞 #

python
@app.middleware('http')
async def non_blocking_middleware(request: Request, call_next):
    result = await some_async_operation()
    response = await call_next(request)
    return response

3. 使用 request.state #

python
@app.middleware('http')
async def add_user_to_state(request: Request, call_next):
    request.state.user = get_current_user(request)
    return await call_next(request)

@app.get('/me')
def get_me(request: Request):
    return request.state.user

4. 条件执行 #

python
@app.middleware('http')
async def conditional_middleware(request: Request, call_next):
    if request.url.path.startswith('/api/'):
        response = await call_next(request)
        response.headers['X-API-Version'] = '1.0'
        return response
    
    return await call_next(request)

完整示例 #

python
import time
import uuid
import logging
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

app.add_middleware(GZipMiddleware, minimum_size=1000)

@app.middleware('http')
async def add_request_id(request: Request, call_next):
    request_id = str(uuid.uuid4())
    request.state.request_id = request_id
    response = await call_next(request)
    response.headers['X-Request-ID'] = request_id
    return response

@app.middleware('http')
async def log_requests(request: Request, call_next):
    start_time = time.time()
    
    logger.info(f'Request started: {request.method} {request.url.path}')
    
    response = await call_next(request)
    
    process_time = (time.time() - start_time) * 1000
    logger.info(
        f'Request completed: {request.method} {request.url.path} - '
        f'{response.status_code} - {process_time:.2f}ms'
    )
    
    response.headers['X-Process-Time'] = f'{process_time:.2f}ms'
    
    return response

@app.middleware('http')
async def error_handler(request: Request, call_next):
    try:
        return await call_next(request)
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f'Unhandled error: {e}')
        return JSONResponse(
            status_code=500,
            content={'detail': 'Internal server error'}
        )

@app.get('/')
def read_root():
    return {'message': 'Hello World'}

@app.get('/items/{item_id}')
def read_item(item_id: int):
    return {'item_id': item_id}

下一步 #

现在你已经掌握了中间件,接下来学习 安全认证,了解 FastAPI 的安全机制!

最后更新:2026-03-29