Files
django-vue3-admin-gd/backend/utils/serializers.py
2025-08-27 10:23:13 +08:00

237 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
@Remark: 自定义序列化器
"""
from rest_framework import serializers
from rest_framework.fields import empty
from rest_framework.request import Request
from rest_framework.serializers import ModelSerializer
from django.utils.functional import cached_property
from rest_framework.utils.serializer_helpers import BindingDict
class AuditUserFieldsMixin:
"""
用于自动赋值 creator 和 modifier 字段的 Mixin
"""
# 修改人的审计字段名称, 默认modifier, 继承使用时可自定义覆盖
modifier_field_id = 'modifier'
# 创建人的审计字段名称, 默认creator, 继承使用时可自定义覆盖
creator_field_id = 'creator'
def set_audit_user_fields(self, validated_data, is_create=True):
username = self.get_request_user_name() if hasattr(self, 'get_request_user_name') else None
if getattr(self, 'request', None):
if self.modifier_field_id in self.fields:
validated_data[self.modifier_field_id] = username
if is_create and self.creator_field_id in self.fields:
validated_data[self.creator_field_id] = username
class DesensitizationMixin:
"""
用于敏感字段脱敏的通用 Mixin
使用方法:
1. 在序列化器中继承此 Mixin
2. 设置 desensitize_fields 属性,指定需要脱敏的字段
3. 可选:设置 desensitize_prefix_length 和 desensitize_suffix_length 来自定义脱敏格式
"""
# 需要脱敏的字段列表,格式:['field_name', 'related_field.field_name']
desensitize_fields = []
# 脱敏时保留的前缀长度默认4
desensitize_prefix_length = 4
# 脱敏时保留的后缀长度默认4
desensitize_suffix_length = 4
# 脱敏阈值,字段长度小于等于此值时全部用*替换默认8
desensitize_threshold = 8
# 脱敏字符,默认使用*
desensitize_char = '*'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 为每个脱敏字段创建脱敏方法
for field_name in self.desensitize_fields:
if '.' in field_name:
# 处理关联字段,如 'key.api_key'
method_name = f'get_{field_name.replace(".", "_")}_desensitized'
setattr(self, method_name, self._create_desensitize_method(field_name))
else:
# 处理直接字段
method_name = f'get_{field_name}_desensitized'
setattr(self, method_name, self._create_desensitize_method(field_name))
def _create_desensitize_method(self, field_name):
"""创建脱敏方法的闭包"""
def desensitize_method(obj):
return self._desensitize_field(obj, field_name)
return desensitize_method
def _desensitize_field(self, obj, field_name):
"""脱敏指定字段"""
if '.' in field_name:
# 处理关联字段,如 'key.api_key'
parts = field_name.split('.')
value = obj
for part in parts:
if hasattr(value, part):
value = getattr(value, part)
else:
return None
else:
# 处理直接字段
if not hasattr(obj, field_name):
return None
value = getattr(obj, field_name)
# 如果值为空,直接返回
if not value:
return value
# 检查用户权限
if self._can_view_full_value():
return value
# 执行脱敏
return self._apply_desensitization(str(value))
def _can_view_full_value(self):
"""检查当前用户是否可以查看完整值"""
# request = self.context.get('request')
# if not request or not request.user:
# return False
#
# # 超级用户或管理员可以查看完整值
# return request.user.is_superuser or request.user.is_staff
return False # 默认不允许查看完整值,需根据实际权限逻辑调整
def _apply_desensitization(self, value):
"""应用脱敏规则"""
if len(value) <= self.desensitize_threshold:
# 如果长度小于等于阈值,则全部用脱敏字符替换
return self.desensitize_char * len(value)
else:
# 显示前缀和后缀,中间用脱敏字符替换
prefix = value[:self.desensitize_prefix_length]
suffix = value[-self.desensitize_suffix_length:]
middle_length = len(value) - self.desensitize_prefix_length - self.desensitize_suffix_length
middle = self.desensitize_char * middle_length
return prefix + middle + suffix
def get_fields(self):
"""重写 get_fields 方法,为脱敏字段添加脱敏版本"""
fields = super().get_fields()
is_list = getattr(self.root, 'many', False)
for field_name in self.desensitize_fields:
if '.' in field_name:
# 处理关联字段,如 'key.api_key'
method_name = f'get_{field_name.replace(".", "_")}_desensitized'
field_key = f"{field_name.replace('.', '_')}_desensitized"
else:
# 处理直接字段
method_name = f'get_{field_name}_desensitized'
field_key = f"{field_name}_desensitized"
# 创建脱敏字段的 SerializerMethodField
fields[field_key] = serializers.SerializerMethodField(method_name=method_name)
# 保持原字段不变,确保创建/更新功能正常
# 原字段仍然可以接收输入数据
if is_list:
# 如果是列表序列化,移除原始字段
fields.pop(field_name, None)
return fields
class CustomModelSerializer(AuditUserFieldsMixin, ModelSerializer):
"""
增强DRF的ModelSerializer,可自动更新模型的审计字段记录
(1)self.request能获取到rest_framework.request.Request对象
"""
# 添加默认时间返回格式
create_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False, read_only=True)
update_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False)
def __init__(self, instance=None, data=empty, request=None, **kwargs):
super().__init__(instance, data, **kwargs)
self.request: Request = request or self.context.get('request', None)
def save(self, **kwargs):
return super().save(**kwargs)
def create(self, validated_data):
self.set_audit_user_fields(validated_data, is_create=True)
return super().create(validated_data)
def update(self, instance, validated_data):
self.set_audit_user_fields(validated_data, is_create=False)
return super().update(instance, validated_data)
def get_request_username(self):
if getattr(self.request, 'user', None):
return getattr(self.request.user, 'username', None)
return None
def get_request_name(self):
if getattr(self.request, 'user', None):
return getattr(self.request.user, 'name', None)
return None
def get_request_user_id(self):
if getattr(self.request, 'user', None):
return getattr(self.request.user, 'id', None)
return None
@cached_property
def fields(self):
fields = BindingDict(self)
for key, value in self.get_fields().items():
fields[key] = value
if not hasattr(self, '_context'):
return fields
is_root = self.root == self
parent_is_list_root = self.parent == self.root and getattr(self.parent, 'many', False)
if not (is_root or parent_is_list_root):
return fields
try:
request = self.request or self.context['request']
except KeyError:
return fields
params = getattr(
request, 'query_params', getattr(request, 'GET', None)
)
if params is None:
pass
try:
filter_fields = params.get('_fields', None).split(',')
except AttributeError:
filter_fields = None
try:
omit_fields = params.get('_exclude', None).split(',')
except AttributeError:
omit_fields = []
existing = set(fields.keys())
if filter_fields is None:
allowed = existing
else:
allowed = set(filter(None, filter_fields))
omitted = set(filter(None, omit_fields))
for field in existing:
if field not in allowed:
fields.pop(field, None)
if field in omitted:
fields.pop(field, None)
return fields