90 lines
3.1 KiB
Python
90 lines
3.1 KiB
Python
from typing import Generic, TypeVar, List, Optional, Dict, Any
|
||
from fastapi import HTTPException
|
||
from sqlalchemy.orm import Session
|
||
from datetime import datetime
|
||
|
||
# 定义泛型变量(分别对应:SQLAlchemy模型、创建Pydantic模型、更新Pydantic模型)
|
||
ModelType = TypeVar("ModelType")
|
||
CreateSchemaType = TypeVar("CreateSchemaType")
|
||
UpdateSchemaType = TypeVar("UpdateSchemaType")
|
||
|
||
|
||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||
def __init__(self, model: ModelType):
|
||
"""
|
||
初始化CRUD类,需要传入SQLAlchemy模型
|
||
:param model: SQLAlchemy模型类(如AIApiKey、AIModel等)
|
||
"""
|
||
self.model = model
|
||
|
||
# 创建
|
||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||
"""创建一条记录"""
|
||
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
|
||
|
||
# 自动填充时间字段(如果模型有created_at/updated_at)
|
||
if hasattr(self.model, "created_at"):
|
||
obj_in_data["created_at"] = datetime.now()
|
||
if hasattr(self.model, "updated_at"):
|
||
obj_in_data["updated_at"] = datetime.now()
|
||
|
||
db_obj = self.model(**obj_in_data) # 实例化模型
|
||
db.add(db_obj)
|
||
db.commit()
|
||
db.refresh(db_obj)
|
||
return db_obj
|
||
|
||
# 按ID查询
|
||
def get(self, db: Session, id: int) -> Optional[ModelType]:
|
||
"""按ID查询单条记录"""
|
||
return db.query(self.model).filter(self.model.id == id).first()
|
||
|
||
# 按条件查询单条记录
|
||
def get_by(self, db: Session, **kwargs) -> Optional[ModelType]:
|
||
"""按条件查询单条记录(如get_by(name="test"))"""
|
||
return db.query(self.model).filter_by(**kwargs).first()
|
||
|
||
# 分页查询所有
|
||
def get_multi(
|
||
self, db: Session, *, page: int = 0, limit: int = 100
|
||
) -> List[ModelType]:
|
||
"""分页查询多条记录"""
|
||
return db.query(self.model).offset(page).limit(limit).all()
|
||
|
||
# 更新
|
||
def update(
|
||
self,
|
||
db: Session,
|
||
*,
|
||
db_obj: ModelType,
|
||
obj_in: UpdateSchemaType | Dict[str, Any]
|
||
) -> ModelType:
|
||
"""更新记录(支持Pydantic模型或字典)"""
|
||
if isinstance(obj_in, dict):
|
||
update_data = obj_in
|
||
else:
|
||
update_data = obj_in.model_dump(exclude_unset=True) # 只更新提供的字段
|
||
|
||
# 遍历更新字段
|
||
for field in update_data:
|
||
if hasattr(db_obj, field):
|
||
setattr(db_obj, field, update_data[field])
|
||
|
||
# 自动更新updated_at(如果模型有该字段)
|
||
if hasattr(db_obj, "updated_at"):
|
||
db_obj.updated_at = datetime.now()
|
||
|
||
db.add(db_obj)
|
||
db.commit()
|
||
db.refresh(db_obj)
|
||
return db_obj
|
||
|
||
# 删除
|
||
def remove(self, db: Session, *, id: int) -> ModelType:
|
||
"""删除记录"""
|
||
obj = db.query(self.model).get(id)
|
||
if not obj:
|
||
raise HTTPException(status_code=404, detail=f"{self.model.__name__}不存在")
|
||
db.delete(obj)
|
||
db.commit()
|
||
return obj |