add ai drawing
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dashscope import ImageSynthesis
|
||||
@@ -5,22 +6,24 @@ from http import HTTPStatus
|
||||
|
||||
from sqlalchemy import desc
|
||||
|
||||
from llm.factory import get_adapter
|
||||
from models.ai import Drawing
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp,
|
||||
def create_drawing_task(db: Session, user, platform: str, model: str, prompt: str, size: str, rsp,
|
||||
options: str = None):
|
||||
# 写入数据库
|
||||
|
||||
drawing = Drawing(
|
||||
user_id=user_id,
|
||||
user_id=user['user_id'],
|
||||
creator=user['username'],
|
||||
modifier=user['username'],
|
||||
platform=platform,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
width=int(size.split('x')[0]),
|
||||
height=int(size.split('x')[1]),
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now(),
|
||||
width=int(size.split('*')[0]),
|
||||
height=int(size.split('*')[1]),
|
||||
options=options,
|
||||
status=rsp['output']['task_status'],
|
||||
task_id=rsp['output']['task_id'],
|
||||
@@ -35,17 +38,32 @@ def fetch_drawing_task_status(db: Session, drawing_id: int):
|
||||
drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first()
|
||||
if not drawing or not drawing.task_id:
|
||||
return None, "任务不存在"
|
||||
status = ImageSynthesis.fetch(drawing.task_id)
|
||||
if status.status_code == HTTPStatus.OK:
|
||||
# 可根据 status.output.task_status 更新数据库
|
||||
drawing.status = status.output.task_status
|
||||
if hasattr(status.output, 'results') and status.output.results:
|
||||
drawing.pic_url = status.output.results[0].url
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
if drawing.status in ("PENDING", 'RUNNING'):
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
adapter = get_adapter('tongyi', api_key=api_key, model='')
|
||||
rsp = adapter.fetch_drawing_task_status(drawing.task_id)
|
||||
if rsp['status_code'] == HTTPStatus.OK:
|
||||
# 可根据 status.output.task_status 更新数据库
|
||||
if rsp['output']['task_status'] == 'SUCCEEDED':
|
||||
drawing.update_time = datetime.now()
|
||||
drawing.status = rsp['output']['task_status']
|
||||
drawing.pic_url = rsp['output']['results'][0]['url']
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
elif rsp['output']['task_status'] == 'FAILED':
|
||||
drawing.update_time = datetime.now()
|
||||
drawing.status = rsp['output']['task_status']
|
||||
drawing.error_message = rsp['output']['message']
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
elif rsp['output']['task_status'] == 'RUNNING':
|
||||
drawing.update_time = datetime.now()
|
||||
drawing.status = rsp['output']['task_status']
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
return drawing, None
|
||||
else:
|
||||
return None, status.message
|
||||
return drawing, None
|
||||
|
||||
|
||||
def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12):
|
||||
|
||||
Reference in New Issue
Block a user