依赖注入 #

什么是依赖注入? #

依赖注入(Dependency Injection)是一种设计模式,用于将组件的依赖关系从组件内部移到外部管理。

text
┌─────────────────────────────────────────────────────────────┐
│                    依赖注入的概念                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   传统方式:                                                 │
│   def handler():                                            │
│       db = Database()        # 在函数内部创建依赖            │
│       user = db.get_user()                                  │
│       return user                                           │
│                                                             │
│   依赖注入:                                                 │
│   def handler(db: Database = Depends(get_db)):              │
│       user = db.get_user()   # 依赖由外部提供                │
│       return user                                           │
│                                                             │
│   优势:                                                    │
│   ✅ 解耦:组件不负责创建依赖                                │
│   ✅ 可测试:易于 mock 依赖                                  │
│   ✅ 可复用:依赖逻辑可共享                                  │
│   ✅ 清晰:依赖关系明确                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本用法 #

创建依赖 #

python
from fastapi import FastAPI, Depends

app = FastAPI()

def common_parameters(q: str | None = None, skip: int = 0, limit: int = 100):
    return {'q': q, 'skip': skip, 'limit': limit}

@app.get('/items/')
def read_items(commons: dict = Depends(common_parameters)):
    return commons

@app.get('/users/')
def read_users(commons: dict = Depends(common_parameters)):
    return commons

类作为依赖 #

python
from fastapi import FastAPI, Depends

app = FastAPI()

class CommonParams:
    def __init__(self, q: str | None = None, skip: int = 0, limit: int = 100):
        self.q = q
        self.skip = skip
        self.limit = limit

@app.get('/items/')
def read_items(commons: CommonParams = Depends(CommonParams)):
    return {
        'q': commons.q,
        'skip': commons.skip,
        'limit': commons.limit
    }

依赖类型 #

函数依赖 #

python
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.get('/users/')
def read_users(db: Session = Depends(get_db)):
    users = db.query(User).all()
    return users

类依赖 #

python
class Pagination:
    def __init__(self, page: int = 1, size: int = 10):
        self.page = page
        self.size = size
        self.offset = (page - 1) * size

@app.get('/items/')
def read_items(pagination: Pagination = Depends()):
    return {
        'page': pagination.page,
        'size': pagination.size,
        'offset': pagination.offset
    }

生成器依赖 #

python
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

def get_current_user(db: Session = Depends(get_db), token: str = Header(...)):
    user = verify_token(db, token)
    if not user:
        raise HTTPException(status_code=401, detail='Invalid token')
    return user

依赖缓存 #

默认缓存 #

python
@app.get('/items/')
def read_items(commons: dict = Depends(common_parameters)):
    return commons

@app.get('/users/')
def read_users(commons: dict = Depends(common_parameters)):
    return commons
# 同一请求中,commons 只计算一次

禁用缓存 #

python
@app.get('/items/')
def read_items(commons: dict = Depends(common_parameters, use_cache=False)):
    return commons

嵌套依赖 #

依赖链 #

python
from fastapi import FastAPI, Depends, HTTPException, Header

app = FastAPI()

def get_token(authorization: str = Header(...)):
    if not authorization.startswith('Bearer '):
        raise HTTPException(status_code=401, detail='Invalid authorization')
    return authorization.replace('Bearer ', '')

def get_current_user(token: str = Depends(get_token)):
    user = verify_token(token)
    if not user:
        raise HTTPException(status_code=401, detail='Invalid token')
    return user

@app.get('/users/me')
def read_users_me(current_user: User = Depends(get_current_user)):
    return current_user

依赖树 #

text
┌─────────────────────────────────────────────────────────────┐
│                    依赖树结构                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   read_users_me                                             │
│        │                                                    │
│        └── get_current_user                                 │
│                  │                                          │
│                  └── get_token                              │
│                         │                                   │
│                         └── Header                          │
│                                                             │
│   FastAPI 会按顺序解析整个依赖树                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

路径操作装饰器依赖 #

全局依赖 #

python
from fastapi import FastAPI, Depends

app = FastAPI()

def verify_token(token: str = Header(...)):
    if token != 'secret':
        raise HTTPException(status_code=401, detail='Invalid token')

@app.get('/items/', dependencies=[Depends(verify_token)])
def read_items():
    return [{'item': 'Item 1'}, {'item': 'Item 2'}]

路由器依赖 #

python
from fastapi import FastAPI, Depends, APIRouter

app = FastAPI()
router = APIRouter(dependencies=[Depends(verify_token)])

@router.get('/items/')
def read_items():
    return [{'item': 'Item 1'}]

@router.get('/users/')
def read_users():
    return [{'user': 'User 1'}]

app.include_router(router)

应用级依赖 #

python
app = FastAPI(dependencies=[Depends(verify_token)])

@app.get('/items/')
def read_items():
    return [{'item': 'Item 1'}]

全局依赖 #

创建全局依赖 #

python
from fastapi import FastAPI, Request
from time import time

async def add_process_time(request: Request, call_next):
    start_time = time()
    response = await call_next(request)
    process_time = time() - start_time
    response.headers['X-Process-Time'] = str(process_time)
    return response

app = FastAPI()
app.middleware('http')(add_process_time)

可调用实例 #

call 方法 #

python
class FixedContentQueryChecker:
    def __init__(self, fixed_content: str):
        self.fixed_content = fixed_content
    
    def __call__(self, q: str = ''):
        if q:
            return self.fixed_content in q
        return False

checker = FixedContentQueryChecker('bar')

@app.get('/items/')
def read_items(has_fixed_content: bool = Depends(checker)):
    return {'has_fixed_content': has_fixed_content}

子依赖 #

yield 依赖 #

python
from fastapi import FastAPI, Depends

app = FastAPI()

class DBSession:
    def __init__(self):
        print('Creating DB connection')
    
    def close(self):
        print('Closing DB connection')

def get_db():
    db = DBSession()
    try:
        yield db
    finally:
        db.close()

@app.get('/items/')
def read_items(db: DBSession = Depends(get_db)):
    return {'message': 'Items'}
# 请求结束后自动调用 db.close()

多个 yield 依赖 #

python
def get_db():
    db = DBSession()
    try:
        yield db
    finally:
        db.close()

def get_cache():
    cache = CacheClient()
    try:
        yield cache
    finally:
        cache.disconnect()

@app.get('/items/')
def read_items(
    db: DBSession = Depends(get_db),
    cache: CacheClient = Depends(get_cache)
):
    return {'message': 'Items'}

实战示例 #

数据库依赖 #

python
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session

DATABASE_URL = 'sqlite:///./test.db'

engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

app = FastAPI()

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.get('/users/')
def read_users(db: Session = Depends(get_db)):
    users = db.query(User).all()
    return users

认证依赖 #

python
from fastapi import FastAPI, Depends, HTTPException, Header
from pydantic import BaseModel

app = FastAPI()

class User(BaseModel):
    id: int
    username: str

fake_users_db = {
    'token123': User(id=1, username='john'),
    'token456': User(id=2, username='jane')
}

async def get_current_user(authorization: str = Header(...)):
    if not authorization.startswith('Bearer '):
        raise HTTPException(status_code=401, detail='Invalid authorization header')
    
    token = authorization.replace('Bearer ', '')
    
    if token not in fake_users_db:
        raise HTTPException(status_code=401, detail='Invalid token')
    
    return fake_users_db[token]

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    return current_user

@app.get('/users/me')
def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user

分页依赖 #

python
from fastapi import FastAPI, Depends, Query
from pydantic import BaseModel

app = FastAPI()

class Pagination:
    def __init__(
        self,
        page: int = Query(1, ge=1),
        size: int = Query(10, ge=1, le=100)
    ):
        self.page = page
        self.size = size
        self.offset = (page - 1) * size

@app.get('/items/')
def read_items(pagination: Pagination = Depends()):
    items = get_items_from_db(skip=pagination.offset, limit=pagination.size)
    return {
        'items': items,
        'page': pagination.page,
        'size': pagination.size
    }

权限依赖 #

python
from fastapi import FastAPI, Depends, HTTPException
from enum import Enum

app = FastAPI()

class Role(str, Enum):
    admin = 'admin'
    user = 'user'
    guest = 'guest'

class User(BaseModel):
    id: int
    username: str
    role: Role

def require_role(required_role: Role):
    def role_checker(current_user: User = Depends(get_current_user)):
        if current_user.role != required_role and current_user.role != Role.admin:
            raise HTTPException(status_code=403, detail='Not enough permissions')
        return current_user
    return role_checker

@app.get('/admin/')
def admin_only(user: User = Depends(require_role(Role.admin))):
    return {'message': 'Admin area'}

依赖覆盖 #

测试时覆盖 #

python
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient

app = FastAPI()

def get_db():
    return 'production_db'

@app.get('/items/')
def read_items(db: str = Depends(get_db)):
    return {'db': db}

def override_get_db():
    return 'test_db'

app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)

def test_read_items():
    response = client.get('/items/')
    assert response.json() == {'db': 'test_db'}

完整示例 #

python
from typing import Optional
from fastapi import FastAPI, Depends, HTTPException, Header, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session

app = FastAPI()

class User(BaseModel):
    id: int
    username: str
    email: str
    is_active: bool = True

class Pagination:
    def __init__(
        self,
        page: int = Query(1, ge=1),
        size: int = Query(10, ge=1, le=100)
    ):
        self.page = page
        self.size = size
        self.offset = (page - 1) * size

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

async def get_current_user(
    authorization: Optional[str] = Header(None)
) -> Optional[User]:
    if not authorization:
        return None
    
    if not authorization.startswith('Bearer '):
        raise HTTPException(status_code=401, detail='Invalid authorization')
    
    token = authorization.replace('Bearer ', '')
    user = verify_token(token)
    
    if not user:
        raise HTTPException(status_code=401, detail='Invalid token')
    
    return user

async def get_current_active_user(
    current_user: User = Depends(get_current_user)
) -> User:
    if not current_user:
        raise HTTPException(status_code=401, detail='Not authenticated')
    
    if not current_user.is_active:
        raise HTTPException(status_code=400, detail='Inactive user')
    
    return current_user

@app.get('/items/')
def read_items(
    db: Session = Depends(get_db),
    pagination: Pagination = Depends(),
    current_user: Optional[User] = Depends(get_current_user)
):
    items = db.query(Item).offset(pagination.offset).limit(pagination.size).all()
    return {
        'items': items,
        'page': pagination.page,
        'size': pagination.size,
        'user': current_user
    }

@app.get('/users/me')
def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user

下一步 #

现在你已经掌握了依赖注入,接下来学习 中间件,了解 FastAPI 的请求处理机制!

最后更新:2026-03-29