init
This commit is contained in:
7
backend/utils/authentication.py
Normal file
7
backend/utils/authentication.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from rest_framework.authentication import TokenAuthentication
|
||||
|
||||
class BearerTokenAuthentication(TokenAuthentication):
|
||||
"""
|
||||
使用 'Bearer' 前缀的 Token 认证
|
||||
"""
|
||||
keyword = 'Bearer'
|
||||
149
backend/utils/custom_model_viewSet.py
Normal file
149
backend/utils/custom_model_viewSet.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from rest_framework import viewsets, status
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
class CustomModelViewSet(viewsets.ModelViewSet):
|
||||
"""
|
||||
自定义ModelViewSet,提供以下增强功能:
|
||||
- 基于动作的序列化器选择
|
||||
- 基于动作的权限控制
|
||||
- 标准化响应格式
|
||||
- 软删除支持
|
||||
- 批量操作支持
|
||||
"""
|
||||
# 动作到序列化器类的映射
|
||||
action_serializers = {}
|
||||
# 动作到权限类的映射
|
||||
action_permissions = {}
|
||||
# 软删除字段名
|
||||
soft_delete_field = 'is_deleted'
|
||||
# 是否支持软删除
|
||||
enable_soft_delete = False
|
||||
|
||||
def get_serializer_class(self):
|
||||
"""根据当前动作获取序列化器类"""
|
||||
return self.action_serializers.get(
|
||||
self.action,
|
||||
super().get_serializer_class()
|
||||
)
|
||||
|
||||
def get_permissions(self):
|
||||
"""根据当前动作获取权限类"""
|
||||
permissions = self.action_permissions.get(
|
||||
self.action,
|
||||
self.permission_classes
|
||||
)
|
||||
return [permission() for permission in permissions]
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
"""重写列表视图,支持软删除过滤"""
|
||||
queryset = self.get_queryset()
|
||||
|
||||
# 应用软删除过滤
|
||||
if self.enable_soft_delete:
|
||||
queryset = queryset.filter(**{self.soft_delete_field: False})
|
||||
|
||||
# 应用搜索和过滤
|
||||
queryset = self.filter_queryset(queryset)
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return self._build_response(
|
||||
data=serializer.data,
|
||||
message="ok",
|
||||
status=status.HTTP_200_OK
|
||||
)
|
||||
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
"""重写详情视图,支持软删除检查"""
|
||||
instance = self.get_object()
|
||||
|
||||
# 检查软删除状态
|
||||
if (self.enable_soft_delete and
|
||||
hasattr(instance, self.soft_delete_field) and
|
||||
getattr(instance, self.soft_delete_field)):
|
||||
return Response(status=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
serializer = self.get_serializer(instance)
|
||||
return self._build_response(
|
||||
data=serializer.data,
|
||||
message="Object retrieved successfully",
|
||||
status=status.HTTP_200_OK
|
||||
)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
"""重写创建视图,支持批量创建"""
|
||||
is_many = isinstance(request.data, list)
|
||||
|
||||
if is_many:
|
||||
serializer = self.get_serializer(data=request.data, many=True)
|
||||
else:
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
|
||||
return self._build_response(
|
||||
data=serializer.data,
|
||||
message="ok",
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
instance = self.get_object()
|
||||
self.perform_destroy(instance)
|
||||
return self._build_response(
|
||||
message="ok",
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
def update(self, request, *args, **kwargs):
|
||||
partial = kwargs.pop('partial', False)
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_update(serializer)
|
||||
|
||||
if getattr(instance, '_prefetched_objects_cache', None):
|
||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||
# forcibly invalidate the prefetch cache on the instance.
|
||||
instance._prefetched_objects_cache = {}
|
||||
return self._build_response(
|
||||
data=serializer.data,
|
||||
message="ok",
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
def _build_response(self, code=0, message="成功", data=None, status=status.HTTP_200_OK):
|
||||
"""
|
||||
构建标准化API响应格式
|
||||
|
||||
参数说明:
|
||||
- code: 业务状态码(0表示成功,非0表示错误)
|
||||
- message: 状态描述信息
|
||||
- data: 响应数据(可为None)
|
||||
- status: HTTP状态码(默认200)
|
||||
"""
|
||||
# 构建基础响应结构
|
||||
response_data = {
|
||||
"code": code,
|
||||
"message": message
|
||||
}
|
||||
|
||||
# 仅当data不为None时添加到响应中
|
||||
if data is not None:
|
||||
response_data["data"] = data
|
||||
|
||||
# 移除可能的空值(如message为空字符串)
|
||||
response_data = {k: v for k, v in response_data.items() if v is not None and v != ""}
|
||||
|
||||
# 返回DRF的Response对象
|
||||
return Response(
|
||||
data=response_data,
|
||||
status=status,
|
||||
content_type="application/json"
|
||||
)
|
||||
20
backend/utils/models.py
Normal file
20
backend/utils/models.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
@Remark: 公共基础model类
|
||||
"""
|
||||
from django.db import models
|
||||
|
||||
class CoreModel(models.Model):
|
||||
remark = models.CharField(max_length=256, verbose_name="备注", null=True, blank=True, help_text="备注")
|
||||
creator = models.CharField(max_length=64, null=True, blank=True, help_text="创建人", verbose_name="创建人")
|
||||
modifier = models.CharField(max_length=64, null=True, blank=True, help_text="修改人", verbose_name="修改人")
|
||||
update_time = models.DateTimeField(auto_now=True, null=True, blank=True, help_text="修改时间", verbose_name="修改时间")
|
||||
create_time = models.DateTimeField(auto_now_add=True, null=True, blank=True, help_text="创建时间",
|
||||
verbose_name="创建时间")
|
||||
is_deleted = models.BooleanField(default=False, verbose_name='是否软删除')
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
verbose_name = '核心模型'
|
||||
verbose_name_plural = verbose_name
|
||||
56
backend/utils/pagination.py
Normal file
56
backend/utils/pagination.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.core import paginator
|
||||
from django.core.paginator import Paginator as DjangoPaginator
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.response import Response
|
||||
from django.core.paginator import InvalidPage
|
||||
|
||||
|
||||
class CustomPagination(PageNumberPagination):
|
||||
page_size = 20
|
||||
page_size_query_param = "pageSize"
|
||||
max_page_size = 999
|
||||
django_paginator_class = DjangoPaginator
|
||||
|
||||
def paginate_queryset(self, queryset, request, view=None):
|
||||
"""
|
||||
重写paginate_queryset让分页超过正常分页:有原来的4000错误无效页面。改写为返回2000成功,data=[]提示
|
||||
"""
|
||||
page_size = self.get_page_size(request)
|
||||
if not page_size:
|
||||
return None
|
||||
paginator = self.django_paginator_class(queryset, page_size)
|
||||
page_number = self.get_page_number(request, paginator)
|
||||
try:
|
||||
self.page = paginator.page(page_number)
|
||||
except InvalidPage as exc:
|
||||
self.page = []
|
||||
|
||||
if paginator.num_pages > 1 and self.template is not None:
|
||||
# The browsable API should display pagination controls.
|
||||
self.display_page_controls = True
|
||||
|
||||
self.request = request
|
||||
return list(self.page)
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
code = 0
|
||||
msg = 'ok'
|
||||
total = self.page.paginator.count if self.page else 0
|
||||
res = {
|
||||
"total": total,
|
||||
"items": data
|
||||
}
|
||||
if not data:
|
||||
code = 0
|
||||
msg = "暂无数据"
|
||||
res['data'] = []
|
||||
|
||||
return Response(OrderedDict([
|
||||
('code', code),
|
||||
('message', msg),
|
||||
('data', res),
|
||||
('error', None),
|
||||
]))
|
||||
8
backend/utils/permissions.py
Normal file
8
backend/utils/permissions.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from rest_framework import permissions
|
||||
|
||||
class IsSuperUserOrReadOnly(permissions.BasePermission):
|
||||
"""超级用户可读写,普通用户只读"""
|
||||
def has_permission(self, request, view):
|
||||
if request.method in permissions.SAFE_METHODS:
|
||||
return True
|
||||
return request.user and request.user.is_superuser
|
||||
118
backend/utils/serializers.py
Normal file
118
backend/utils/serializers.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
@Remark: 自定义序列化器
|
||||
"""
|
||||
from rest_framework import serializers
|
||||
from rest_framework.fields import empty
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from django.utils.functional import cached_property
|
||||
from rest_framework.utils.serializer_helpers import BindingDict
|
||||
|
||||
from system.models import User
|
||||
|
||||
|
||||
class CustomModelSerializer(ModelSerializer):
|
||||
"""
|
||||
增强DRF的ModelSerializer,可自动更新模型的审计字段记录
|
||||
(1)self.request能获取到rest_framework.request.Request对象
|
||||
"""
|
||||
# 修改人的审计字段名称, 默认modifier, 继承使用时可自定义覆盖
|
||||
modifier_field_id = 'modifier'
|
||||
modifier_name = serializers.SerializerMethodField(read_only=True)
|
||||
|
||||
def get_modifier_name(self, instance):
|
||||
if not hasattr(instance, 'modifier'):
|
||||
return None
|
||||
queryset = User.objects.filter(id=instance.modifier).values_list('name', flat=True).first()
|
||||
if queryset:
|
||||
return queryset
|
||||
return None
|
||||
|
||||
# 创建人的审计字段名称, 默认creator, 继承使用时可自定义覆盖
|
||||
creator_field_id = 'creator'
|
||||
# 添加默认时间返回格式
|
||||
create_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False, read_only=True)
|
||||
update_time = serializers.DateTimeField(format="%Y-%m-%d %H:%M:%S", required=False)
|
||||
|
||||
def __init__(self, instance=None, data=empty, request=None, **kwargs):
|
||||
super().__init__(instance, data, **kwargs)
|
||||
self.request: Request = request or self.context.get('request', None)
|
||||
|
||||
def save(self, **kwargs):
|
||||
return super().save(**kwargs)
|
||||
|
||||
def create(self, validated_data):
|
||||
if self.request:
|
||||
if self.modifier_field_id in self.fields.fields:
|
||||
validated_data[self.modifier_field_id] = self.get_request_username()
|
||||
if self.creator_field_id in self.fields.fields:
|
||||
validated_data[self.creator_field_id] = self.get_request_username()
|
||||
return super().create(validated_data)
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
if self.request:
|
||||
if hasattr(self.instance, self.modifier_field_id):
|
||||
self.instance.modifier = self.get_request_username()
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
def get_request_username(self):
|
||||
if getattr(self.request, 'user', None):
|
||||
return getattr(self.request.user, 'username', None)
|
||||
return None
|
||||
|
||||
def get_request_name(self):
|
||||
if getattr(self.request, 'user', None):
|
||||
return getattr(self.request.user, 'name', None)
|
||||
return None
|
||||
|
||||
def get_request_user_id(self):
|
||||
if getattr(self.request, 'user', None):
|
||||
return getattr(self.request.user, 'id', None)
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def fields(self):
|
||||
fields = BindingDict(self)
|
||||
for key, value in self.get_fields().items():
|
||||
fields[key] = value
|
||||
|
||||
if not hasattr(self, '_context'):
|
||||
return fields
|
||||
is_root = self.root == self
|
||||
parent_is_list_root = self.parent == self.root and getattr(self.parent, 'many', False)
|
||||
if not (is_root or parent_is_list_root):
|
||||
return fields
|
||||
|
||||
try:
|
||||
request = self.request or self.context['request']
|
||||
except KeyError:
|
||||
return fields
|
||||
params = getattr(
|
||||
request, 'query_params', getattr(request, 'GET', None)
|
||||
)
|
||||
if params is None:
|
||||
pass
|
||||
try:
|
||||
filter_fields = params.get('_fields', None).split(',')
|
||||
except AttributeError:
|
||||
filter_fields = None
|
||||
|
||||
try:
|
||||
omit_fields = params.get('_exclude', None).split(',')
|
||||
except AttributeError:
|
||||
omit_fields = []
|
||||
|
||||
existing = set(fields.keys())
|
||||
if filter_fields is None:
|
||||
allowed = existing
|
||||
else:
|
||||
allowed = set(filter(None, filter_fields))
|
||||
|
||||
omitted = set(filter(None, omit_fields))
|
||||
for field in existing:
|
||||
if field not in allowed:
|
||||
fields.pop(field, None)
|
||||
if field in omitted:
|
||||
fields.pop(field, None)
|
||||
|
||||
return fields
|
||||
30
backend/utils/utils.py
Normal file
30
backend/utils/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from django.utils import timezone
|
||||
|
||||
def validate_mobile(value):
|
||||
if value and not re.findall(r"1\d{10}", value):
|
||||
raise ValidationError('手机格式不正确')
|
||||
|
||||
|
||||
def validate_amount(value):
|
||||
if value is None:
|
||||
raise ValidationError('金额不能为空')
|
||||
if value and value < 0:
|
||||
raise ValidationError('金额不能为负')
|
||||
|
||||
|
||||
def to_cent(value):
|
||||
if value is None:
|
||||
value = 0
|
||||
return Decimal(value).quantize(Decimal('.01'), rounding=ROUND_HALF_UP)
|
||||
|
||||
# 定义一个小工具:从时间戳转换为 aware datetime(如果时间戳有效)
|
||||
def ts_to_aware(ts):
|
||||
if ts:
|
||||
naive_dt = datetime.fromtimestamp(ts)
|
||||
return timezone.make_aware(naive_dt)
|
||||
return None
|
||||
Reference in New Issue
Block a user