This commit is contained in:
xie7654
2025-06-29 21:45:27 +08:00
commit f6e68e37c8
1539 changed files with 129319 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
from rest_framework.authentication import TokenAuthentication
class BearerTokenAuthentication(TokenAuthentication):
"""
使用 'Bearer' 前缀的 Token 认证
"""
keyword = 'Bearer'

View File

@@ -0,0 +1,149 @@
from rest_framework import viewsets, status
from rest_framework.response import Response
class CustomModelViewSet(viewsets.ModelViewSet):
"""
自定义ModelViewSet提供以下增强功能
- 基于动作的序列化器选择
- 基于动作的权限控制
- 标准化响应格式
- 软删除支持
- 批量操作支持
"""
# 动作到序列化器类的映射
action_serializers = {}
# 动作到权限类的映射
action_permissions = {}
# 软删除字段名
soft_delete_field = 'is_deleted'
# 是否支持软删除
enable_soft_delete = False
def get_serializer_class(self):
"""根据当前动作获取序列化器类"""
return self.action_serializers.get(
self.action,
super().get_serializer_class()
)
def get_permissions(self):
"""根据当前动作获取权限类"""
permissions = self.action_permissions.get(
self.action,
self.permission_classes
)
return [permission() for permission in permissions]
def list(self, request, *args, **kwargs):
"""重写列表视图,支持软删除过滤"""
queryset = self.get_queryset()
# 应用软删除过滤
if self.enable_soft_delete:
queryset = queryset.filter(**{self.soft_delete_field: False})
# 应用搜索和过滤
queryset = self.filter_queryset(queryset)
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return self._build_response(
data=serializer.data,
message="ok",
status=status.HTTP_200_OK
)
def retrieve(self, request, *args, **kwargs):
"""重写详情视图,支持软删除检查"""
instance = self.get_object()
# 检查软删除状态
if (self.enable_soft_delete and
hasattr(instance, self.soft_delete_field) and
getattr(instance, self.soft_delete_field)):
return Response(status=status.HTTP_404_NOT_FOUND)
serializer = self.get_serializer(instance)
return self._build_response(
data=serializer.data,
message="Object retrieved successfully",
status=status.HTTP_200_OK
)
def create(self, request, *args, **kwargs):
"""重写创建视图,支持批量创建"""
is_many = isinstance(request.data, list)
if is_many:
serializer = self.get_serializer(data=request.data, many=True)
else:
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
return self._build_response(
data=serializer.data,
message="ok",
status=status.HTTP_200_OK,
)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return self._build_response(
message="ok",
status=status.HTTP_200_OK,
)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return self._build_response(
data=serializer.data,
message="ok",
status=status.HTTP_200_OK,
)
def _build_response(self, code=0, message="成功", data=None, status=status.HTTP_200_OK):
"""
构建标准化API响应格式
参数说明:
- code: 业务状态码0表示成功非0表示错误
- message: 状态描述信息
- data: 响应数据可为None
- status: HTTP状态码默认200
"""
# 构建基础响应结构
response_data = {
"code": code,
"message": message
}
# 仅当data不为None时添加到响应中
if data is not None:
response_data["data"] = data
# 移除可能的空值如message为空字符串
response_data = {k: v for k, v in response_data.items() if v is not None and v != ""}
# 返回DRF的Response对象
return Response(
data=response_data,
status=status,
content_type="application/json"
)

20
backend/utils/models.py Normal file
View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""
@Remark: 公共基础model类
"""
from django.db import models
class CoreModel(models.Model):
remark = models.CharField(max_length=256, verbose_name="备注", null=True, blank=True, help_text="备注")
creator = models.CharField(max_length=64, null=True, blank=True, help_text="创建人", verbose_name="创建人")
modifier = models.CharField(max_length=64, null=True, blank=True, help_text="修改人", verbose_name="修改人")
update_time = models.DateTimeField(auto_now=True, null=True, blank=True, help_text="修改时间", verbose_name="修改时间")
create_time = models.DateTimeField(auto_now_add=True, null=True, blank=True, help_text="创建时间",
verbose_name="创建时间")
is_deleted = models.BooleanField(default=False, verbose_name='是否软删除')
class Meta:
abstract = True
verbose_name = '核心模型'
verbose_name_plural = verbose_name

View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
from collections import OrderedDict
from django.core import paginator
from django.core.paginator import Paginator as DjangoPaginator
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response
from django.core.paginator import InvalidPage
class CustomPagination(PageNumberPagination):
page_size = 20
page_size_query_param = "pageSize"
max_page_size = 999
django_paginator_class = DjangoPaginator
def paginate_queryset(self, queryset, request, view=None):
"""
重写paginate_queryset让分页超过正常分页:有原来的4000错误无效页面。改写为返回2000成功data=[]提示
"""
page_size = self.get_page_size(request)
if not page_size:
return None
paginator = self.django_paginator_class(queryset, page_size)
page_number = self.get_page_number(request, paginator)
try:
self.page = paginator.page(page_number)
except InvalidPage as exc:
self.page = []
if paginator.num_pages > 1 and self.template is not None:
# The browsable API should display pagination controls.
self.display_page_controls = True
self.request = request
return list(self.page)
def get_paginated_response(self, data):
code = 0
msg = 'ok'
total = self.page.paginator.count if self.page else 0
res = {
"total": total,
"items": data
}
if not data:
code = 0
msg = "暂无数据"
res['data'] = []
return Response(OrderedDict([
('code', code),
('message', msg),
('data', res),
('error', None),
]))

View File

@@ -0,0 +1,8 @@
from rest_framework import permissions
class IsSuperUserOrReadOnly(permissions.BasePermission):
"""超级用户可读写,普通用户只读"""
def has_permission(self, request, view):
if request.method in permissions.SAFE_METHODS:
return True
return request.user and request.user.is_superuser

View File

@@ -0,0 +1,118 @@
"""
@Remark: 自定义序列化器
"""
from rest_framework import serializers
from rest_framework.fields import empty
from rest_framework.request import Request
from rest_framework.serializers import ModelSerializer
from django.utils.functional import cached_property
from rest_framework.utils.serializer_helpers import BindingDict
from system.models import User
class CustomModelSerializer(ModelSerializer):
"""
增强DRF的ModelSerializer,可自动更新模型的审计字段记录
(1)self.request能获取到rest_framework.request.Request对象
"""
# 修改人的审计字段名称, 默认modifier, 继承使用时可自定义覆盖
modifier_field_id = 'modifier'
modifier_name = serializers.SerializerMethodField(read_only=True)
def get_modifier_name(self, instance):
if not hasattr(instance, 'modifier'):
return None
queryset = User.objects.filter(id=instance.modifier).values_list('name', flat=True).first()
if queryset:
return queryset
return None
# 创建人的审计字段名称, 默认creator, 继承使用时可自定义覆盖
creator_field_id = 'creator'
# 添加默认时间返回格式
create_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False, read_only=True)
update_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False)
def __init__(self, instance=None, data=empty, request=None, **kwargs):
super().__init__(instance, data, **kwargs)
self.request: Request = request or self.context.get('request', None)
def save(self, **kwargs):
return super().save(**kwargs)
def create(self, validated_data):
if self.request:
if self.modifier_field_id in self.fields.fields:
validated_data[self.modifier_field_id] = self.get_request_username()
if self.creator_field_id in self.fields.fields:
validated_data[self.creator_field_id] = self.get_request_username()
return super().create(validated_data)
def update(self, instance, validated_data):
if self.request:
if hasattr(self.instance, self.modifier_field_id):
self.instance.modifier = self.get_request_username()
return super().update(instance, validated_data)
def get_request_username(self):
if getattr(self.request, 'user', None):
return getattr(self.request.user, 'username', None)
return None
def get_request_name(self):
if getattr(self.request, 'user', None):
return getattr(self.request.user, 'name', None)
return None
def get_request_user_id(self):
if getattr(self.request, 'user', None):
return getattr(self.request.user, 'id', None)
return None
@cached_property
def fields(self):
fields = BindingDict(self)
for key, value in self.get_fields().items():
fields[key] = value
if not hasattr(self, '_context'):
return fields
is_root = self.root == self
parent_is_list_root = self.parent == self.root and getattr(self.parent, 'many', False)
if not (is_root or parent_is_list_root):
return fields
try:
request = self.request or self.context['request']
except KeyError:
return fields
params = getattr(
request, 'query_params', getattr(request, 'GET', None)
)
if params is None:
pass
try:
filter_fields = params.get('_fields', None).split(',')
except AttributeError:
filter_fields = None
try:
omit_fields = params.get('_exclude', None).split(',')
except AttributeError:
omit_fields = []
existing = set(fields.keys())
if filter_fields is None:
allowed = existing
else:
allowed = set(filter(None, filter_fields))
omitted = set(filter(None, omit_fields))
for field in existing:
if field not in allowed:
fields.pop(field, None)
if field in omitted:
fields.pop(field, None)
return fields

30
backend/utils/utils.py Normal file
View File

@@ -0,0 +1,30 @@
import re
from datetime import datetime
from decimal import Decimal, ROUND_HALF_UP
from rest_framework.exceptions import ValidationError
from django.utils import timezone
def validate_mobile(value):
if value and not re.findall(r"1\d{10}", value):
raise ValidationError('手机格式不正确')
def validate_amount(value):
if value is None:
raise ValidationError('金额不能为空')
if value and value < 0:
raise ValidationError('金额不能为负')
def to_cent(value):
if value is None:
value = 0
return Decimal(value).quantize(Decimal('.01'), rounding=ROUND_HALF_UP)
# 定义一个小工具:从时间戳转换为 aware datetime如果时间戳有效
def ts_to_aware(ts):
if ts:
naive_dt = datetime.fromtimestamp(ts)
return timezone.make_aware(naive_dt)
return None