中间件 #
什么是中间件? #
中间件是一个在请求到达路由处理器之前和响应返回客户端之后执行的函数。
text
┌─────────────────────────────────────────────────────────────┐
│ 中间件工作流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 客户端 ────> 中间件1 ────> 中间件2 ────> 路由处理器 │
│ │
│ 客户端 <──── 中间件1 <──── 中间件2 <──── 路由处理器 │
│ │
│ 请求阶段:依次执行 │
│ 响应阶段:逆序执行 │
│ │
└─────────────────────────────────────────────────────────────┘
基本用法 #
创建中间件 #
python
from fastapi import FastAPI, Request
app = FastAPI()
@app.middleware('http')
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers['X-Process-Time'] = str(process_time)
return response
中间件顺序 #
python
@app.middleware('http')
async def middleware1(request: Request, call_next):
print('Middleware 1 - Before')
response = await call_next(request)
print('Middleware 1 - After')
return response
@app.middleware('http')
async def middleware2(request: Request, call_next):
print('Middleware 2 - Before')
response = await call_next(request)
print('Middleware 2 - After')
return response
# 输出顺序:
# Middleware 2 - Before
# Middleware 1 - Before
# Middleware 1 - After
# Middleware 2 - After
内置中间件 #
CORS 中间件 #
python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
CORS 配置详解 #
python
app.add_middleware(
CORSMiddleware,
allow_origins=[
'https://example.com',
'https://www.example.com',
],
allow_credentials=True,
allow_methods=['GET', 'POST', 'PUT', 'DELETE'],
allow_headers=['*'],
expose_headers=['X-Total-Count'],
max_age=600,
)
HTTPS 重定向 #
python
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
app = FastAPI()
app.add_middleware(HTTPSRedirectMiddleware)
Trusted Host #
python
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=['example.com', '*.example.com']
)
GZip 压缩 #
python
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=1000)
自定义中间件 #
类方式 #
python
from starlette.middleware.base import BaseHTTPMiddleware
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers['Custom-Header'] = 'value'
return response
app.add_middleware(CustomMiddleware)
函数方式 #
python
@app.middleware('http')
async def custom_middleware(request: Request, call_next):
response = await call_next(request)
response.headers['X-Custom'] = 'value'
return response
常用中间件示例 #
请求日志 #
python
import time
import logging
from fastapi import FastAPI, Request
logger = logging.getLogger(__name__)
app = FastAPI()
@app.middleware('http')
async def log_requests(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
formatted_time = f'{process_time:.2f}ms'
logger.info(
f'{request.method} {request.url.path} - '
f'{response.status_code} - {formatted_time}'
)
return response
请求计时 #
python
@app.middleware('http')
async def add_process_time(request: Request, call_next):
start_time = time.perf_counter()
response = await call_next(request)
process_time = time.perf_counter() - start_time
response.headers['X-Process-Time'] = f'{process_time:.4f}'
return response
请求 ID #
python
import uuid
@app.middleware('http')
async def add_request_id(request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers['X-Request-ID'] = request_id
return response
IP 限制 #
python
from fastapi import FastAPI, Request, HTTPException
app = FastAPI()
ALLOWED_IPS = ['127.0.0.1', '192.168.1.0/24']
@app.middleware('http')
async def ip_filter(request: Request, call_next):
client_ip = request.client.host
if client_ip not in ALLOWED_IPS:
raise HTTPException(status_code=403, detail='IP not allowed')
return await call_next(request)
速率限制 #
python
from collections import defaultdict
from time import time
request_counts = defaultdict(list)
@app.middleware('http')
async def rate_limit(request: Request, call_next):
client_ip = request.client.host
current_time = time()
request_counts[client_ip] = [
t for t in request_counts[client_ip]
if current_time - t < 60
]
if len(request_counts[client_ip]) >= 100:
raise HTTPException(status_code=429, detail='Too many requests')
request_counts[client_ip].append(current_time)
return await call_next(request)
异常处理 #
python
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
app = FastAPI()
@app.middleware('http')
async def catch_exceptions(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return JSONResponse(
status_code=500,
content={'detail': 'Internal server error'}
)
认证中间件 #
python
from fastapi import FastAPI, Request, HTTPException
app = FastAPI()
PUBLIC_PATHS = ['/login', '/register', '/docs', '/openapi.json']
@app.middleware('http')
async def auth_middleware(request: Request, call_next):
if request.url.path in PUBLIC_PATHS:
return await call_next(request)
authorization = request.headers.get('Authorization')
if not authorization:
raise HTTPException(status_code=401, detail='Not authenticated')
token = authorization.replace('Bearer ', '')
user = verify_token(token)
if not user:
raise HTTPException(status_code=401, detail='Invalid token')
request.state.user = user
return await call_next(request)
中间件最佳实践 #
1. 保持简单 #
python
@app.middleware('http')
async def simple_middleware(request: Request, call_next):
return await call_next(request)
2. 避免阻塞 #
python
@app.middleware('http')
async def non_blocking_middleware(request: Request, call_next):
result = await some_async_operation()
response = await call_next(request)
return response
3. 使用 request.state #
python
@app.middleware('http')
async def add_user_to_state(request: Request, call_next):
request.state.user = get_current_user(request)
return await call_next(request)
@app.get('/me')
def get_me(request: Request):
return request.state.user
4. 条件执行 #
python
@app.middleware('http')
async def conditional_middleware(request: Request, call_next):
if request.url.path.startswith('/api/'):
response = await call_next(request)
response.headers['X-API-Version'] = '1.0'
return response
return await call_next(request)
完整示例 #
python
import time
import uuid
import logging
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
@app.middleware('http')
async def add_request_id(request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers['X-Request-ID'] = request_id
return response
@app.middleware('http')
async def log_requests(request: Request, call_next):
start_time = time.time()
logger.info(f'Request started: {request.method} {request.url.path}')
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
logger.info(
f'Request completed: {request.method} {request.url.path} - '
f'{response.status_code} - {process_time:.2f}ms'
)
response.headers['X-Process-Time'] = f'{process_time:.2f}ms'
return response
@app.middleware('http')
async def error_handler(request: Request, call_next):
try:
return await call_next(request)
except HTTPException:
raise
except Exception as e:
logger.error(f'Unhandled error: {e}')
return JSONResponse(
status_code=500,
content={'detail': 'Internal server error'}
)
@app.get('/')
def read_root():
return {'message': 'Hello World'}
@app.get('/items/{item_id}')
def read_item(item_id: int):
return {'item_id': item_id}
下一步 #
现在你已经掌握了中间件,接下来学习 安全认证,了解 FastAPI 的安全机制!
最后更新:2026-03-29