DRF 自定义认证 #

一、自定义认证基础 #

1.1 认证类结构 #

python
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User

class CustomAuthentication(authentication.BaseAuthentication):
    def authenticate(self, request):
        """
        认证方法,返回(user, auth)或None
        """
        return None

1.2 认证流程 #

text
请求 → 认证类.authenticate() → 返回(user, auth)
              │
              ├── 成功:返回(user, auth)
              ├── 失败:抛出AuthenticationFailed
              └── 跳过:返回None

二、API Key认证 #

2.1 模型定义 #

python
from django.db import models
from django.contrib.auth.models import User
import secrets

class APIKey(models.Model):
    user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='api_keys')
    key = models.CharField(max_length=64, unique=True, default=secrets.token_urlsafe)
    name = models.CharField(max_length=100)
    is_active = models.BooleanField(default=True)
    created_at = models.DateTimeField(auto_now_add=True)
    expires_at = models.DateTimeField(null=True, blank=True)
    last_used = models.DateTimeField(null=True, blank=True)
    
    class Meta:
        verbose_name = 'API Key'
        verbose_name_plural = 'API Keys'
    
    def __str__(self):
        return f'{self.name} - {self.user.username}'

2.2 认证类实现 #

python
from rest_framework import authentication, exceptions
from django.utils import timezone
from .models import APIKey

class APIKeyAuthentication(authentication.BaseAuthentication):
    keyword = 'ApiKey'
    
    def authenticate(self, request):
        auth = request.META.get('HTTP_AUTHORIZATION', '')
        
        if not auth:
            return None
        
        parts = auth.split()
        
        if len(parts) != 2:
            raise exceptions.AuthenticationFailed('无效的认证格式')
        
        if parts[0] != self.keyword:
            return None
        
        return self.authenticate_credentials(parts[1])
    
    def authenticate_credentials(self, key):
        try:
            api_key = APIKey.objects.select_related('user').get(key=key)
        except APIKey.DoesNotExist:
            raise exceptions.AuthenticationFailed('无效的API Key')
        
        if not api_key.is_active:
            raise exceptions.AuthenticationFailed('API Key已禁用')
        
        if api_key.expires_at and api_key.expires_at < timezone.now():
            raise exceptions.AuthenticationFailed('API Key已过期')
        
        api_key.last_used = timezone.now()
        api_key.save(update_fields=['last_used'])
        
        return (api_key.user, api_key)

2.3 使用API Key #

python
class ArticleViewSet(viewsets.ModelViewSet):
    authentication_classes = [APIKeyAuthentication]
    permission_classes = [IsAuthenticated]
    queryset = Article.objects.all()
    serializer_class = ArticleSerializer

2.4 请求示例 #

bash
curl -X GET http://api.example.com/api/articles/ \
  -H "Authorization: ApiKey your-api-key-here"

三、OAuth2认证 #

3.1 安装依赖 #

bash
pip install django-oauth-toolkit

3.2 配置 #

python
INSTALLED_APPS = [
    ...
    'oauth2_provider',
]

REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'oauth2_provider.contrib.rest_framework.OAuth2Authentication',
    ],
}

3.3 自定义OAuth认证 #

python
from rest_framework import authentication, exceptions
from oauth2_provider.models import AccessToken

class CustomOAuthAuthentication(authentication.BaseAuthentication):
    def authenticate(self, request):
        token = self.get_token(request)
        
        if not token:
            return None
        
        try:
            access_token = AccessToken.objects.select_related('user', 'application').get(
                token=token
            )
        except AccessToken.DoesNotExist:
            raise exceptions.AuthenticationFailed('无效的访问令牌')
        
        if access_token.is_expired():
            raise exceptions.AuthenticationFailed('访问令牌已过期')
        
        return (access_token.user, access_token)
    
    def get_token(self, request):
        auth = request.META.get('HTTP_AUTHORIZATION', '')
        
        if auth.startswith('Bearer '):
            return auth.split(' ')[1]
        
        return request.GET.get('access_token')

四、第三方登录 #

4.1 微信登录 #

python
import requests
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User

class WeChatAuthentication(authentication.BaseAuthentication):
    def authenticate(self, request):
        code = request.META.get('HTTP_X_WECHAT_CODE')
        
        if not code:
            return None
        
        openid = self.get_openid(code)
        
        if not openid:
            raise exceptions.AuthenticationFailed('微信认证失败')
        
        try:
            user = User.objects.get(username=f'wx_{openid}')
        except User.DoesNotExist:
            user = self.create_user(openid)
        
        return (user, None)
    
    def get_openid(self, code):
        url = 'https://api.weixin.qq.com/sns/jscode2session'
        params = {
            'appid': settings.WECHAT_APPID,
            'secret': settings.WECHAT_SECRET,
            'js_code': code,
            'grant_type': 'authorization_code'
        }
        
        response = requests.get(url, params=params)
        data = response.json()
        
        return data.get('openid')
    
    def create_user(self, openid):
        user = User.objects.create_user(
            username=f'wx_{openid}',
        )
        return user

4.2 GitHub登录 #

python
import requests
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User

class GitHubAuthentication(authentication.BaseAuthentication):
    def authenticate(self, request):
        token = request.META.get('HTTP_X_GITHUB_TOKEN')
        
        if not token:
            return None
        
        user_data = self.get_github_user(token)
        
        if not user_data:
            raise exceptions.AuthenticationFailed('GitHub认证失败')
        
        user = self.get_or_create_user(user_data)
        
        return (user, None)
    
    def get_github_user(self, token):
        response = requests.get(
            'https://api.github.com/user',
            headers={'Authorization': f'token {token}'}
        )
        
        if response.status_code != 200:
            return None
        
        return response.json()
    
    def get_or_create_user(self, user_data):
        github_id = user_data['id']
        username = f'gh_{github_id}'
        
        try:
            return User.objects.get(username=username)
        except User.DoesNotExist:
            return User.objects.create_user(
                username=username,
                email=user_data.get('email', ''),
                first_name=user_data.get('name', '')[:30]
            )

五、多因素认证 #

5.1 OTP模型 #

python
from django.db import models
from django.contrib.auth.models import User
import pyotp
import qrcode
from io import BytesIO

class UserOTP(models.Model):
    user = models.OneToOneField(User, on_delete=models.CASCADE)
    secret = models.CharField(max_length=32, default=pyotp.random_base32)
    is_enabled = models.BooleanField(default=False)
    
    def get_totp(self):
        return pyotp.TOTP(self.secret)
    
    def verify_otp(self, otp):
        return self.get_totp().verify(otp, valid_window=1)
    
    def get_qr_code(self):
        totp = self.get_totp()
        uri = totp.provisioning_uri(
            name=self.user.email,
            issuer_name='MyApp'
        )
        
        qr = qrcode.QRCode(version=1, box_size=10, border=5)
        qr.add_data(uri)
        qr.make(fit=True)
        
        img = qr.make_image(fill_color='black', back_color='white')
        buffer = BytesIO()
        img.save(buffer, format='PNG')
        
        return buffer.getvalue()

5.2 OTP认证 #

python
from rest_framework import authentication, exceptions
from .models import UserOTP

class OTPAuthentication(authentication.BaseAuthentication):
    def authenticate(self, request):
        user = getattr(request, '_cached_user', None)
        otp = request.META.get('HTTP_X_OTP')
        
        if not user or not otp:
            return None
        
        try:
            user_otp = UserOTP.objects.get(user=user, is_enabled=True)
        except UserOTP.DoesNotExist:
            return None
        
        if not user_otp.verify_otp(otp):
            raise exceptions.AuthenticationFailed('无效的OTP')
        
        return (user, None)

六、IP白名单认证 #

6.1 实现 #

python
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User

class IPWhitelistAuthentication(authentication.BaseAuthentication):
    WHITELIST = [
        '127.0.0.1',
        '192.168.1.0/24',
    ]
    
    def authenticate(self, request):
        ip = self.get_client_ip(request)
        
        if not self.is_ip_allowed(ip):
            raise exceptions.AuthenticationFailed('IP不在白名单中')
        
        return (User.objects.get_or_create(username='api_user')[0], None)
    
    def get_client_ip(self, request):
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            return x_forwarded_for.split(',')[0]
        return request.META.get('REMOTE_ADDR')
    
    def is_ip_allowed(self, ip):
        import ipaddress
        
        for allowed in self.WHITELIST:
            if '/' in allowed:
                if ipaddress.ip_address(ip) in ipaddress.ip_network(allowed):
                    return True
            else:
                if ip == allowed:
                    return True
        
        return False

七、签名认证 #

7.1 实现 #

python
import hashlib
import hmac
import time
from rest_framework import authentication, exceptions
from django.contrib.auth.models import User

class SignatureAuthentication(authentication.BaseAuthentication):
    def authenticate(self, request):
        timestamp = request.META.get('HTTP_X_TIMESTAMP')
        signature = request.META.get('HTTP_X_SIGNATURE')
        access_key = request.META.get('HTTP_X_ACCESS_KEY')
        
        if not all([timestamp, signature, access_key]):
            return None
        
        if not self.is_timestamp_valid(timestamp):
            raise exceptions.AuthenticationFailed('请求已过期')
        
        try:
            user = User.objects.get(profile__access_key=access_key)
        except User.DoesNotExist:
            raise exceptions.AuthenticationFailed('无效的Access Key')
        
        expected_signature = self.generate_signature(
            request.method,
            request.path,
            request.body,
            timestamp,
            user.profile.secret_key
        )
        
        if not hmac.compare_digest(signature, expected_signature):
            raise exceptions.AuthenticationFailed('签名验证失败')
        
        return (user, None)
    
    def is_timestamp_valid(self, timestamp):
        return abs(time.time() - int(timestamp)) < 300
    
    def generate_signature(self, method, path, body, timestamp, secret_key):
        message = f'{method}{path}{body}{timestamp}'.encode()
        return hmac.new(
            secret_key.encode(),
            message,
            hashlib.sha256
        ).hexdigest()

八、组合认证 #

8.1 多认证方式 #

python
class MultiAuthentication(authentication.BaseAuthentication):
    def __init__(self):
        self.authenticators = [
            TokenAuthentication(),
            JWTAuthentication(),
            APIKeyAuthentication(),
        ]
    
    def authenticate(self, request):
        for authenticator in self.authenticators:
            result = authenticator.authenticate(request)
            if result:
                return result
        return None

8.2 优先级认证 #

python
class PriorityAuthentication(authentication.BaseAuthentication):
    def authenticate(self, request):
        if request.META.get('HTTP_AUTHORIZATION', '').startswith('Bearer'):
            return JWTAuthentication().authenticate(request)
        elif request.META.get('HTTP_X_API_KEY'):
            return APIKeyAuthentication().authenticate(request)
        return None

九、认证中间件 #

9.1 自定义中间件 #

python
class AuthenticationMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        response = self.get_response(request)
        
        if hasattr(request, 'user') and request.user.is_authenticated:
            response['X-User-ID'] = str(request.user.id)
        
        return response

十、总结 #

本章学习了自定义认证:

  • API Key认证:实现API Key认证方案
  • OAuth认证:集成OAuth2认证
  • 第三方登录:微信、GitHub登录
  • 多因素认证:OTP认证
  • 签名认证:请求签名验证

自定义认证让我们能够实现各种复杂的认证需求!

最后更新:2026-03-28