add ai drawing

This commit is contained in:
XIE7654
2025-07-22 12:09:37 +08:00
parent 71d5053b9c
commit 04ee6bf0e0
8 changed files with 240 additions and 112 deletions

View File

@@ -1,18 +1,20 @@
import json
import os
from fastapi import APIRouter, Depends, HTTPException, Body, Query
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from api.v1.drawing.vo import CreateDrawingTaskRequest
from db.session import get_db
from deps.auth import get_current_user
from llm.factory import get_adapter
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
from utils.resp import resp_error, resp_success
router = APIRouter(prefix="/drawing", tags=["drawing"])
@router.get("/")
def api_get_image_page(
def api_get_drawing_page(
page: int = Query(1, ge=1),
page_size: int = Query(12, ge=1, le=100),
db: Session = Depends(get_db),
@@ -27,19 +29,12 @@ def api_get_image_page(
"pic_url": img.pic_url,
"status": img.status,
"error_message": img.error_message,
"created_at": img.created_at if hasattr(img, 'created_at') else None
"create_time": img.create_time if hasattr(img, 'create_time') else None
}
for img in data["items"]
]
return data
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024x1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1
@router.post("/")
@@ -48,40 +43,37 @@ def api_create_image_task(
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
user_id = user["user_id"]
style = req.style
size = req.size
platform = req.platform
n = req.n
prompt = req.prompt
model = req.model
print(user_id, req.platform, req.size, req.model, req.prompt)
api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model=model)
try:
# rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# print(rsp, 'rsp')
res_json = {
"status_code": 200,
"request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
"code": "",
"message": "",
"output": {
"task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
"task_status": "PENDING",
"results": []
},
"usage": None
}
rsp = res_json
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# rsp = {
# "status_code": 200,
# "request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
# "code": "",
# "message": "",
# "output": {
# "task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
# "task_status": "PENDING",
# "results": []
# },
# "usage": None
# }
if rsp['status_code'] != 200:
raise HTTPException(status_code=500, detail=rsp['message'])
return resp_error(message=rsp['message'], code=rsp['status_code'])
# raise HTTPException(status_code=500, detail=rsp['message'])
option = {
'style': style
}
drawing = create_drawing_task(
db=db,
user_id=user["user_id"],
user=user,
platform=platform,
model=model,
rsp=rsp,
@@ -89,22 +81,28 @@ def api_create_image_task(
size=size,
options=json.dumps(option)
)
return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
return resp_success(data={
"id": drawing.id,
"task_id": drawing.task_id,
"status": drawing.status
})
# return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
except NotImplementedError:
print("该服务商不支持图片生成")
@router.get("/{id}")
def api_fetch_image_task_status(
id: int,
@router.get("/{drawing_id}/")
def api_fetch_drawing_task_status(
drawing_id: int,
db: Session = Depends(get_db)
):
image, err = fetch_drawing_task_status(db, id)
if not image:
drawing, err = fetch_drawing_task_status(db, drawing_id)
if not drawing:
raise HTTPException(status_code=404, detail=err or "任务不存在")
return {
"id": image.id,
"status": image.status,
"pic_url": image.pic_url,
"error_message": image.error_message
}
return resp_success(data={
"id": drawing.id,
"status": drawing.status,
"pic_url": drawing.pic_url,
"error_message": drawing.error_message
})

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024*1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1

View File

@@ -23,10 +23,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
# 自动填充时间字段如果模型有created_at/updated_at
if hasattr(self.model, "created_at"):
obj_in_data["created_at"] = datetime.now()
if hasattr(self.model, "updated_at"):
obj_in_data["updated_at"] = datetime.now()
if hasattr(self.model, "create_time"):
obj_in_data["create_time"] = datetime.now()
if hasattr(self.model, "update_time"):
obj_in_data["update_time"] = datetime.now()
db_obj = self.model(**obj_in_data) # 实例化模型
db.add(db_obj)

View File

@@ -24,7 +24,6 @@ class TongYiAdapter(MultiModalAICapability):
yield chunk
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
print(self.model, self.api_key, 'key')
"""创建异步图片生成任务"""
rsp = ImageSynthesis.async_call(
api_key=self.api_key,
@@ -34,14 +33,12 @@ class TongYiAdapter(MultiModalAICapability):
style=f'<{style}>',
size=size
)
print(rsp, 'rsp')
return rsp
def fetch_drawing_task_status(self, task):
"""获取异步图片任务状态"""
status = ImageSynthesis.fetch(task)
if status.status_code == HTTPStatus.OK:
return status.output.task_status
else:
raise Exception(f"Failed, status_code: {status.status_code}, code: {status.code}, message: {status.message}")
rsp = ImageSynthesis.fetch(task, api_key=self.api_key)
return rsp

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from db.session import Base
from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
@@ -8,6 +10,26 @@ class CoreModel(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
create_time = Column(DateTime)
update_time = Column(DateTime)
is_deleted = Column(Boolean, default=False)
remark = Column(String(256), nullable=True, comment="备注")
creator = Column(String(64), nullable=True, comment="创建人")
modifier = Column(String(64), nullable=True, comment="修改人")
# 创建时间 - 使用函数默认值,在插入时自动生成
create_time = Column(DateTime, default=datetime.now(), comment="创建时间")
# 修改时间 - 使用SQL函数在更新时自动触发
update_time = Column(DateTime, default=datetime.now(), onupdate=datetime.now(), comment="修改时间")
is_deleted = Column(Boolean, default=False, comment="是否软删除")
# 软删除方法
# def soft_delete(self, session):
# self.is_deleted = True
# self.modifier = get_current_user() # 需要实现这个函数获取当前用户
# self.update_time = datetime.utcnow()
# session.commit()
# 查询时自动过滤已删除记录
@classmethod
def get_active(cls, session):
return session.query(cls).filter(cls.is_deleted == False)

View File

@@ -1,3 +1,4 @@
import os
from datetime import datetime
from dashscope import ImageSynthesis
@@ -5,22 +6,24 @@ from http import HTTPStatus
from sqlalchemy import desc
from llm.factory import get_adapter
from models.ai import Drawing
from sqlalchemy.orm import Session
def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp,
def create_drawing_task(db: Session, user, platform: str, model: str, prompt: str, size: str, rsp,
options: str = None):
# 写入数据库
drawing = Drawing(
user_id=user_id,
user_id=user['user_id'],
creator=user['username'],
modifier=user['username'],
platform=platform,
model=model,
prompt=prompt,
width=int(size.split('x')[0]),
height=int(size.split('x')[1]),
create_time=datetime.now(),
update_time=datetime.now(),
width=int(size.split('*')[0]),
height=int(size.split('*')[1]),
options=options,
status=rsp['output']['task_status'],
task_id=rsp['output']['task_id'],
@@ -35,17 +38,32 @@ def fetch_drawing_task_status(db: Session, drawing_id: int):
drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first()
if not drawing or not drawing.task_id:
return None, "任务不存在"
status = ImageSynthesis.fetch(drawing.task_id)
if status.status_code == HTTPStatus.OK:
# 可根据 status.output.task_status 更新数据库
drawing.status = status.output.task_status
if hasattr(status.output, 'results') and status.output.results:
drawing.pic_url = status.output.results[0].url
db.commit()
db.refresh(drawing)
if drawing.status in ("PENDING", 'RUNNING'):
api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model='')
rsp = adapter.fetch_drawing_task_status(drawing.task_id)
if rsp['status_code'] == HTTPStatus.OK:
# 可根据 status.output.task_status 更新数据库
if rsp['output']['task_status'] == 'SUCCEEDED':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
drawing.pic_url = rsp['output']['results'][0]['url']
db.commit()
db.refresh(drawing)
elif rsp['output']['task_status'] == 'FAILED':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
drawing.error_message = rsp['output']['message']
db.commit()
db.refresh(drawing)
elif rsp['output']['task_status'] == 'RUNNING':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
db.commit()
db.refresh(drawing)
return drawing, None
else:
return None, status.message
return drawing, None
def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12):