DRF 自定义认证 #
一、自定义认证基础 #
1.1 认证类结构 #
python
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User
class CustomAuthentication(authentication.BaseAuthentication):
def authenticate(self, request):
"""
认证方法,返回(user, auth)或None
"""
return None
1.2 认证流程 #
text
请求 → 认证类.authenticate() → 返回(user, auth)
│
├── 成功:返回(user, auth)
├── 失败:抛出AuthenticationFailed
└── 跳过:返回None
二、API Key认证 #
2.1 模型定义 #
python
from django.db import models
from django.contrib.auth.models import User
import secrets
class APIKey(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='api_keys')
key = models.CharField(max_length=64, unique=True, default=secrets.token_urlsafe)
name = models.CharField(max_length=100)
is_active = models.BooleanField(default=True)
created_at = models.DateTimeField(auto_now_add=True)
expires_at = models.DateTimeField(null=True, blank=True)
last_used = models.DateTimeField(null=True, blank=True)
class Meta:
verbose_name = 'API Key'
verbose_name_plural = 'API Keys'
def __str__(self):
return f'{self.name} - {self.user.username}'
2.2 认证类实现 #
python
from rest_framework import authentication, exceptions
from django.utils import timezone
from .models import APIKey
class APIKeyAuthentication(authentication.BaseAuthentication):
keyword = 'ApiKey'
def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '')
if not auth:
return None
parts = auth.split()
if len(parts) != 2:
raise exceptions.AuthenticationFailed('无效的认证格式')
if parts[0] != self.keyword:
return None
return self.authenticate_credentials(parts[1])
def authenticate_credentials(self, key):
try:
api_key = APIKey.objects.select_related('user').get(key=key)
except APIKey.DoesNotExist:
raise exceptions.AuthenticationFailed('无效的API Key')
if not api_key.is_active:
raise exceptions.AuthenticationFailed('API Key已禁用')
if api_key.expires_at and api_key.expires_at < timezone.now():
raise exceptions.AuthenticationFailed('API Key已过期')
api_key.last_used = timezone.now()
api_key.save(update_fields=['last_used'])
return (api_key.user, api_key)
2.3 使用API Key #
python
class ArticleViewSet(viewsets.ModelViewSet):
authentication_classes = [APIKeyAuthentication]
permission_classes = [IsAuthenticated]
queryset = Article.objects.all()
serializer_class = ArticleSerializer
2.4 请求示例 #
bash
curl -X GET http://api.example.com/api/articles/ \
-H "Authorization: ApiKey your-api-key-here"
三、OAuth2认证 #
3.1 安装依赖 #
bash
pip install django-oauth-toolkit
3.2 配置 #
python
INSTALLED_APPS = [
...
'oauth2_provider',
]
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'oauth2_provider.contrib.rest_framework.OAuth2Authentication',
],
}
3.3 自定义OAuth认证 #
python
from rest_framework import authentication, exceptions
from oauth2_provider.models import AccessToken
class CustomOAuthAuthentication(authentication.BaseAuthentication):
def authenticate(self, request):
token = self.get_token(request)
if not token:
return None
try:
access_token = AccessToken.objects.select_related('user', 'application').get(
token=token
)
except AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('无效的访问令牌')
if access_token.is_expired():
raise exceptions.AuthenticationFailed('访问令牌已过期')
return (access_token.user, access_token)
def get_token(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '')
if auth.startswith('Bearer '):
return auth.split(' ')[1]
return request.GET.get('access_token')
四、第三方登录 #
4.1 微信登录 #
python
import requests
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User
class WeChatAuthentication(authentication.BaseAuthentication):
def authenticate(self, request):
code = request.META.get('HTTP_X_WECHAT_CODE')
if not code:
return None
openid = self.get_openid(code)
if not openid:
raise exceptions.AuthenticationFailed('微信认证失败')
try:
user = User.objects.get(username=f'wx_{openid}')
except User.DoesNotExist:
user = self.create_user(openid)
return (user, None)
def get_openid(self, code):
url = 'https://api.weixin.qq.com/sns/jscode2session'
params = {
'appid': settings.WECHAT_APPID,
'secret': settings.WECHAT_SECRET,
'js_code': code,
'grant_type': 'authorization_code'
}
response = requests.get(url, params=params)
data = response.json()
return data.get('openid')
def create_user(self, openid):
user = User.objects.create_user(
username=f'wx_{openid}',
)
return user
4.2 GitHub登录 #
python
import requests
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User
class GitHubAuthentication(authentication.BaseAuthentication):
def authenticate(self, request):
token = request.META.get('HTTP_X_GITHUB_TOKEN')
if not token:
return None
user_data = self.get_github_user(token)
if not user_data:
raise exceptions.AuthenticationFailed('GitHub认证失败')
user = self.get_or_create_user(user_data)
return (user, None)
def get_github_user(self, token):
response = requests.get(
'https://api.github.com/user',
headers={'Authorization': f'token {token}'}
)
if response.status_code != 200:
return None
return response.json()
def get_or_create_user(self, user_data):
github_id = user_data['id']
username = f'gh_{github_id}'
try:
return User.objects.get(username=username)
except User.DoesNotExist:
return User.objects.create_user(
username=username,
email=user_data.get('email', ''),
first_name=user_data.get('name', '')[:30]
)
五、多因素认证 #
5.1 OTP模型 #
python
from django.db import models
from django.contrib.auth.models import User
import pyotp
import qrcode
from io import BytesIO
class UserOTP(models.Model):
user = models.OneToOneField(User, on_delete=models.CASCADE)
secret = models.CharField(max_length=32, default=pyotp.random_base32)
is_enabled = models.BooleanField(default=False)
def get_totp(self):
return pyotp.TOTP(self.secret)
def verify_otp(self, otp):
return self.get_totp().verify(otp, valid_window=1)
def get_qr_code(self):
totp = self.get_totp()
uri = totp.provisioning_uri(
name=self.user.email,
issuer_name='MyApp'
)
qr = qrcode.QRCode(version=1, box_size=10, border=5)
qr.add_data(uri)
qr.make(fit=True)
img = qr.make_image(fill_color='black', back_color='white')
buffer = BytesIO()
img.save(buffer, format='PNG')
return buffer.getvalue()
5.2 OTP认证 #
python
from rest_framework import authentication, exceptions
from .models import UserOTP
class OTPAuthentication(authentication.BaseAuthentication):
def authenticate(self, request):
user = getattr(request, '_cached_user', None)
otp = request.META.get('HTTP_X_OTP')
if not user or not otp:
return None
try:
user_otp = UserOTP.objects.get(user=user, is_enabled=True)
except UserOTP.DoesNotExist:
return None
if not user_otp.verify_otp(otp):
raise exceptions.AuthenticationFailed('无效的OTP')
return (user, None)
六、IP白名单认证 #
6.1 实现 #
python
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User
class IPWhitelistAuthentication(authentication.BaseAuthentication):
WHITELIST = [
'127.0.0.1',
'192.168.1.0/24',
]
def authenticate(self, request):
ip = self.get_client_ip(request)
if not self.is_ip_allowed(ip):
raise exceptions.AuthenticationFailed('IP不在白名单中')
return (User.objects.get_or_create(username='api_user')[0], None)
def get_client_ip(self, request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
return x_forwarded_for.split(',')[0]
return request.META.get('REMOTE_ADDR')
def is_ip_allowed(self, ip):
import ipaddress
for allowed in self.WHITELIST:
if '/' in allowed:
if ipaddress.ip_address(ip) in ipaddress.ip_network(allowed):
return True
else:
if ip == allowed:
return True
return False
七、签名认证 #
7.1 实现 #
python
import hashlib
import hmac
import time
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User
class SignatureAuthentication(authentication.BaseAuthentication):
def authenticate(self, request):
timestamp = request.META.get('HTTP_X_TIMESTAMP')
signature = request.META.get('HTTP_X_SIGNATURE')
access_key = request.META.get('HTTP_X_ACCESS_KEY')
if not all([timestamp, signature, access_key]):
return None
if not self.is_timestamp_valid(timestamp):
raise exceptions.AuthenticationFailed('请求已过期')
try:
user = User.objects.get(profile__access_key=access_key)
except User.DoesNotExist:
raise exceptions.AuthenticationFailed('无效的Access Key')
expected_signature = self.generate_signature(
request.method,
request.path,
request.body,
timestamp,
user.profile.secret_key
)
if not hmac.compare_digest(signature, expected_signature):
raise exceptions.AuthenticationFailed('签名验证失败')
return (user, None)
def is_timestamp_valid(self, timestamp):
return abs(time.time() - int(timestamp)) < 300
def generate_signature(self, method, path, body, timestamp, secret_key):
message = f'{method}{path}{body}{timestamp}'.encode()
return hmac.new(
secret_key.encode(),
message,
hashlib.sha256
).hexdigest()
八、组合认证 #
8.1 多认证方式 #
python
class MultiAuthentication(authentication.BaseAuthentication):
def __init__(self):
self.authenticators = [
TokenAuthentication(),
JWTAuthentication(),
APIKeyAuthentication(),
]
def authenticate(self, request):
for authenticator in self.authenticators:
result = authenticator.authenticate(request)
if result:
return result
return None
8.2 优先级认证 #
python
class PriorityAuthentication(authentication.BaseAuthentication):
def authenticate(self, request):
if request.META.get('HTTP_AUTHORIZATION', '').startswith('Bearer'):
return JWTAuthentication().authenticate(request)
elif request.META.get('HTTP_X_API_KEY'):
return APIKeyAuthentication().authenticate(request)
return None
九、认证中间件 #
9.1 自定义中间件 #
python
class AuthenticationMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
if hasattr(request, 'user') and request.user.is_authenticated:
response['X-User-ID'] = str(request.user.id)
return response
十、总结 #
本章学习了自定义认证:
- API Key认证:实现API Key认证方案
- OAuth认证:集成OAuth2认证
- 第三方登录:微信、GitHub登录
- 多因素认证:OTP认证
- 签名认证:请求签名验证
自定义认证让我们能够实现各种复杂的认证需求!
最后更新:2026-03-28