This commit is contained in:
sheng
2023-07-25 15:25:57 +08:00
parent 8f57a240db
commit 7ccbcebe77
373 changed files with 41578 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
import hashlib
import logging
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.contrib.auth.hashers import check_password
from django.utils import timezone
from dvadmin.utils.validator import CustomValidationError
logger = logging.getLogger(__name__)
UserModel = get_user_model()
class CustomBackend(ModelBackend):
"""
Django原生认证方式
"""
def authenticate(self, request, username=None, password=None, **kwargs):
msg = '%s 正在使用本地登录...' % username
logger.info(msg)
if username is None:
username = kwargs.get(UserModel.USERNAME_FIELD)
try:
user = UserModel._default_manager.get_by_natural_key(username)
except UserModel.DoesNotExist:
UserModel().set_password(password)
else:
verify_password = check_password(password, user.password)
if not verify_password:
password = hashlib.md5(password.encode(encoding='UTF-8')).hexdigest()
verify_password = check_password(password, user.password)
if verify_password:
if self.user_can_authenticate(user):
user.last_login = timezone.now()
user.save()
return user
raise CustomValidationError("当前用户已被禁用,请联系管理员!")

View File

@@ -0,0 +1,89 @@
# 初始化基类
import json
import os
from django.apps import apps
from rest_framework import request
from application import settings
from dvadmin.system.models import Users
class CoreInitialize:
"""
使用方法:继承此类,重写 run方法在 run 中调用 save 进行数据初始化
"""
creator_id = None
reset = False
request = request
file_path = None
def __init__(self, reset=False, creator_id=None, app=None):
"""
reset: 是否重置初始化数据
creator_id: 创建人id
"""
self.reset = reset or self.reset
self.creator_id = creator_id or self.creator_id
self.app = app or ''
self.request.user = Users.objects.order_by('create_datetime').first()
def init_base(self, Serializer, unique_fields=None):
model = Serializer.Meta.model
path_file = os.path.join(apps.get_app_config(self.app.split('.')[-1]).path, 'fixtures',
f'init_{Serializer.Meta.model._meta.model_name}.json')
if not os.path.isfile(path_file):
print("文件不存在,跳过初始化")
return
with open(path_file,encoding="utf-8") as f:
for data in json.load(f):
filter_data = {}
# 配置过滤条件,如果有唯一标识字段则使用唯一标识字段,否则使用全部字段
if unique_fields:
for field in unique_fields:
if field in data:
filter_data[field] = data[field]
else:
for key, value in data.items():
if isinstance(value, list) or value == None or value == '':
continue
filter_data[key] = value
instance = model.objects.filter(**filter_data).first()
data["reset"] = self.reset
serializer = Serializer(instance, data=data, request=self.request)
serializer.is_valid(raise_exception=True)
serializer.save()
print(f"[{self.app}][{model._meta.model_name}]初始化完成")
def save(self, obj, data: list, name=None, no_reset=False):
name = name or obj._meta.verbose_name
print(f"正在初始化[{obj._meta.label} => {name}]")
if not no_reset and self.reset and obj not in settings.INITIALIZE_RESET_LIST:
try:
obj.objects.all().delete()
settings.INITIALIZE_RESET_LIST.append(obj)
except Exception:
pass
for ele in data:
m2m_dict = {}
new_data = {}
for key, value in ele.items():
# 判断传的 value 为 list 的多对多进行抽离使用set 进行更新
if isinstance(value, list) and value and isinstance(value[0], int):
m2m_dict[key] = value
else:
new_data[key] = value
object, _ = obj.objects.get_or_create(id=ele.get("id"), defaults=new_data)
for key, m2m in m2m_dict.items():
m2m = list(set(m2m))
if m2m and len(m2m) > 0 and m2m[0]:
exec(f"""
if object.{key}:
values_list = object.{key}.all().values_list('id', flat=True)
values_list = list(set(list(values_list) + {m2m}))
object.{key}.set(values_list)
""")
print(f"初始化完成[{obj._meta.label} => {name}]")
def run(self):
raise NotImplementedError('.run() must be overridden')

View File

@@ -0,0 +1,155 @@
# -*- coding: utf-8 -*-
from rest_framework.decorators import action
from rest_framework.permissions import AllowAny
from dvadmin.utils.json_response import DetailResponse
class FastCrudMixin:
"""
定义快速CRUD数据操作的通用方法
"""
# 需要CRUD的字段
crud_fields = None
# 排除CRUD的字段
exclude_fields = None
# 自定义CRUD的JSON
custom_crud_json = None
# 需要修改的CRUD键值对
crud_update_key_value = None
# 将Django的字段类型处理为JS类型
def __handle_type(self, type):
if type in ['BigAutoField', 'CharField']:
return "input"
if type == 'DateTimeField':
return "datetime"
if type == 'DateField':
return "date"
if type == 'IntegerField':
return "number"
if type == 'BooleanField':
return "dict-switch"
# 获取字段属性信息
def __get_field_attribute(self):
result = []
queryset = self.get_queryset()
__name = ""
__verbose_name = ""
__type = "text"
# 判断指定CRUD字段
if self.crud_fields and type(self.crud_fields == list):
for item in self.crud_fields:
try:
field = queryset.model._meta.get_field(item)
field_type = field.get_internal_type()
__name = field.name
# 判断类型是否为外键类型,外键类型需要特殊方式获取verbose_name
if field_type in ['ForeignKey', 'OneToOneField', 'ManyToManyField']:
continue
# try:
# verbose_name = Users._meta.get_field(str(field.name)).verbose_name
# except:
# pass
else:
__verbose_name = field.verbose_name
__type = self.__handle_type(field_type)
except:
continue
result.append({"key": __name, "title": __verbose_name, "type": __type})
else:
# 获取model的所有字段及属性
model_fields = queryset.model._meta.get_fields()
# 遍历所有字段属性
for field in model_fields:
field_type = field.get_internal_type()
__name = field.name
# 判断需要排除的CRUD字段
if self.exclude_fields and type(self.exclude_fields == list):
if __name in self.exclude_fields:
continue
# 判断类型是否为外键类型,外键类型需要特殊方式获取verbose_name
if field_type in ['ForeignKey', 'OneToOneField', 'ManyToManyField']:
continue
# try:
# verbose_name = Users._meta.get_field(str(field.name)).verbose_name
# except:
# pass
else:
__verbose_name = field.verbose_name
__type = self.__handle_type(field_type)
result.append({"key": __name, "title": __verbose_name, "type": __type})
return result
#获取key
def __find_key(self,dct: dict,
target_key: str,
level: int = -1,
index: int = -1) -> tuple:
"""Find a key within a nested dictionary and return its level and index."""
for k, v in dct.items():
level += 1
index += 1
if k == target_key:
return level, index
elif isinstance(v, list):
for i, dct_ in enumerate(v):
if isinstance(dct_, dict):
result = self.__find_key(dct_, target_key)
if result is not None:
return result
else:
continue
elif isinstance(v, str) or isinstance(v, int) or isinstance(v, float):
continue
# 修改字典中key的value
def __update_nested_dict(self,nested_dict: dict,
target_key: str,
new_value) -> dict:
"""Update a nested dictionary with a new value."""
split_target_key = target_key.split('.')
if len(split_target_key) > 1:
new_dict = nested_dict[split_target_key[0]]
for item in split_target_key[1:-1]:
new_dict = new_dict[item]
self.__update_nested_dict(new_dict, split_target_key[-1], new_value)
else:
nested_dict[target_key] = new_value
return nested_dict
# 处理crud,返回columns
def __handle_crud(self):
result = self.__get_field_attribute()
columns = dict()
for item in result:
key = item.get('key')
title = item.get('title')
type = item.get('type')
columns[key] = {
"title": title,
"key": key,
"type": type
}
# 对自定义的crud配置合并
if self.custom_crud_json and isinstance(self.custom_crud_json,dict):
columns = columns | self.custom_crud_json
# 对curd进行修改配置
if self.crud_update_key_value and isinstance(self.crud_update_key_value,dict):
for key, value in self.crud_update_key_value.items():
columns = self.__update_nested_dict(columns,key,value)
return columns
@action(methods=['get'], detail=False,permission_classes=[AllowAny])
def init_crud(self, request):
self.permission_classes = [AllowAny]
columns = self.__handle_crud()
expose = "({expose,dict})=>{"
ret = "return {"
res = "}}"
data = f"""{expose}
{ret}
columns:{columns}
{res}
"""
return DetailResponse(data=data)

View File

@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/2 002 16:06
@Remark: 自定义异常处理
"""
import logging
import traceback
from django.db.models import ProtectedError
from django.http import Http404
from rest_framework.exceptions import APIException as DRFAPIException, AuthenticationFailed, NotAuthenticated
from rest_framework.status import HTTP_401_UNAUTHORIZED
from rest_framework.views import set_rollback, exception_handler
from dvadmin.utils.json_response import ErrorResponse
logger = logging.getLogger(__name__)
class CustomAuthenticationFailed(NotAuthenticated):
# 设置 status_code 属性为 400
status_code = 400
def CustomExceptionHandler(ex, context):
"""
统一异常拦截处理
目的:(1)取消所有的500异常响应,统一响应为标准错误返回
(2)准确显示错误信息
:param ex:
:param context:
:return:
"""
msg = ''
code = 4000
# 调用默认的异常处理函数
response = exception_handler(ex, context)
if isinstance(ex, AuthenticationFailed):
# 如果是身份验证错误
if response and response.data.get('detail') == "Given token not valid for any token type":
code = 401
msg = ex.detail
elif response and response.data.get('detail') == "Token is blacklisted":
# token在黑名单
return ErrorResponse(status=HTTP_401_UNAUTHORIZED)
else:
code = 401
msg = ex.detail
elif isinstance(ex,Http404):
code = 400
msg = "接口地址不正确"
elif isinstance(ex, DRFAPIException):
set_rollback()
msg = ex.detail
if isinstance(msg,dict):
for k, v in msg.items():
for i in v:
msg = "%s:%s" % (k, i)
elif isinstance(ex, ProtectedError):
set_rollback()
msg = "删除失败:该条数据与其他数据有相关绑定"
# elif isinstance(ex, DatabaseError):
# set_rollback()
# msg = "接口服务器异常,请联系管理员"
elif isinstance(ex, Exception):
logger.exception(traceback.format_exc())
msg = str(ex)
return ErrorResponse(msg=msg, code=code)

View File

@@ -0,0 +1,331 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/6 006 12:39
@Remark: 自定义过滤器
"""
import operator
import re
from collections import OrderedDict
from functools import reduce
import six
from django.db.models import Q, F
from django.db.models.constants import LOOKUP_SEP
from django_filters import utils
from django_filters.filters import CharFilter
from django_filters.rest_framework import DjangoFilterBackend
from django_filters.utils import get_model_field
from rest_framework.filters import BaseFilterBackend
from dvadmin.system.models import Dept, ApiWhiteList, RoleMenuButtonPermission
def get_dept(dept_id: int, dept_all_list=None, dept_list=None):
"""
递归获取部门的所有下级部门
:param dept_id: 需要获取的部门id
:param dept_all_list: 所有部门列表
:param dept_list: 递归部门list
:return:
"""
if not dept_all_list:
dept_all_list = Dept.objects.all().values("id", "parent")
if dept_list is None:
dept_list = [dept_id]
for ele in dept_all_list:
if ele.get("parent") == dept_id:
dept_list.append(ele.get("id"))
get_dept(ele.get("id"), dept_all_list, dept_list)
return list(set(dept_list))
class DataLevelPermissionsFilter(BaseFilterBackend):
"""
数据 级权限过滤器
0. 获取用户的部门id没有部门则返回空
1. 判断过滤的数据是否有创建人所在部门 "creator" 字段,没有则返回全部
2. 如果用户没有关联角色则返回本部门数据
3. 根据角色的最大权限进行数据过滤(会有多个角色,进行去重取最大权限)
3.1 判断用户是否为超级管理员角色/如果有1(所有数据) 则返回所有数据
4. 只为仅本人数据权限时只返回过滤本人数据,并且部门为自己本部门(考虑到用户会变部门,只能看当前用户所在的部门数据)
5. 自定数据权限 获取部门,根据部门过滤
"""
def filter_queryset(self, request, queryset, view):
"""
接口白名单是否认证数据权限
"""
api = request.path # 当前请求接口
method = request.method # 当前请求方法
methodList = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
method = methodList.index(method)
# ***接口白名单***
api_white_list = ApiWhiteList.objects.filter(enable_datasource=False).values(
permission__api=F("url"), permission__method=F("method")
)
api_white_list = [
str(item.get("permission__api").replace("{id}", ".*?"))
+ ":"
+ str(item.get("permission__method"))
for item in api_white_list
if item.get("permission__api")
]
for item in api_white_list:
new_api = f"{api}:{method}"
matchObj = re.match(item, new_api, re.M | re.I)
if matchObj is None:
continue
else:
return queryset
"""
判断是否为超级管理员:
如果不是超级管理员,则进入下一步权限判断
"""
if request.user.is_superuser == 0:
return self._extracted_from_filter_queryset_33(request, queryset, api, method)
else:
return queryset
# TODO Rename this here and in `filter_queryset`
def _extracted_from_filter_queryset_33(self, request, queryset, api, method):
# 0. 获取用户的部门id没有部门则返回空
user_dept_id = getattr(request.user, "dept_id", None)
if not user_dept_id:
return queryset.none()
# 1. 判断过滤的数据是否有创建人所在部门 "dept_belong_id" 字段
if not getattr(queryset.model, "dept_belong_id", None):
return queryset
# 2. 如果用户没有关联角色则返回本部门数据
if not hasattr(request.user, "role"):
return queryset.filter(dept_belong_id=user_dept_id)
# 3. 根据所有角色 获取所有权限范围
# (0, "仅本人数据权限"),
# (1, "本部门及以下数据权限"),
# (2, "本部门数据权限"),
# (3, "全部数据权限"),
# (4, "自定数据权限")
replace_str = re.compile('\d')
re_api = replace_str.sub('{id}', api)
role_id_list = request.user.role.values_list('id', flat=True)
role_permission_list=RoleMenuButtonPermission.objects.filter(
role__in=role_id_list,
role__status=1,
menu_button__api=re_api,
menu_button__method=method).values(
'data_range',
role_admin=F('role__admin')
)
dataScope_list = [] # 权限范围列表
for ele in role_permission_list:
# 判断用户是否为超级管理员角色/如果拥有[全部数据权限]则返回所有数据
if ele.get("data_range") == 3 or ele.get("role_admin") == True:
return queryset
dataScope_list.append(ele.get("data_range"))
dataScope_list = list(set(dataScope_list))
# 4. 只为仅本人数据权限时只返回过滤本人数据,并且部门为自己本部门(考虑到用户会变部门,只能看当前用户所在的部门数据)
if 0 in dataScope_list:
return queryset.filter(
creator=request.user, dept_belong_id=user_dept_id
)
# 5. 自定数据权限 获取部门,根据部门过滤
dept_list = []
for ele in dataScope_list:
if ele == 1:
dept_list.append(user_dept_id)
dept_list.extend(
get_dept(
user_dept_id,
)
)
elif ele == 2:
dept_list.append(user_dept_id)
elif ele == 4:
dept_list.extend(
request.user.role.filter(status=1).values_list(
"dept__id", flat=True
)
)
if queryset.model._meta.model_name == 'dept':
return queryset.filter(id__in=list(set(dept_list)))
return queryset.filter(dept_belong_id__in=list(set(dept_list)))
class CustomDjangoFilterBackend(DjangoFilterBackend):
lookup_prefixes = {
"^": "istartswith",
"=": "iexact",
"@": "search",
"$": "iregex",
"~": "icontains",
}
def construct_search(self, field_name, lookup_expr=None):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
lookup = lookup_expr
if field_name.endswith(lookup):
return field_name
return LOOKUP_SEP.join([field_name, lookup])
def find_filter_lookups(self, orm_lookups, search_term_key):
for lookup in orm_lookups:
# if lookup.find(search_term_key) >= 0:
new_lookup = lookup.split("__")[0]
# 修复条件搜索错误 bug
if new_lookup == search_term_key:
return lookup
return None
def get_filterset_class(self, view, queryset=None):
"""
Return the `FilterSet` class used to filter the queryset.
"""
filterset_class = getattr(view, "filterset_class", None)
filterset_fields = getattr(view, "filterset_fields", None)
# TODO: remove assertion in 2.1
if filterset_class is None and hasattr(view, "filter_class"):
utils.deprecate(
"`%s.filter_class` attribute should be renamed `filterset_class`."
% view.__class__.__name__
)
filterset_class = getattr(view, "filter_class", None)
# TODO: remove assertion in 2.1
if filterset_fields is None and hasattr(view, "filter_fields"):
utils.deprecate(
"`%s.filter_fields` attribute should be renamed `filterset_fields`."
% view.__class__.__name__
)
filterset_fields = getattr(view, "filter_fields", None)
if filterset_class:
filterset_model = filterset_class._meta.model
# FilterSets do not need to specify a Meta class
if filterset_model and queryset is not None:
assert issubclass(
queryset.model, filterset_model
), "FilterSet model %s does not match queryset model %s" % (
filterset_model,
queryset.model,
)
return filterset_class
if filterset_fields and queryset is not None:
MetaBase = getattr(self.filterset_base, "Meta", object)
class AutoFilterSet(self.filterset_base):
@classmethod
def get_filters(cls):
"""
Get all filters for the filterset. This is the combination of declared and
generated filters.
"""
# No model specified - skip filter generation
if not cls._meta.model:
return cls.declared_filters.copy()
# Determine the filters that should be included on the filterset.
filters = OrderedDict()
fields = cls.get_fields()
undefined = []
for field_name, lookups in fields.items():
field = get_model_field(cls._meta.model, field_name)
from django.db import models
from timezone_field import TimeZoneField
# 不进行 过滤的model 类
if isinstance(field, (models.JSONField, TimeZoneField)):
continue
# warn if the field doesn't exist.
if field is None:
undefined.append(field_name)
# 更新默认字符串搜索为模糊搜索
if isinstance(field, (models.CharField)) and filterset_fields == '__all__' and lookups == [
'exact']:
lookups = ['icontains']
for lookup_expr in lookups:
filter_name = cls.get_filter_name(field_name, lookup_expr)
# If the filter is explicitly declared on the class, skip generation
if filter_name in cls.declared_filters:
filters[filter_name] = cls.declared_filters[filter_name]
continue
if field is not None:
filters[filter_name] = cls.filter_for_field(
field, field_name, lookup_expr
)
# Allow Meta.fields to contain declared filters *only* when a list/tuple
if isinstance(cls._meta.fields, (list, tuple)):
undefined = [
f for f in undefined if f not in cls.declared_filters
]
if undefined:
raise TypeError(
"'Meta.fields' must not contain non-model field names: %s"
% ", ".join(undefined)
)
# Add in declared filters. This is necessary since we don't enforce adding
# declared filters to the 'Meta.fields' option
filters.update(cls.declared_filters)
return filters
class Meta(MetaBase):
model = queryset.model
fields = filterset_fields
return AutoFilterSet
return None
def filter_queryset(self, request, queryset, view):
filterset = self.get_filterset(request, queryset, view)
if filterset is None:
return queryset
if filterset.__class__.__name__ == "AutoFilterSet":
queryset = filterset.queryset
orm_lookups = []
for search_field in filterset.filters:
if isinstance(filterset.filters[search_field], CharFilter):
orm_lookups.append(
self.construct_search(six.text_type(search_field), filterset.filters[search_field].lookup_expr)
)
else:
orm_lookups.append(search_field)
conditions = []
queries = []
for search_term_key in filterset.data.keys():
orm_lookup = self.find_filter_lookups(orm_lookups, search_term_key)
if not orm_lookup:
continue
query = Q(**{orm_lookup: filterset.data[search_term_key]})
queries.append(query)
if len(queries) > 0:
conditions.append(reduce(operator.and_, queries))
queryset = queryset.filter(reduce(operator.and_, conditions))
return queryset
else:
return queryset
if not filterset.is_valid() and self.raise_exception:
raise utils.translate_validation(filterset.errors)
return filterset.qs

View File

@@ -0,0 +1,104 @@
import os
from git.repo import Repo
from git.repo.fun import is_git_dir
class GitRepository(object):
"""
git仓库管理
"""
def __init__(self, local_path, repo_url, branch='master'):
self.local_path = local_path
self.repo_url = repo_url
self.repo = None
self.initial(self.repo_url, branch)
def initial(self, repo_url, branch):
"""
初始化git仓库
:param repo_url:
:param branch:
:return:
"""
if not os.path.exists(self.local_path):
os.makedirs(self.local_path)
git_local_path = os.path.join(self.local_path, '.git')
if not is_git_dir(git_local_path):
self.repo = Repo.clone_from(repo_url, to_path=self.local_path, branch=branch)
else:
self.repo = Repo(self.local_path)
def pull(self):
"""
从线上拉最新代码
:return:
"""
self.repo.git.pull()
def branches(self):
"""
获取所有分支
:return:
"""
branches = self.repo.remote().refs
return [item.remote_head for item in branches if item.remote_head not in ['HEAD', ]]
def commits(self):
"""
获取所有提交记录
:return:
"""
commit_log = self.repo.git.log('--pretty={"commit":"%h","author":"%an","summary":"%s","date":"%cd"}',
max_count=50,
date='format:%Y-%m-%d %H:%M')
log_list = commit_log.split("\n")
return [eval(item) for item in log_list]
def tags(self):
"""
获取所有tag
:return:
"""
return [tag.name for tag in self.repo.tags]
def tags_exists(self, tag):
"""
tag是否存在
:return:
"""
return tag in self.tags()
def change_to_branch(self, branch):
"""
切换分支
:param branch:
:return:
"""
self.repo.git.checkout(branch)
def change_to_commit(self, branch, commit):
"""
切换commit
:param branch:
:param commit:
:return:
"""
self.change_to_branch(branch=branch)
self.repo.git.reset('--hard', commit)
def change_to_tag(self, tag):
"""
切换tag
:param tag:
:return:
"""
self.repo.git.checkout(tag)
# if __name__ == '__main__':
# local_path = os.path.join('codes', 't1')
# repo = GitRepository(local_path, remote_path)
# branch_list = repo.branches()
# print(branch_list)
# repo.change_to_branch('dev')
# repo.pull()

View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
import os
import re
from datetime import datetime
import openpyxl
from django.conf import settings
from dvadmin.utils.validator import CustomValidationError
def import_to_data(file_url, field_data, m2m_fields=None):
"""
读取导入的excel文件
:param file_url:
:param field_data: 首行数据源
:param m2m_fields: 多对多字段
:return:
"""
# 读取excel 文件
file_path_dir = os.path.join(settings.BASE_DIR, file_url)
workbook = openpyxl.load_workbook(file_path_dir)
table = workbook[workbook.sheetnames[0]]
theader = tuple(table.values)[0] #Excel的表头
is_update = '更新主键(勿改)' in theader #是否导入更新
if is_update is False: #不是更新时,删除id列
field_data.pop('id')
# 获取参数映射
validation_data_dict = {}
for key, value in field_data.items():
if isinstance(value, dict):
choices = value.get("choices", {})
data_dict = {}
if choices.get("data"):
for k, v in choices.get("data").items():
data_dict[k] = v
elif choices.get("queryset") and choices.get("values_name"):
data_list = choices.get("queryset").values(choices.get("values_name"), "id")
for ele in data_list:
data_dict[ele.get(choices.get("values_name"))] = ele.get("id")
else:
continue
validation_data_dict[key] = data_dict
# 创建一个空列表存储Excel的数据
tables = []
for i, row in enumerate(range(table.max_row)):
if i == 0:
continue
array = {}
for index, item in enumerate(field_data.items()):
items = list(item)
key = items[0]
values = items[1]
value_type = 'str'
if isinstance(values, dict):
value_type = values.get('type','str')
cell_value = table.cell(row=row + 1, column=index + 2).value
if cell_value is None or cell_value=='':
continue
elif value_type == 'date':
print(61, datetime.strptime(str(cell_value), '%Y-%m-%d %H:%M:%S').date())
try:
cell_value = datetime.strptime(str(cell_value), '%Y-%m-%d %H:%M:%S').date()
except:
raise CustomValidationError('日期格式不正确')
elif value_type == 'datetime':
cell_value = datetime.strptime(str(cell_value), '%Y-%m-%d %H:%M:%S')
else:
# 由于excel导入数字类型后会出现数字加 .0 的,进行处理
if type(cell_value) is float and str(cell_value).split(".")[1] == "0":
cell_value = int(str(cell_value).split(".")[0])
elif type(cell_value) is str:
cell_value = cell_value.strip(" \t\n\r")
if key in validation_data_dict:
array[key] = validation_data_dict.get(key, {}).get(cell_value, None)
if key in m2m_fields:
array[key] = list(
filter(
lambda x: x,
[
validation_data_dict.get(key, {}).get(value, None)
for value in re.split(r"[|.,;:\s]\s*", cell_value)
],
)
)
else:
array[key] = cell_value
tables.append(array)
return tables

View File

@@ -0,0 +1,345 @@
# -*- coding: utf-8 -*-
from urllib.parse import quote
from django.db import transaction
from django.http import HttpResponse
from openpyxl import Workbook
from openpyxl.worksheet.datavalidation import DataValidation
from openpyxl.utils import get_column_letter, quote_sheetname
from openpyxl.worksheet.table import Table, TableStyleInfo
from rest_framework.decorators import action
from rest_framework.request import Request
from dvadmin.utils.import_export import import_to_data
from dvadmin.utils.json_response import DetailResponse
from dvadmin.utils.request_util import get_verbose_name
class ImportSerializerMixin:
"""
自定义导入模板、导入功能
"""
# 导入字段
import_field_dict = {}
# 导入序列化器
import_serializer_class = None
# 表格表头最大宽度默认50个字符
export_column_width = 50
def is_number(self,num):
try:
float(num)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(num)
return True
except (TypeError, ValueError):
pass
return False
def get_string_len(self, string):
"""
获取字符串最大长度
:param string:
:return:
"""
length = 4
if string is None:
return length
if self.is_number(string):
return length
for char in string:
length += 2.1 if ord(char) > 256 else 1
return round(length, 1) if length <= self.export_column_width else self.export_column_width
@action(methods=['get','post'],detail=False)
@transaction.atomic # Django 事务,防止出错
def import_data(self, request: Request, *args, **kwargs):
"""
导入模板
:param request:
:param args:
:param kwargs:
:return:
"""
assert self.import_field_dict, "'%s' 请配置对应的导出模板字段。" % self.__class__.__name__
# 导出模板
if request.method == "GET":
# 示例数据
queryset = self.filter_queryset(self.get_queryset())
# 导出excel 表
response = HttpResponse(content_type="application/msexcel")
response["Access-Control-Expose-Headers"] = f"Content-Disposition"
response[
"Content-Disposition"
] = f'attachment;filename={quote(str(f"导入{get_verbose_name(queryset)}模板.xlsx"))}'
wb = Workbook()
ws1 = wb.create_sheet("data", 1)
ws1.sheet_state = "hidden"
ws = wb.active
row = get_column_letter(len(self.import_field_dict) + 1)
column = 10
header_data = [
"序号",
]
validation_data_dict = {}
for index, ele in enumerate(self.import_field_dict.values()):
if isinstance(ele, dict):
header_data.append(ele.get("title"))
choices = ele.get("choices", {})
if choices.get("data"):
data_list = []
data_list.extend(choices.get("data").keys())
validation_data_dict[ele.get("title")] = data_list
elif choices.get("queryset") and choices.get("values_name"):
data_list = choices.get("queryset").values_list(choices.get("values_name"), flat=True)
validation_data_dict[ele.get("title")] = list(data_list)
else:
continue
column_letter = get_column_letter(len(validation_data_dict))
dv = DataValidation(
type="list",
formula1=f"{quote_sheetname('data')}!${column_letter}$2:${column_letter}${len(validation_data_dict[ele.get('title')]) + 1}",
allow_blank=True,
)
ws.add_data_validation(dv)
dv.add(f"{get_column_letter(index + 2)}2:{get_column_letter(index + 2)}1048576")
else:
header_data.append(ele)
# 添加数据列
ws1.append(list(validation_data_dict.keys()))
for index, validation_data in enumerate(validation_data_dict.values()):
for inx, ele in enumerate(validation_data):
ws1[f"{get_column_letter(index + 1)}{inx + 2}"] = ele
# 插入导出模板正式数据
df_len_max = [self.get_string_len(ele) for ele in header_data]
ws.append(header_data)
#  更新列宽
for index, width in enumerate(df_len_max):
ws.column_dimensions[get_column_letter(index + 1)].width = width
tab = Table(displayName="Table1", ref=f"A1:{row}{column}") # 名称管理器
style = TableStyleInfo(
name="TableStyleLight11",
showFirstColumn=True,
showLastColumn=True,
showRowStripes=True,
showColumnStripes=True,
)
tab.tableStyleInfo = style
ws.add_table(tab)
wb.save(response)
return response
else:
# 从excel中组织对应的数据结构然后使用序列化器保存
queryset = self.filter_queryset(self.get_queryset())
# 获取多对多字段
m2m_fields = [
ele.name
for ele in queryset.model._meta.get_fields()
if hasattr(ele, "many_to_many") and ele.many_to_many == True
]
import_field_dict = {'id':'更新主键(勿改)',**self.import_field_dict}
data = import_to_data(request.data.get("url"), import_field_dict, m2m_fields)
for ele in data:
filter_dic = {'id':ele.get('id')}
instance = filter_dic and queryset.filter(**filter_dic).first()
# print(156,ele)
serializer = self.import_serializer_class(instance, data=ele, request=request)
serializer.is_valid(raise_exception=True)
serializer.save()
return DetailResponse(msg=f"导入成功!")
@action(methods=['get'],detail=False)
def update_template(self,request):
queryset = self.filter_queryset(self.get_queryset())
assert self.import_field_dict, "'%s' 请配置对应的导入模板字段。" % self.__class__.__name__
assert self.import_serializer_class, "'%s' 请配置对应的导入序列化器。" % self.__class__.__name__
data = self.import_serializer_class(queryset, many=True, request=request).data
# 导出excel 表
response = HttpResponse(content_type="application/msexcel")
response["Access-Control-Expose-Headers"] = f"Content-Disposition"
response["content-disposition"] = f'attachment;filename={quote(str(f"导出{get_verbose_name(queryset)}.xlsx"))}'
wb = Workbook()
ws1 = wb.create_sheet("data", 1)
ws1.sheet_state = "hidden"
ws = wb.active
import_field_dict = {}
header_data = ["序号","更新主键(勿改)"]
hidden_header = ["#","id"]
#----设置选项----
validation_data_dict = {}
for index, item in enumerate(self.import_field_dict.items()):
items = list(item)
key = items[0]
value = items[1]
if isinstance(value, dict):
header_data.append(value.get("title"))
hidden_header.append(value.get('display'))
choices = value.get("choices", {})
if choices.get("data"):
data_list = []
data_list.extend(choices.get("data").keys())
validation_data_dict[value.get("title")] = data_list
elif choices.get("queryset") and choices.get("values_name"):
data_list = choices.get("queryset").values_list(choices.get("values_name"), flat=True)
validation_data_dict[value.get("title")] = list(data_list)
else:
continue
column_letter = get_column_letter(len(validation_data_dict))
dv = DataValidation(
type="list",
formula1=f"{quote_sheetname('data')}!${column_letter}$2:${column_letter}${len(validation_data_dict[value.get('title')]) + 1}",
allow_blank=True,
)
ws.add_data_validation(dv)
dv.add(f"{get_column_letter(index + 3)}2:{get_column_letter(index + 3)}1048576")
else:
header_data.append(value)
hidden_header.append(key)
# 添加数据列
ws1.append(list(validation_data_dict.keys()))
for index, validation_data in enumerate(validation_data_dict.values()):
for inx, ele in enumerate(validation_data):
ws1[f"{get_column_letter(index + 1)}{inx + 2}"] = ele
#--------
df_len_max = [self.get_string_len(ele) for ele in header_data]
row = get_column_letter(len(hidden_header) + 1)
column = 1
ws.append(header_data)
for index, results in enumerate(data):
results_list = []
for h_index, h_item in enumerate(hidden_header):
for key, val in results.items():
if key == h_item:
if val is None or val == "":
results_list.append("")
elif isinstance(val,list):
results_list.append(str(val))
else:
results_list.append(val)
# 计算最大列宽度
if isinstance(val,str):
result_column_width = self.get_string_len(val)
if h_index != 0 and result_column_width > df_len_max[h_index]:
df_len_max[h_index] = result_column_width
ws.append([index+1,*results_list])
column += 1
#  更新列宽
for index, width in enumerate(df_len_max):
ws.column_dimensions[get_column_letter(index + 1)].width = width
tab = Table(displayName="Table", ref=f"A1:{row}{column}") # 名称管理器
style = TableStyleInfo(
name="TableStyleLight11",
showFirstColumn=True,
showLastColumn=True,
showRowStripes=True,
showColumnStripes=True,
)
tab.tableStyleInfo = style
ws.add_table(tab)
wb.save(response)
return response
class ExportSerializerMixin:
"""
自定义导出功能
"""
# 导出字段
export_field_label = []
# 导出序列化器
export_serializer_class = None
# 表格表头最大宽度默认50个字符
export_column_width = 50
def is_number(self,num):
try:
float(num)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(num)
return True
except (TypeError, ValueError):
pass
return False
def get_string_len(self, string):
"""
获取字符串最大长度
:param string:
:return:
"""
length = 4
if string is None:
return length
if self.is_number(string):
return length
for char in string:
length += 2.1 if ord(char) > 256 else 1
return round(length, 1) if length <= self.export_column_width else self.export_column_width
@action(methods=['get'],detail=False)
def export_data(self, request: Request, *args, **kwargs):
"""
导出功能
:param request:
:param args:
:param kwargs:
:return:
"""
queryset = self.filter_queryset(self.get_queryset())
assert self.export_field_label, "'%s' 请配置对应的导出模板字段。" % self.__class__.__name__
assert self.export_serializer_class, "'%s' 请配置对应的导出序列化器。" % self.__class__.__name__
data = self.export_serializer_class(queryset, many=True, request=request).data
# 导出excel 表
response = HttpResponse(content_type="application/msexcel")
response["Access-Control-Expose-Headers"] = f"Content-Disposition"
response["content-disposition"] = f'attachment;filename={quote(str(f"导出{get_verbose_name(queryset)}.xlsx"))}'
wb = Workbook()
ws = wb.active
header_data = ["序号", *self.export_field_label.values()]
hidden_header = ["#", *self.export_field_label.keys()]
df_len_max = [self.get_string_len(ele) for ele in header_data]
row = get_column_letter(len(self.export_field_label) + 1)
column = 1
ws.append(header_data)
for index, results in enumerate(data):
results_list = []
for h_index, h_item in enumerate(hidden_header):
for key,val in results.items():
if key == h_item:
if val is None or val=="":
results_list.append("")
else:
results_list.append(val)
# 计算最大列宽度
result_column_width = self.get_string_len(val)
if h_index !=0 and result_column_width > df_len_max[h_index]:
df_len_max[h_index] = result_column_width
ws.append([index + 1, *results_list])
column += 1
#  更新列宽
for index, width in enumerate(df_len_max):
ws.column_dimensions[get_column_letter(index + 1)].width = width
tab = Table(displayName="Table", ref=f"A1:{row}{column}") # 名称管理器
style = TableStyleInfo(
name="TableStyleLight11",
showFirstColumn=True,
showLastColumn=True,
showRowStripes=True,
showColumnStripes=True,
)
tab.tableStyleInfo = style
ws.add_table(tab)
wb.save(response)
return response

View File

@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/2 002 14:43
@Remark: 自定义的JsonResonpse文件
"""
from rest_framework.response import Response
class SuccessResponse(Response):
"""
标准响应成功的返回, SuccessResponse(data)或者SuccessResponse(data=data)
(1)默认code返回2000, 不支持指定其他返回码
"""
def __init__(self, data=None, msg='success', status=None, template_name=None, headers=None, exception=False,
content_type=None,page=1,limit=1,total=1):
std_data = {
"code": 2000,
"page": page,
"limit": limit,
"total": total,
"data": data,
"msg": msg
}
super().__init__(std_data, status, template_name, headers, exception, content_type)
class DetailResponse(Response):
"""
不包含分页信息的接口返回,主要用于单条数据查询
(1)默认code返回2000, 不支持指定其他返回码
"""
def __init__(self, data=None, msg='success', status=None, template_name=None, headers=None, exception=False,
content_type=None,):
std_data = {
"code": 2000,
"data": data,
"msg": msg
}
super().__init__(std_data, status, template_name, headers, exception, content_type)
class ErrorResponse(Response):
"""
标准响应错误的返回,ErrorResponse(msg='xxx')
(1)默认错误码返回400, 也可以指定其他返回码:ErrorResponse(code=xxx)
"""
def __init__(self, data=None, msg='error', code=400, status=None, template_name=None, headers=None,
exception=False, content_type=None):
std_data = {
"code": code,
"data": data,
"msg": msg
}
super().__init__(std_data, status, template_name, headers, exception, content_type)

View File

@@ -0,0 +1,89 @@
"""
日志 django中间件
"""
import json
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from django.utils.deprecation import MiddlewareMixin
from dvadmin.system.models import OperationLog
from dvadmin.utils.request_util import get_request_user, get_request_ip, get_request_data, get_request_path, get_os, \
get_browser, get_verbose_name
class ApiLoggingMiddleware(MiddlewareMixin):
"""
用于记录API访问日志中间件
"""
def __init__(self, get_response=None):
super().__init__(get_response)
self.enable = getattr(settings, 'API_LOG_ENABLE', None) or False
self.methods = getattr(settings, 'API_LOG_METHODS', None) or set()
self.operation_log_id = None
@classmethod
def __handle_request(cls, request):
request.request_ip = get_request_ip(request)
request.request_data = get_request_data(request)
request.request_path = get_request_path(request)
def __handle_response(self, request, response):
# request_data,request_ip由PermissionInterfaceMiddleware中间件中添加的属性
body = getattr(request, 'request_data', {})
# 请求含有password则用*替换掉(暂时先用于所有接口的password请求参数)
if isinstance(body, dict) and body.get('password', ''):
body['password'] = '*' * len(body['password'])
if not hasattr(response, 'data') or not isinstance(response.data, dict):
response.data = {}
try:
if not response.data and response.content:
content = json.loads(response.content.decode())
response.data = content if isinstance(content, dict) else {}
except Exception:
return
user = get_request_user(request)
info = {
'request_ip': getattr(request, 'request_ip', 'unknown'),
'creator': user if not isinstance(user, AnonymousUser) else None,
'dept_belong_id': getattr(request.user, 'dept_id', None),
'request_method': request.method,
'request_path': request.request_path,
'request_body': body,
'response_code': response.data.get('code'),
'request_os': get_os(request),
'request_browser': get_browser(request),
'request_msg': request.session.get('request_msg'),
'status': True if response.data.get('code') in [2000, ] else False,
'json_result': {"code": response.data.get('code'), "msg": response.data.get('msg')},
}
operation_log, creat = OperationLog.objects.update_or_create(defaults=info, id=self.operation_log_id)
if not operation_log.request_modular and settings.API_MODEL_MAP.get(request.request_path, None):
operation_log.request_modular = settings.API_MODEL_MAP[request.request_path]
operation_log.save()
def process_view(self, request, view_func, view_args, view_kwargs):
if hasattr(view_func, 'cls') and hasattr(view_func.cls, 'queryset'):
if self.enable:
if self.methods == 'ALL' or request.method in self.methods:
log = OperationLog(request_modular=get_verbose_name(view_func.cls.queryset))
log.save()
self.operation_log_id = log.id
return
def process_request(self, request):
self.__handle_request(request)
def process_response(self, request, response):
"""
主要请求处理完之后记录
:param request:
:param response:
:return:
"""
if self.enable:
if self.methods == 'ALL' or request.method in self.methods:
self.__handle_response(request, response)
return response

View File

@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/5/31 031 22:08
@Remark: 公共基础model类
"""
import uuid
from django.apps import apps
from django.db import models
from django.db.models import QuerySet
from application import settings
table_prefix = settings.TABLE_PREFIX # 数据库表名前缀
class SoftDeleteQuerySet(QuerySet):
pass
class SoftDeleteManager(models.Manager):
"""支持软删除"""
def __init__(self, *args, **kwargs):
self.__add_is_del_filter = False
super(SoftDeleteManager, self).__init__(*args, **kwargs)
def filter(self, *args, **kwargs):
# 考虑是否主动传入is_deleted
if not kwargs.get('is_deleted') is None:
self.__add_is_del_filter = True
return super(SoftDeleteManager, self).filter(*args, **kwargs)
def get_queryset(self):
if self.__add_is_del_filter:
return SoftDeleteQuerySet(self.model, using=self._db).exclude(is_deleted=False)
return SoftDeleteQuerySet(self.model).exclude(is_deleted=True)
def get_by_natural_key(self,name):
return SoftDeleteQuerySet(self.model).get(username=name)
class SoftDeleteModel(models.Model):
"""
软删除模型
一旦继承,就将开启软删除
"""
is_deleted = models.BooleanField(verbose_name="是否软删除", help_text='是否软删除', default=False, db_index=True)
objects = SoftDeleteManager()
class Meta:
abstract = True
verbose_name = '软删除模型'
verbose_name_plural = verbose_name
def delete(self, using=None, soft_delete=True, *args, **kwargs):
"""
重写删除方法,直接开启软删除
"""
self.is_deleted = True
self.save(using=using)
class CoreModel(models.Model):
"""
核心标准抽象模型模型,可直接继承使用
增加审计字段, 覆盖字段时, 字段名称请勿修改, 必须统一审计字段名称
"""
id = models.BigAutoField(primary_key=True, help_text="Id", verbose_name="Id")
description = models.CharField(max_length=255, verbose_name="描述", null=True, blank=True, help_text="描述")
creator = models.ForeignKey(to=settings.AUTH_USER_MODEL, related_query_name='creator_query', null=True,
verbose_name='创建人', help_text="创建人", on_delete=models.SET_NULL, db_constraint=False)
modifier = models.CharField(max_length=255, null=True, blank=True, help_text="修改人", verbose_name="修改人")
dept_belong_id = models.CharField(max_length=255, help_text="数据归属部门", null=True, blank=True, verbose_name="数据归属部门")
update_datetime = models.DateTimeField(auto_now=True, null=True, blank=True, help_text="修改时间", verbose_name="修改时间")
create_datetime = models.DateTimeField(auto_now_add=True, null=True, blank=True, help_text="创建时间",
verbose_name="创建时间")
class Meta:
abstract = True
verbose_name = '核心模型'
verbose_name_plural = verbose_name
def get_all_models_objects(model_name=None):
"""
获取所有 models 对象
:return: {}
"""
settings.ALL_MODELS_OBJECTS = {}
if not settings.ALL_MODELS_OBJECTS:
all_models = apps.get_models()
for item in list(all_models):
table = {
"tableName": item._meta.verbose_name,
"table": item.__name__,
"tableFields": []
}
for field in item._meta.fields:
fields = {
"title": field.verbose_name,
"field": field.name
}
table['tableFields'].append(fields)
settings.ALL_MODELS_OBJECTS.setdefault(item.__name__, {"table": table, "object": item})
if model_name:
return settings.ALL_MODELS_OBJECTS[model_name] or {}
return settings.ALL_MODELS_OBJECTS or {}

View File

@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2020/4/16 23:35
"""
from collections import OrderedDict
from django.core import paginator
from django.core.paginator import Paginator as DjangoPaginator, InvalidPage
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response
class CustomPagination(PageNumberPagination):
page_size = 10
page_size_query_param = "limit"
max_page_size = 999
django_paginator_class = DjangoPaginator
def paginate_queryset(self, queryset, request, view=None):
"""
Paginate a queryset if required, either returning a
page object, or `None` if pagination is not configured for this view.
"""
empty = True
page_size = self.get_page_size(request)
if not page_size:
return None
paginator = self.django_paginator_class(queryset, page_size)
page_number = request.query_params.get(self.page_query_param, 1)
if page_number in self.last_page_strings:
page_number = paginator.num_pages
try:
self.page = paginator.page(page_number)
except InvalidPage as exc:
# msg = self.invalid_page_message.format(
# page_number=page_number, message=str(exc)
# )
# raise NotFound(msg)
empty = False
pass
if paginator.num_pages > 1 and self.template is not None:
# The browsable API should display pagination controls.
self.display_page_controls = True
self.request = request
if not empty:
self.page = []
return list(self.page)
def get_paginated_response(self, data):
code = 2000
msg = 'success'
page =int(self.get_page_number(self.request, paginator)) or 1
total=self.page.paginator.count if self.page else 0
limit= int(self.get_page_size(self.request)) or 10
is_next= self.page.has_next()
is_previous= self.page.has_previous()
data=data
if not data:
code = 2000
msg = "暂无数据"
data = []
return Response(OrderedDict([
('code', code),
('msg', msg),
('page', page),
('limit', limit),
('total',total),
('is_next',is_next),
('is_previous', is_previous),
('data', data)
]))

View File

@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/6 006 10:30
@Remark: 自定义权限
"""
import re
from django.contrib.auth.models import AnonymousUser
from django.db.models import F
from rest_framework.permissions import BasePermission
from dvadmin.system.models import ApiWhiteList, RoleMenuButtonPermission
def ValidationApi(reqApi, validApi):
"""
验证当前用户是否有接口权限
:param reqApi: 当前请求的接口
:param validApi: 用于验证的接口
:return: True或者False
"""
if validApi is not None:
valid_api = validApi.replace('{id}', '.*?')
matchObj = re.match(valid_api, reqApi, re.M | re.I)
if matchObj:
return True
else:
return False
else:
return False
class AnonymousUserPermission(BasePermission):
"""
匿名用户权限
"""
def has_permission(self, request, view):
if isinstance(request.user, AnonymousUser):
return False
return True
def ReUUID(api):
"""
将接口的uuid替换掉
:param api:
:return:
"""
pattern = re.compile(r'[a-f\d]{4}(?:[a-f\d]{4}-){4}[a-f\d]{12}/$')
m = pattern.search(api)
if m:
res = api.replace(m.group(0), ".*/")
return res
else:
return None
class CustomPermission(BasePermission):
"""自定义权限"""
def has_permission(self, request, view):
if isinstance(request.user, AnonymousUser):
return False
# 判断是否是超级管理员
if request.user.is_superuser:
return True
else:
api = request.path # 当前请求接口
method = request.method # 当前请求方法
methodList = ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'PATCH']
method = methodList.index(method)
# ***接口白名单***
api_white_list = ApiWhiteList.objects.values(permission__api=F('url'), permission__method=F('method'))
api_white_list = [
str(item.get('permission__api').replace('{id}', '([a-zA-Z0-9-]+)')) + ":" + str(
item.get('permission__method')) + '$' for item in api_white_list if item.get('permission__api')]
# ********#
if not hasattr(request.user, "role"):
return False
role_id_list = request.user.role.values_list('id',flat=True)
userApiList = RoleMenuButtonPermission.objects.filter(role__in=role_id_list).values(permission__api=F('menu_button__api'), permission__method=F('menu_button__method')) # 获取当前用户的角色拥有的所有接口
ApiList = [
str(item.get('permission__api').replace('{id}', '([a-zA-Z0-9-]+)')) + ":" + str(
item.get('permission__method')) + '$' for item in userApiList if item.get('permission__api')]
new_api_ist = api_white_list + ApiList
new_api = api + ":" + str(method)
for item in new_api_ist:
matchObj = re.match(item, new_api, re.M | re.I)
if matchObj is None:
continue
else:
return True
else:
return False

View File

@@ -0,0 +1,219 @@
"""
Request工具类
"""
import json
import requests
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.auth.models import AnonymousUser
from django.urls.resolvers import ResolverMatch
from rest_framework_simplejwt.authentication import JWTAuthentication
from user_agents import parse
from dvadmin.system.models import LoginLog
def get_request_user(request):
"""
获取请求user
(1)如果request里的user没有认证,那么则手动认证一次
:param request:
:return:
"""
user: AbstractBaseUser = getattr(request, 'user', None)
if user and user.is_authenticated:
return user
try:
user, tokrn = JWTAuthentication().authenticate(request)
except Exception as e:
pass
return user or AnonymousUser()
def get_request_ip(request):
"""
获取请求IP
:param request:
:return:
"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR', '')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[-1].strip()
return ip
ip = request.META.get('REMOTE_ADDR', '') or getattr(request, 'request_ip', None)
return ip or 'unknown'
def get_request_data(request):
"""
获取请求参数
:param request:
:return:
"""
request_data = getattr(request, 'request_data', None)
if request_data:
return request_data
data: dict = {**request.GET.dict(), **request.POST.dict()}
if not data:
try:
body = request.body
if body:
data = json.loads(body)
except Exception as e:
pass
if not isinstance(data, dict):
data = {'data': data}
return data
def get_request_path(request, *args, **kwargs):
"""
获取请求路径
:param request:
:param args:
:param kwargs:
:return:
"""
request_path = getattr(request, 'request_path', None)
if request_path:
return request_path
values = []
for arg in args:
if len(arg) == 0:
continue
if isinstance(arg, str):
values.append(arg)
elif isinstance(arg, (tuple, set, list)):
values.extend(arg)
elif isinstance(arg, dict):
values.extend(arg.values())
if len(values) == 0:
return request.path
path: str = request.path
for value in values:
path = path.replace('/' + value, '/' + '{id}')
return path
def get_request_canonical_path(request, ):
"""
获取请求路径
:param request:
:param args:
:param kwargs:
:return:
"""
request_path = getattr(request, 'request_canonical_path', None)
if request_path:
return request_path
path: str = request.path
resolver_match: ResolverMatch = request.resolver_match
for value in resolver_match.args:
path = path.replace(f"/{value}", "/{id}")
for key, value in resolver_match.kwargs.items():
if key == 'pk':
path = path.replace(f"/{value}", f"/{{id}}")
continue
path = path.replace(f"/{value}", f"/{{{key}}}")
return path
def get_browser(request, ):
"""
获取浏览器名
:param request:
:param args:
:param kwargs:
:return:
"""
ua_string = request.META['HTTP_USER_AGENT']
user_agent = parse(ua_string)
return user_agent.get_browser()
def get_os(request, ):
"""
获取操作系统
:param request:
:param args:
:param kwargs:
:return:
"""
ua_string = request.META['HTTP_USER_AGENT']
user_agent = parse(ua_string)
return user_agent.get_os()
def get_verbose_name(queryset=None, view=None, model=None):
"""
获取 verbose_name
:param request:
:param view:
:return:
"""
try:
if queryset is not None and hasattr(queryset, 'model'):
model = queryset.model
elif view and hasattr(view.get_queryset(), 'model'):
model = view.get_queryset().model
elif view and hasattr(view.get_serializer(), 'Meta') and hasattr(view.get_serializer().Meta, 'model'):
model = view.get_serializer().Meta.model
if model:
return getattr(model, '_meta').verbose_name
else:
model = queryset.model._meta.verbose_name
except Exception as e:
pass
return model if model else ""
def get_ip_analysis(ip):
"""
获取ip详细概略
:param ip: ip地址
:return:
"""
data = {
"continent": "",
"country": "",
"province": "",
"city": "",
"district": "",
"isp": "",
"area_code": "",
"country_english": "",
"country_code": "",
"longitude": "",
"latitude": ""
}
if ip != 'unknown' and ip:
if getattr(settings, 'ENABLE_LOGIN_ANALYSIS_LOG', True):
try:
res = requests.get(url='https://ip.django-vue-admin.com/ip/analysis', params={"ip": ip}, timeout=5)
if res.status_code == 200:
res_data = res.json()
if res_data.get('code') == 0:
data = res_data.get('data')
return data
except Exception as e:
print(e)
return data
def save_login_log(request):
"""
保存登录日志
:return:
"""
ip = get_request_ip(request=request)
analysis_data = get_ip_analysis(ip)
analysis_data['username'] = request.user.username
analysis_data['ip'] = ip
analysis_data['agent'] = str(parse(request.META['HTTP_USER_AGENT']))
analysis_data['browser'] = get_browser(request)
analysis_data['os'] = get_os(request)
analysis_data['creator_id'] = request.user.id
analysis_data['dept_belong_id'] = getattr(request.user, 'dept_id', '')
LoginLog.objects.create(**analysis_data)

View File

@@ -0,0 +1,170 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/1 001 22:47
@Remark: 自定义序列化器
"""
from rest_framework import serializers
from rest_framework.fields import empty
from rest_framework.request import Request
from rest_framework.serializers import ModelSerializer
from django.utils.functional import cached_property
from rest_framework.utils.serializer_helpers import BindingDict
from dvadmin.system.models import Users
from django_restql.mixins import DynamicFieldsMixin
class CustomModelSerializer(DynamicFieldsMixin, ModelSerializer):
"""
增强DRF的ModelSerializer,可自动更新模型的审计字段记录
(1)self.request能获取到rest_framework.request.Request对象
"""
# 修改人的审计字段名称, 默认modifier, 继承使用时可自定义覆盖
modifier_field_id = "modifier"
modifier_name = serializers.SerializerMethodField(read_only=True)
def get_modifier_name(self, instance):
if not hasattr(instance, "modifier"):
return None
queryset = (
Users.objects.filter(id=instance.modifier)
.values_list("name", flat=True)
.first()
)
if queryset:
return queryset
return None
# 创建人的审计字段名称, 默认creator, 继承使用时可自定义覆盖
creator_field_id = "creator"
creator_name = serializers.SlugRelatedField(
slug_field="name", source="creator", read_only=True
)
# 数据所属部门字段
dept_belong_id_field_name = "dept_belong_id"
# 添加默认时间返回格式
create_datetime = serializers.DateTimeField(
format="%Y-%m-%d %H:%M:%S", required=False, read_only=True
)
update_datetime = serializers.DateTimeField(
format="%Y-%m-%d %H:%M:%S", required=False
)
def __init__(self, instance=None, data=empty, request=None, **kwargs):
super().__init__(instance, data, **kwargs)
self.request: Request = request or self.context.get("request", None)
def save(self, **kwargs):
return super().save(**kwargs)
def create(self, validated_data):
if self.request:
if str(self.request.user) != "AnonymousUser":
if self.modifier_field_id in self.fields.fields:
validated_data[self.modifier_field_id] = self.get_request_user_id()
if self.creator_field_id in self.fields.fields:
validated_data[self.creator_field_id] = self.request.user
if (
self.dept_belong_id_field_name in self.fields.fields
and validated_data.get(self.dept_belong_id_field_name, None) is None
):
validated_data[self.dept_belong_id_field_name] = getattr(
self.request.user, "dept_id", None
)
return super().create(validated_data)
def update(self, instance, validated_data):
if self.request:
if str(self.request.user) != "AnonymousUser":
if self.modifier_field_id in self.fields.fields:
validated_data[self.modifier_field_id] = self.get_request_user_id()
if hasattr(self.instance, self.modifier_field_id):
setattr(
self.instance, self.modifier_field_id, self.get_request_user_id()
)
return super().update(instance, validated_data)
def get_request_username(self):
if getattr(self.request, "user", None):
return getattr(self.request.user, "username", None)
return None
def get_request_name(self):
if getattr(self.request, "user", None):
return getattr(self.request.user, "name", None)
return None
def get_request_user_id(self):
if getattr(self.request, "user", None):
return getattr(self.request.user, "id", None)
return None
@property
def errors(self):
# get errors
errors = super().errors
verbose_errors = {}
# fields = { field.name: field.verbose_name } for each field in model
fields = {field.name: field.verbose_name for field in
self.Meta.model._meta.get_fields() if hasattr(field, 'verbose_name')}
# iterate over errors and replace error key with verbose name if exists
for field_name, error in errors.items():
if field_name in fields:
verbose_errors[str(fields[field_name])] = error
else:
verbose_errors[field_name] = error
return verbose_errors
# @cached_property
# def fields(self):
# fields = BindingDict(self)
# for key, value in self.get_fields().items():
# fields[key] = value
#
# if not hasattr(self, '_context'):
# return fields
# is_root = self.root == self
# parent_is_list_root = self.parent == self.root and getattr(self.parent, 'many', False)
# if not (is_root or parent_is_list_root):
# return fields
#
# try:
# request = self.request or self.context['request']
# except KeyError:
# return fields
# params = getattr(
# request, 'query_params', getattr(request, 'GET', None)
# )
# if params is None:
# pass
# try:
# filter_fields = params.get('_fields', None).split(',')
# except AttributeError:
# filter_fields = None
#
# try:
# omit_fields = params.get('_exclude', None).split(',')
# except AttributeError:
# omit_fields = []
#
# existing = set(fields.keys())
# if filter_fields is None:
# allowed = existing
# else:
# allowed = set(filter(None, filter_fields))
#
# omitted = set(filter(None, omit_fields))
# for field in existing:
# if field not in allowed:
# fields.pop(field, None)
# if field in omitted:
# fields.pop(field, None)
#
# return fields

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/8/21 021 9:48
@Remark:
"""
import hashlib
import random
CHAR_SET = ("2", "3", "4", "5",
"6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H",
"J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V",
"W", "X", "Y", "Z")
def random_str(number=16):
"""
返回特定长度的随机字符串(非进制)
:return:
"""
result = ""
for i in range(0, number):
inx = random.randint(0, len(CHAR_SET) - 1)
result += CHAR_SET[inx]
return result
def has_md5(str, salt='123456'):
"""
md5 加密
:param str:
:param salt:
:return:
"""
# satl是盐值默认是123456
str = str + salt
md = hashlib.md5() # 构造一个md5对象
md.update(str.encode())
res = md.hexdigest()
return res

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/8/12 012 10:25
@Remark: swagger配置
"""
from drf_yasg.generators import OpenAPISchemaGenerator
from drf_yasg.inspectors import SwaggerAutoSchema
from application.settings import SWAGGER_SETTINGS
def get_summary(string):
if string is not None:
result = string.strip().replace(" ","").split("\n")
return result[0]
class CustomSwaggerAutoSchema(SwaggerAutoSchema):
def get_tags(self, operation_keys=None):
tags = super().get_tags(operation_keys)
if "api" in tags and operation_keys:
# `operation_keys` 内容像这样 ['v1', 'prize_join_log', 'create']
tags[0] = operation_keys[SWAGGER_SETTINGS.get('AUTO_SCHEMA_TYPE', 2)]
return tags
def get_summary_and_description(self):
summary_and_description = super().get_summary_and_description()
summary = get_summary(self.__dict__.get('view').__doc__)
description = summary_and_description[1]
return summary,description
class CustomOpenAPISchemaGenerator(OpenAPISchemaGenerator):
def get_schema(self, request=None, public=False):
"""Generate a :class:`.Swagger` object with custom tags"""
swagger = super().get_schema(request, public)
swagger.tags = [
{
"name": "token",
"description": "认证相关"
},
]
return swagger

View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/2 002 17:03
@Remark: 自定义验证器
"""
from django.db import DataError
from rest_framework.exceptions import APIException
from rest_framework.validators import UniqueValidator
class CustomValidationError(APIException):
"""
继承并重写验证器返回的结果,避免暴露字段
"""
def __init__(self, detail):
self.detail = detail
def qs_exists(queryset):
try:
return queryset.exists()
except (TypeError, ValueError, DataError):
return False
def qs_filter(queryset, **kwargs):
try:
return queryset.filter(**kwargs)
except (TypeError, ValueError, DataError):
return queryset.none()
class CustomUniqueValidator(UniqueValidator):
"""
继承,重写必填字段的验证器结果,防止字段暴露
"""
def filter_queryset(self, value, queryset, field_name):
"""
Filter the queryset to all instances matching the given attribute.
"""
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset
def __call__(self, value, serializer_field):
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer_field.parent, 'instance', None)
queryset = self.queryset
queryset = self.filter_queryset(value, queryset, field_name)
queryset = self.exclude_current_instance(queryset, instance)
if qs_exists(queryset):
raise CustomValidationError(self.message)
def __repr__(self):
return super().__repr__()

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/1 001 22:57
@Remark: 自定义视图集
"""
import uuid
from django.db import transaction
from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.viewsets import ModelViewSet
from dvadmin.utils.filters import DataLevelPermissionsFilter
from dvadmin.utils.import_export_mixin import ExportSerializerMixin, ImportSerializerMixin
from dvadmin.utils.json_response import SuccessResponse, ErrorResponse, DetailResponse
from dvadmin.utils.permission import CustomPermission
from django_restql.mixins import QueryArgumentsMixin
class CustomModelViewSet(ModelViewSet, ImportSerializerMixin, ExportSerializerMixin, QueryArgumentsMixin):
"""
自定义的ModelViewSet:
统一标准的返回格式;新增,查询,修改可使用不同序列化器
(1)ORM性能优化, 尽可能使用values_queryset形式
(2)xxx_serializer_class 某个方法下使用的序列化器(xxx=create|update|list|retrieve|destroy)
(3)filter_fields = '__all__' 默认支持全部model中的字段查询(除json字段外)
(4)import_field_dict={} 导入时的字段字典 {model值: model的label}
(5)export_field_label = [] 导出时的字段
"""
values_queryset = None
ordering_fields = '__all__'
create_serializer_class = None
update_serializer_class = None
filter_fields = '__all__'
search_fields = ()
extra_filter_class = [DataLevelPermissionsFilter]
permission_classes = [CustomPermission]
import_field_dict = {}
export_field_label = {}
def filter_queryset(self, queryset):
for backend in set(set(self.filter_backends) | set(self.extra_filter_class or [])):
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
def get_queryset(self):
if getattr(self, 'values_queryset', None):
return self.values_queryset
return super().get_queryset()
def get_serializer_class(self):
action_serializer_name = f"{self.action}_serializer_class"
action_serializer_class = getattr(self, action_serializer_name, None)
if action_serializer_class:
return action_serializer_class
return super().get_serializer_class()
# 通过many=True直接改造原有的API使其可以批量创建
def get_serializer(self, *args, **kwargs):
serializer_class = self.get_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
if isinstance(self.request.data, list):
with transaction.atomic():
return serializer_class(many=True, *args, **kwargs)
else:
return serializer_class(*args, **kwargs)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data, request=request)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
return DetailResponse(data=serializer.data, msg="新增成功")
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True, request=request)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True, request=request)
return SuccessResponse(data=serializer.data, msg="获取成功")
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return DetailResponse(data=serializer.data, msg="获取成功")
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, request=request, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return DetailResponse(data=serializer.data, msg="更新成功")
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
instance.delete()
return DetailResponse(data=[], msg="删除成功")
keys = openapi.Schema(description='主键列表', type=openapi.TYPE_ARRAY, items=openapi.TYPE_STRING)
@swagger_auto_schema(request_body=openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['keys'],
properties={'keys': keys}
), operation_summary='批量删除')
@action(methods=['delete'], detail=False)
def multiple_delete(self, request, *args, **kwargs):
request_data = request.data
keys = request_data.get('keys', None)
if keys:
self.get_queryset().filter(id__in=keys).delete()
return SuccessResponse(data=[], msg="删除成功")
else:
return ErrorResponse(msg="未获取到keys字段")