From 04ee6bf0e0c0941a555818463731d5813e7d3f87 Mon Sep 17 00:00:00 2001 From: XIE7654 <765462425@qq.com> Date: Tue, 22 Jul 2025 12:09:37 +0800 Subject: [PATCH] add ai drawing --- ai_service/api/v1/drawing/__init__.py | 82 +++++---- ai_service/api/v1/drawing/vo.py | 10 ++ ai_service/crud/base.py | 8 +- ai_service/llm/adapter/tongyi.py | 11 +- ai_service/models/base.py | 28 +++- ai_service/services/drawing_service.py | 48 ++++-- web/apps/web-antd/src/api/ai/drawing.ts | 8 +- .../web-antd/src/views/ai/drawing/index.vue | 157 +++++++++++++----- 8 files changed, 240 insertions(+), 112 deletions(-) create mode 100644 ai_service/api/v1/drawing/vo.py diff --git a/ai_service/api/v1/drawing/__init__.py b/ai_service/api/v1/drawing/__init__.py index 5e8a045..5706182 100644 --- a/ai_service/api/v1/drawing/__init__.py +++ b/ai_service/api/v1/drawing/__init__.py @@ -1,18 +1,20 @@ import json import os -from fastapi import APIRouter, Depends, HTTPException, Body, Query -from pydantic import BaseModel +from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session + +from api.v1.drawing.vo import CreateDrawingTaskRequest from db.session import get_db from deps.auth import get_current_user from llm.factory import get_adapter from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status +from utils.resp import resp_error, resp_success router = APIRouter(prefix="/drawing", tags=["drawing"]) @router.get("/") -def api_get_image_page( +def api_get_drawing_page( page: int = Query(1, ge=1), page_size: int = Query(12, ge=1, le=100), db: Session = Depends(get_db), @@ -27,19 +29,12 @@ def api_get_image_page( "pic_url": img.pic_url, "status": img.status, "error_message": img.error_message, - "created_at": img.created_at if hasattr(img, 'created_at') else None + "create_time": img.create_time if hasattr(img, 'create_time') else None } for img in data["items"] ] return data -class CreateDrawingTaskRequest(BaseModel): - prompt: str - style: str = 'auto' - size: str = '1024x1024' - model: str = 'wanx_v1' - platform: str = 'tongyi' - n: int = 1 @router.post("/") @@ -48,40 +43,37 @@ def api_create_image_task( db: Session = Depends(get_db), user=Depends(get_current_user) ): - user_id = user["user_id"] style = req.style size = req.size platform = req.platform n = req.n prompt = req.prompt model = req.model - print(user_id, req.platform, req.size, req.model, req.prompt) api_key = os.getenv("DASHSCOPE_API_KEY") adapter = get_adapter('tongyi', api_key=api_key, model=model) try: - # rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size) - # print(rsp, 'rsp') - res_json = { - "status_code": 200, - "request_id": "31b04171-011c-96bd-ac00-f0383b669cc7", - "code": "", - "message": "", - "output": { - "task_id": "4f90cf14-a34e-4eae-xxxxxxxx", - "task_status": "PENDING", - "results": [] - }, - "usage": None - } - rsp = res_json + rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size) + # rsp = { + # "status_code": 200, + # "request_id": "31b04171-011c-96bd-ac00-f0383b669cc7", + # "code": "", + # "message": "", + # "output": { + # "task_id": "4f90cf14-a34e-4eae-xxxxxxxx", + # "task_status": "PENDING", + # "results": [] + # }, + # "usage": None + # } if rsp['status_code'] != 200: - raise HTTPException(status_code=500, detail=rsp['message']) + return resp_error(message=rsp['message'], code=rsp['status_code']) + # raise HTTPException(status_code=500, detail=rsp['message']) option = { 'style': style } drawing = create_drawing_task( db=db, - user_id=user["user_id"], + user=user, platform=platform, model=model, rsp=rsp, @@ -89,22 +81,28 @@ def api_create_image_task( size=size, options=json.dumps(option) ) - return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status} + return resp_success(data={ + "id": drawing.id, + "task_id": drawing.task_id, + "status": drawing.status + }) + # return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status} except NotImplementedError: print("该服务商不支持图片生成") -@router.get("/{id}") -def api_fetch_image_task_status( - id: int, +@router.get("/{drawing_id}/") +def api_fetch_drawing_task_status( + drawing_id: int, db: Session = Depends(get_db) ): - image, err = fetch_drawing_task_status(db, id) - if not image: + drawing, err = fetch_drawing_task_status(db, drawing_id) + if not drawing: raise HTTPException(status_code=404, detail=err or "任务不存在") - return { - "id": image.id, - "status": image.status, - "pic_url": image.pic_url, - "error_message": image.error_message - } + return resp_success(data={ + "id": drawing.id, + "status": drawing.status, + "pic_url": drawing.pic_url, + "error_message": drawing.error_message + }) + diff --git a/ai_service/api/v1/drawing/vo.py b/ai_service/api/v1/drawing/vo.py new file mode 100644 index 0000000..e1e0d62 --- /dev/null +++ b/ai_service/api/v1/drawing/vo.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class CreateDrawingTaskRequest(BaseModel): + prompt: str + style: str = 'auto' + size: str = '1024*1024' + model: str = 'wanx_v1' + platform: str = 'tongyi' + n: int = 1 diff --git a/ai_service/crud/base.py b/ai_service/crud/base.py index b28212e..40b4f61 100644 --- a/ai_service/crud/base.py +++ b/ai_service/crud/base.py @@ -23,10 +23,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典 # 自动填充时间字段(如果模型有created_at/updated_at) - if hasattr(self.model, "created_at"): - obj_in_data["created_at"] = datetime.now() - if hasattr(self.model, "updated_at"): - obj_in_data["updated_at"] = datetime.now() + if hasattr(self.model, "create_time"): + obj_in_data["create_time"] = datetime.now() + if hasattr(self.model, "update_time"): + obj_in_data["update_time"] = datetime.now() db_obj = self.model(**obj_in_data) # 实例化模型 db.add(db_obj) diff --git a/ai_service/llm/adapter/tongyi.py b/ai_service/llm/adapter/tongyi.py index 2079843..973926a 100644 --- a/ai_service/llm/adapter/tongyi.py +++ b/ai_service/llm/adapter/tongyi.py @@ -24,7 +24,6 @@ class TongYiAdapter(MultiModalAICapability): yield chunk def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs): - print(self.model, self.api_key, 'key') """创建异步图片生成任务""" rsp = ImageSynthesis.async_call( api_key=self.api_key, @@ -34,14 +33,12 @@ class TongYiAdapter(MultiModalAICapability): style=f'<{style}>', size=size ) - print(rsp, 'rsp') + return rsp def fetch_drawing_task_status(self, task): """获取异步图片任务状态""" - status = ImageSynthesis.fetch(task) - if status.status_code == HTTPStatus.OK: - return status.output.task_status - else: - raise Exception(f"Failed, status_code: {status.status_code}, code: {status.code}, message: {status.message}") + rsp = ImageSynthesis.fetch(task, api_key=self.api_key) + return rsp + \ No newline at end of file diff --git a/ai_service/models/base.py b/ai_service/models/base.py index 8cfe077..a55dc9d 100644 --- a/ai_service/models/base.py +++ b/ai_service/models/base.py @@ -1,3 +1,5 @@ +from datetime import datetime + from db.session import Base from sqlalchemy import ( Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey @@ -8,6 +10,26 @@ class CoreModel(Base): __abstract__ = True id = Column(Integer, primary_key=True, autoincrement=True) - create_time = Column(DateTime) - update_time = Column(DateTime) - is_deleted = Column(Boolean, default=False) + remark = Column(String(256), nullable=True, comment="备注") + creator = Column(String(64), nullable=True, comment="创建人") + modifier = Column(String(64), nullable=True, comment="修改人") + + # 创建时间 - 使用函数默认值,在插入时自动生成 + create_time = Column(DateTime, default=datetime.now(), comment="创建时间") + + # 修改时间 - 使用SQL函数,在更新时自动触发 + update_time = Column(DateTime, default=datetime.now(), onupdate=datetime.now(), comment="修改时间") + + is_deleted = Column(Boolean, default=False, comment="是否软删除") + + # 软删除方法 + # def soft_delete(self, session): + # self.is_deleted = True + # self.modifier = get_current_user() # 需要实现这个函数获取当前用户 + # self.update_time = datetime.utcnow() + # session.commit() + + # 查询时自动过滤已删除记录 + @classmethod + def get_active(cls, session): + return session.query(cls).filter(cls.is_deleted == False) \ No newline at end of file diff --git a/ai_service/services/drawing_service.py b/ai_service/services/drawing_service.py index 0792b50..2efd0db 100644 --- a/ai_service/services/drawing_service.py +++ b/ai_service/services/drawing_service.py @@ -1,3 +1,4 @@ +import os from datetime import datetime from dashscope import ImageSynthesis @@ -5,22 +6,24 @@ from http import HTTPStatus from sqlalchemy import desc +from llm.factory import get_adapter from models.ai import Drawing from sqlalchemy.orm import Session -def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp, +def create_drawing_task(db: Session, user, platform: str, model: str, prompt: str, size: str, rsp, options: str = None): # 写入数据库 + drawing = Drawing( - user_id=user_id, + user_id=user['user_id'], + creator=user['username'], + modifier=user['username'], platform=platform, model=model, prompt=prompt, - width=int(size.split('x')[0]), - height=int(size.split('x')[1]), - create_time=datetime.now(), - update_time=datetime.now(), + width=int(size.split('*')[0]), + height=int(size.split('*')[1]), options=options, status=rsp['output']['task_status'], task_id=rsp['output']['task_id'], @@ -35,17 +38,32 @@ def fetch_drawing_task_status(db: Session, drawing_id: int): drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first() if not drawing or not drawing.task_id: return None, "任务不存在" - status = ImageSynthesis.fetch(drawing.task_id) - if status.status_code == HTTPStatus.OK: - # 可根据 status.output.task_status 更新数据库 - drawing.status = status.output.task_status - if hasattr(status.output, 'results') and status.output.results: - drawing.pic_url = status.output.results[0].url - db.commit() - db.refresh(drawing) + if drawing.status in ("PENDING", 'RUNNING'): + api_key = os.getenv("DASHSCOPE_API_KEY") + adapter = get_adapter('tongyi', api_key=api_key, model='') + rsp = adapter.fetch_drawing_task_status(drawing.task_id) + if rsp['status_code'] == HTTPStatus.OK: + # 可根据 status.output.task_status 更新数据库 + if rsp['output']['task_status'] == 'SUCCEEDED': + drawing.update_time = datetime.now() + drawing.status = rsp['output']['task_status'] + drawing.pic_url = rsp['output']['results'][0]['url'] + db.commit() + db.refresh(drawing) + elif rsp['output']['task_status'] == 'FAILED': + drawing.update_time = datetime.now() + drawing.status = rsp['output']['task_status'] + drawing.error_message = rsp['output']['message'] + db.commit() + db.refresh(drawing) + elif rsp['output']['task_status'] == 'RUNNING': + drawing.update_time = datetime.now() + drawing.status = rsp['output']['task_status'] + db.commit() + db.refresh(drawing) return drawing, None else: - return None, status.message + return drawing, None def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12): diff --git a/web/apps/web-antd/src/api/ai/drawing.ts b/web/apps/web-antd/src/api/ai/drawing.ts index 2eaf5b5..541693b 100644 --- a/web/apps/web-antd/src/api/ai/drawing.ts +++ b/web/apps/web-antd/src/api/ai/drawing.ts @@ -9,7 +9,7 @@ export interface CreateImageTaskParams { n?: number; } -export async function createImageTask(params: CreateImageTaskParams) { +export async function createDrawing(params: CreateImageTaskParams) { const res = await fetchWithAuth('drawing/', { method: 'POST', body: JSON.stringify(params), @@ -20,7 +20,7 @@ export async function createImageTask(params: CreateImageTaskParams) { return await res.json(); } -export async function fetchImageTaskStatus(id: number) { +export async function getDrawingDetail(id: number) { const res = await fetchWithAuth(`drawing/${id}/`); if (!res.ok) { throw new Error('查询图片任务状态失败'); @@ -33,12 +33,12 @@ export interface GetImagePageParams { page_size?: number; } -export async function getImagePage(params: GetImagePageParams = {}) { +export async function getDrawingPage(params: GetImagePageParams = {}) { const query = new URLSearchParams(); if (params.page) query.append('page', String(params.page)); if (params.page_size) query.append('page_size', String(params.page_size)); const res = await fetchWithAuth( - `drawing${query.toString() ? `?${query.toString()}` : ''}`, + `drawing/${query.toString() ? `?${query.toString()}` : ''}`, ); if (!res.ok) { throw new Error('获取图片分页失败'); diff --git a/web/apps/web-antd/src/views/ai/drawing/index.vue b/web/apps/web-antd/src/views/ai/drawing/index.vue index 540ba06..e60b947 100644 --- a/web/apps/web-antd/src/views/ai/drawing/index.vue +++ b/web/apps/web-antd/src/views/ai/drawing/index.vue @@ -1,5 +1,5 @@