DRF 自定义序列化 #

一、自定义字段 #

1.1 创建自定义字段类 #

python
from rest_framework import serializers

class UpperCaseField(serializers.CharField):
    def to_representation(self, value):
        return value.upper()

    def to_internal_value(self, data):
        return data.upper()

class ArticleSerializer(serializers.Serializer):
    title = UpperCaseField(max_length=200)

1.2 自定义数字字段 #

python
class PercentageField(serializers.FloatField):
    def to_representation(self, value):
        return f'{value * 100:.2f}%'

    def to_internal_value(self, data):
        if isinstance(data, str):
            data = data.rstrip('%')
        return float(data) / 100

class ReportSerializer(serializers.Serializer):
    completion_rate = PercentageField()

1.3 自定义JSON字段 #

python
import json

class JSONField(serializers.Field):
    def to_representation(self, value):
        if isinstance(value, str):
            try:
                return json.loads(value)
            except json.JSONDecodeError:
                return value
        return value

    def to_internal_value(self, data):
        if isinstance(data, (dict, list)):
            return json.dumps(data)
        return data

class ConfigSerializer(serializers.Serializer):
    settings = JSONField()

1.4 自定义枚举字段 #

python
from enum import Enum

class EnumField(serializers.ChoiceField):
    def __init__(self, enum_class, **kwargs):
        self.enum_class = enum_class
        choices = [(e.value, e.name) for e in enum_class]
        super().__init__(choices, **kwargs)

    def to_representation(self, value):
        if isinstance(value, self.enum_class):
            return value.value
        return value

    def to_internal_value(self, data):
        try:
            return self.enum_class(data)
        except ValueError:
            self.fail('invalid_choice', input=data)

class Status(Enum):
    DRAFT = 'draft'
    PUBLISHED = 'published'
    ARCHIVED = 'archived'

class ArticleSerializer(serializers.Serializer):
    status = EnumField(Status)

二、SerializerMethodField #

2.1 基本用法 #

python
class ArticleSerializer(serializers.ModelSerializer):
    word_count = serializers.SerializerMethodField()
    reading_time = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_word_count(self, obj):
        return len(obj.content.split())

    def get_reading_time(self, obj):
        words = len(obj.content.split())
        minutes = words // 200
        return max(1, minutes)

2.2 复杂计算 #

python
class ArticleSerializer(serializers.ModelSerializer):
    popularity_score = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_popularity_score(self, obj):
        views = obj.views or 0
        likes = obj.likes.count() if hasattr(obj, 'likes') else 0
        comments = obj.comments.count() if hasattr(obj, 'comments') else 0
        
        score = views * 0.1 + likes * 2 + comments * 5
        return round(score, 2)

2.3 条件字段 #

python
class ArticleSerializer(serializers.ModelSerializer):
    can_edit = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_can_edit(self, obj):
        request = self.context.get('request')
        if request and request.user.is_authenticated:
            return obj.author == request.user or request.user.is_staff
        return False

2.4 格式化输出 #

python
class ArticleSerializer(serializers.ModelSerializer):
    created_at_formatted = serializers.SerializerMethodField()
    summary = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_created_at_formatted(self, obj):
        return obj.created_at.strftime('%Y年%m月%d日 %H:%M')

    def get_summary(self, obj):
        return obj.content[:100] + '...' if len(obj.content) > 100 else obj.content

三、重写序列化方法 #

3.1 to_representation #

自定义输出格式:

python
class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    def to_representation(self, instance):
        data = super().to_representation(instance)
        
        data['author'] = {
            'id': instance.author.id,
            'username': instance.author.username,
            'avatar': instance.author.avatar.url if instance.author.avatar else None
        }
        
        data['category'] = {
            'id': instance.category.id,
            'name': instance.category.name
        } if instance.category else None
        
        return data

3.2 to_internal_value #

自定义输入处理:

python
class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    def to_internal_value(self, data):
        if 'title' in data:
            data['title'] = data['title'].strip()
        
        if 'tags' in data and isinstance(data['tags'], str):
            data['tags'] = [tag.strip() for tag in data['tags'].split(',')]
        
        return super().to_internal_value(data)

3.3 create方法 #

自定义创建逻辑:

python
class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    def create(self, validated_data):
        tags_data = validated_data.pop('tags', [])
        
        request = self.context.get('request')
        if request:
            validated_data['author'] = request.user
        
        article = Article.objects.create(**validated_data)
        
        for tag_name in tags_data:
            tag, _ = Tag.objects.get_or_create(name=tag_name)
            article.tags.add(tag)
        
        return article

3.4 update方法 #

自定义更新逻辑:

python
class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    def update(self, instance, validated_data):
        tags_data = validated_data.pop('tags', None)
        
        for attr, value in validated_data.items():
            setattr(instance, attr, value)
        instance.save()
        
        if tags_data is not None:
            instance.tags.clear()
            for tag_name in tags_data:
                tag, _ = Tag.objects.get_or_create(name=tag_name)
                instance.tags.add(tag)
        
        return instance

四、动态字段 #

4.1 动态排除字段 #

python
class DynamicFieldsModelSerializer(serializers.ModelSerializer):
    def __init__(self, *args, **kwargs):
        fields = kwargs.pop('fields', None)
        exclude = kwargs.pop('exclude', None)
        super().__init__(*args, **kwargs)
        
        if fields is not None:
            allowed = set(fields)
            existing = set(self.fields)
            for field_name in existing - allowed:
                self.fields.pop(field_name)
        
        if exclude is not None:
            for field_name in exclude:
                self.fields.pop(field_name, None)

class ArticleSerializer(DynamicFieldsModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

使用:

python
serializer = ArticleSerializer(article, fields=('id', 'title'))
serializer = ArticleSerializer(article, exclude=('content',))

4.2 基于用户权限 #

python
class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        request = self.context.get('request')
        
        if request and not request.user.is_staff:
            self.fields.pop('internal_notes', None)

4.3 基于动作 #

python
class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    def __init__(self, *args, **kwargs):
        action = kwargs.pop('action', None)
        super().__init__(*args, **kwargs)
        
        if action == 'list':
            self.fields.pop('content', None)
        elif action == 'create':
            self.fields.pop('views', None)

五、自定义关系字段 #

5.1 自定义外键表示 #

python
class CategoryField(serializers.RelatedField):
    def to_representation(self, value):
        return {
            'id': value.id,
            'name': value.name,
            'slug': value.slug
        }

    def to_internal_value(self, data):
        try:
            return Category.objects.get(id=data)
        except Category.DoesNotExist:
            raise serializers.ValidationError('分类不存在')

class ArticleSerializer(serializers.ModelSerializer):
    category = CategoryField(queryset=Category.objects.all())

    class Meta:
        model = Article
        fields = '__all__'

5.2 自定义多对多字段 #

python
class TagsField(serializers.RelatedField):
    def to_representation(self, value):
        return [{'id': tag.id, 'name': tag.name} for tag in value.all()]

    def to_internal_value(self, data):
        if isinstance(data, list):
            tags = []
            for item in data:
                if isinstance(item, int):
                    tags.append(Tag.objects.get(id=item))
                elif isinstance(item, str):
                    tag, _ = Tag.objects.get_or_create(name=item)
                    tags.append(tag)
            return tags
        return []

class ArticleSerializer(serializers.ModelSerializer):
    tags = TagsField(queryset=Tag.objects.all(), many=True)

    class Meta:
        model = Article
        fields = '__all__'

六、数据转换 #

6.1 字段值转换 #

python
class ArticleSerializer(serializers.ModelSerializer):
    status = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_status(self, obj):
        status_map = {
            'draft': {'value': 'draft', 'label': '草稿', 'color': 'gray'},
            'published': {'value': 'published', 'label': '已发布', 'color': 'green'},
            'archived': {'value': 'archived', 'label': '已归档', 'color': 'red'}
        }
        return status_map.get(obj.status, {})

6.2 时间格式化 #

python
class ArticleSerializer(serializers.ModelSerializer):
    created_at = serializers.DateTimeField(format='%Y-%m-%d %H:%M:%S')
    created_at_human = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_created_at_human(self, obj):
        from django.utils.timesince import timesince
        from django.utils.timezone import now
        
        diff = now() - obj.created_at
        if diff.days == 0:
            return timesince(obj.created_at) + '前'
        elif diff.days < 7:
            return f'{diff.days}天前'
        else:
            return obj.created_at.strftime('%Y-%m-%d')

6.3 图片URL处理 #

python
class ArticleSerializer(serializers.ModelSerializer):
    cover_image = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_cover_image(self, obj):
        request = self.context.get('request')
        if obj.cover_image:
            url = obj.cover_image.url
            if request:
                return request.build_absolute_uri(url)
            return url
        return None

七、上下文传递 #

7.1 传递请求对象 #

python
class ArticleViewSet(viewsets.ModelViewSet):
    queryset = Article.objects.all()
    serializer_class = ArticleSerializer

    def get_serializer_context(self):
        context = super().get_serializer_context()
        context['custom_data'] = 'some_value'
        return context

7.2 在序列化器中使用 #

python
class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    def to_representation(self, instance):
        data = super().to_representation(instance)
        request = self.context.get('request')
        custom_data = self.context.get('custom_data')
        
        if request and request.user.is_authenticated:
            data['is_liked'] = instance.likes.filter(user=request.user).exists()
        
        return data

八、性能优化 #

8.1 减少数据库查询 #

python
class ArticleSerializer(serializers.ModelSerializer):
    author_name = serializers.CharField(source='author.username')
    comment_count = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_comment_count(self, obj):
        if hasattr(obj, 'comment_count'):
            return obj.comment_count
        return obj.comments.count()

视图中使用注解:

python
from django.db.models import Count

class ArticleViewSet(viewsets.ModelViewSet):
    queryset = Article.objects.annotate(
        comment_count=Count('comments')
    )
    serializer_class = ArticleSerializer

8.2 延迟加载 #

python
class ArticleSerializer(serializers.ModelSerializer):
    comments = serializers.SerializerMethodField()

    class Meta:
        model = Article
        fields = '__all__'

    def get_comments(self, obj):
        if self.context.get('include_comments'):
            return CommentSerializer(obj.comments.all(), many=True).data
        return None

8.3 缓存计算结果 #

python
from functools import cached_property

class ArticleSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = '__all__'

    @cached_property
    def _current_user(self):
        return self.context.get('request').user

    def to_representation(self, instance):
        data = super().to_representation(instance)
        data['can_edit'] = instance.author == self._current_user
        return data

九、实际应用示例 #

9.1 用户信息序列化 #

python
class UserSerializer(serializers.ModelSerializer):
    full_name = serializers.SerializerMethodField()
    avatar_url = serializers.SerializerMethodField()
    stats = serializers.SerializerMethodField()

    class Meta:
        model = User
        fields = ['id', 'username', 'email', 'full_name', 'avatar_url', 'stats']

    def get_full_name(self, obj):
        return f'{obj.first_name} {obj.last_name}'.strip() or obj.username

    def get_avatar_url(self, obj):
        request = self.context.get('request')
        if obj.avatar:
            url = obj.avatar.url
            return request.build_absolute_uri(url) if request else url
        return None

    def get_stats(self, obj):
        return {
            'articles_count': obj.articles.count(),
            'comments_count': obj.comments.count(),
            'likes_received': obj.articles.aggregate(
                total=models.Sum('likes')
            )['total'] or 0
        }

9.2 搜索结果序列化 #

python
class SearchResultSerializer(serializers.Serializer):
    type = serializers.CharField()
    id = serializers.IntegerField()
    title = serializers.CharField()
    highlight = serializers.SerializerMethodField()
    url = serializers.CharField()

    def get_highlight(self, obj):
        if hasattr(obj, 'highlight'):
            return obj.highlight
        return None

class SearchSerializer(serializers.Serializer):
    query = serializers.CharField()
    total = serializers.IntegerField()
    results = SearchResultSerializer(many=True)
    facets = serializers.DictField()

十、总结 #

本章学习了自定义序列化的高级技巧:

  • 自定义字段:创建自定义字段类
  • SerializerMethodField:灵活添加计算字段
  • 重写方法:自定义序列化逻辑
  • 动态字段:根据条件动态调整字段
  • 性能优化:减少数据库查询

掌握自定义序列化技巧,可以灵活处理各种复杂数据场景!

最后更新:2026-03-28