init
This commit is contained in:
118
backend/utils/serializers.py
Normal file
118
backend/utils/serializers.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
@Remark: 自定义序列化器
|
||||
"""
|
||||
from rest_framework import serializers
|
||||
from rest_framework.fields import empty
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from django.utils.functional import cached_property
|
||||
from rest_framework.utils.serializer_helpers import BindingDict
|
||||
|
||||
from system.models import User
|
||||
|
||||
|
||||
class CustomModelSerializer(ModelSerializer):
|
||||
"""
|
||||
增强DRF的ModelSerializer,可自动更新模型的审计字段记录
|
||||
(1)self.request能获取到rest_framework.request.Request对象
|
||||
"""
|
||||
# 修改人的审计字段名称, 默认modifier, 继承使用时可自定义覆盖
|
||||
modifier_field_id = 'modifier'
|
||||
modifier_name = serializers.SerializerMethodField(read_only=True)
|
||||
|
||||
def get_modifier_name(self, instance):
|
||||
if not hasattr(instance, 'modifier'):
|
||||
return None
|
||||
queryset = User.objects.filter(id=instance.modifier).values_list('name', flat=True).first()
|
||||
if queryset:
|
||||
return queryset
|
||||
return None
|
||||
|
||||
# 创建人的审计字段名称, 默认creator, 继承使用时可自定义覆盖
|
||||
creator_field_id = 'creator'
|
||||
# 添加默认时间返回格式
|
||||
create_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False, read_only=True)
|
||||
update_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False)
|
||||
|
||||
def __init__(self, instance=None, data=empty, request=None, **kwargs):
|
||||
super().__init__(instance, data, **kwargs)
|
||||
self.request: Request = request or self.context.get('request', None)
|
||||
|
||||
def save(self, **kwargs):
|
||||
return super().save(**kwargs)
|
||||
|
||||
def create(self, validated_data):
|
||||
if self.request:
|
||||
if self.modifier_field_id in self.fields.fields:
|
||||
validated_data[self.modifier_field_id] = self.get_request_username()
|
||||
if self.creator_field_id in self.fields.fields:
|
||||
validated_data[self.creator_field_id] = self.get_request_username()
|
||||
return super().create(validated_data)
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
if self.request:
|
||||
if hasattr(self.instance, self.modifier_field_id):
|
||||
self.instance.modifier = self.get_request_username()
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
def get_request_username(self):
|
||||
if getattr(self.request, 'user', None):
|
||||
return getattr(self.request.user, 'username', None)
|
||||
return None
|
||||
|
||||
def get_request_name(self):
|
||||
if getattr(self.request, 'user', None):
|
||||
return getattr(self.request.user, 'name', None)
|
||||
return None
|
||||
|
||||
def get_request_user_id(self):
|
||||
if getattr(self.request, 'user', None):
|
||||
return getattr(self.request.user, 'id', None)
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def fields(self):
|
||||
fields = BindingDict(self)
|
||||
for key, value in self.get_fields().items():
|
||||
fields[key] = value
|
||||
|
||||
if not hasattr(self, '_context'):
|
||||
return fields
|
||||
is_root = self.root == self
|
||||
parent_is_list_root = self.parent == self.root and getattr(self.parent, 'many', False)
|
||||
if not (is_root or parent_is_list_root):
|
||||
return fields
|
||||
|
||||
try:
|
||||
request = self.request or self.context['request']
|
||||
except KeyError:
|
||||
return fields
|
||||
params = getattr(
|
||||
request, 'query_params', getattr(request, 'GET', None)
|
||||
)
|
||||
if params is None:
|
||||
pass
|
||||
try:
|
||||
filter_fields = params.get('_fields', None).split(',')
|
||||
except AttributeError:
|
||||
filter_fields = None
|
||||
|
||||
try:
|
||||
omit_fields = params.get('_exclude', None).split(',')
|
||||
except AttributeError:
|
||||
omit_fields = []
|
||||
|
||||
existing = set(fields.keys())
|
||||
if filter_fields is None:
|
||||
allowed = existing
|
||||
else:
|
||||
allowed = set(filter(None, filter_fields))
|
||||
|
||||
omitted = set(filter(None, omit_fields))
|
||||
for field in existing:
|
||||
if field not in allowed:
|
||||
fields.pop(field, None)
|
||||
if field in omitted:
|
||||
fields.pop(field, None)
|
||||
|
||||
return fields
|
||||
Reference in New Issue
Block a user