Files
django-vue3-admin-gd/ai_service/crud/base.py
2025-07-22 12:09:37 +08:00

90 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Generic, TypeVar, List, Optional, Dict, Any
from fastapi import HTTPException
from sqlalchemy.orm import Session
from datetime import datetime
# 定义泛型变量分别对应SQLAlchemy模型、创建Pydantic模型、更新Pydantic模型
ModelType = TypeVar("ModelType")
CreateSchemaType = TypeVar("CreateSchemaType")
UpdateSchemaType = TypeVar("UpdateSchemaType")
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: ModelType):
"""
初始化CRUD类需要传入SQLAlchemy模型
:param model: SQLAlchemy模型类如AIApiKey、AIModel等
"""
self.model = model
# 创建
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
"""创建一条记录"""
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
# 自动填充时间字段如果模型有created_at/updated_at
if hasattr(self.model, "create_time"):
obj_in_data["create_time"] = datetime.now()
if hasattr(self.model, "update_time"):
obj_in_data["update_time"] = datetime.now()
db_obj = self.model(**obj_in_data) # 实例化模型
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
# 按ID查询
def get(self, db: Session, id: int) -> Optional[ModelType]:
"""按ID查询单条记录"""
return db.query(self.model).filter(self.model.id == id).first()
# 按条件查询单条记录
def get_by(self, db: Session, **kwargs) -> Optional[ModelType]:
"""按条件查询单条记录如get_by(name="test")"""
return db.query(self.model).filter_by(**kwargs).first()
# 分页查询所有
def get_multi(
self, db: Session, *, page: int = 0, limit: int = 100
) -> List[ModelType]:
"""分页查询多条记录"""
return db.query(self.model).offset(page).limit(limit).all()
# 更新
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | Dict[str, Any]
) -> ModelType:
"""更新记录支持Pydantic模型或字典"""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True) # 只更新提供的字段
# 遍历更新字段
for field in update_data:
if hasattr(db_obj, field):
setattr(db_obj, field, update_data[field])
# 自动更新updated_at如果模型有该字段
if hasattr(db_obj, "updated_at"):
db_obj.updated_at = datetime.now()
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
# 删除
def remove(self, db: Session, *, id: int) -> ModelType:
"""删除记录"""
obj = db.query(self.model).get(id)
if not obj:
raise HTTPException(status_code=404, detail=f"{self.model.__name__}不存在")
db.delete(obj)
db.commit()
return obj