mock create drawing

This commit is contained in:
XIE7654
2025-07-21 22:22:32 +08:00
parent 816668530c
commit 71d5053b9c
19 changed files with 504 additions and 43 deletions

View File

@@ -0,0 +1,22 @@
from fastapi import APIRouter
from .chat import router as chat_router
from .drawing import router as drawing_router
# from .video import router as video_router
# from .audio import router as audio_router
# from .multimodal import router as multimodal_router
# from .model_manage import router as model_manage_router
# from .knowledge import router as knowledge_router
# from .system import router as system_router
# from .user import router as user_router
api_v1_router = APIRouter(prefix="/api/ai/v1")
api_v1_router.include_router(chat_router)
api_v1_router.include_router(drawing_router)
# api_v1_router.include_router(video_router)
# api_v1_router.include_router(audio_router)
# api_v1_router.include_router(multimodal_router)
# api_v1_router.include_router(model_manage_router)
# api_v1_router.include_router(knowledge_router)
# api_v1_router.include_router(system_router)
# api_v1_router.include_router(user_router)

View File

@@ -8,7 +8,7 @@ from typing import List
from pydantic import BaseModel, SecretStr
from langchain.chains import ConversationChain
from api.v1.vo import MessageVO
from api.v1.chat.vo import MessageVO
from deps.auth import get_current_user
from services.chat_service import ChatDBService
from db.session import get_db
@@ -16,7 +16,7 @@ from models.ai import ChatConversation, ChatMessage
from utils.resp import resp_success, Response
from langchain_deepseek import ChatDeepSeek
router = APIRouter()
router = APIRouter(prefix="/chat", tags=["chat"])
def get_deepseek_llm(api_key: SecretStr, model: str):
# deepseek 兼容 OpenAI API需指定 base_url

View File

@@ -0,0 +1,110 @@
import json
import os
from fastapi import APIRouter, Depends, HTTPException, Body, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from db.session import get_db
from deps.auth import get_current_user
from llm.factory import get_adapter
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
router = APIRouter(prefix="/drawing", tags=["drawing"])
@router.get("/")
def api_get_image_page(
page: int = Query(1, ge=1),
page_size: int = Query(12, ge=1, le=100),
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
data = get_drawing_page(db, user_id=user["user_id"], page=page, page_size=page_size)
# 序列化 items
data["items"] = [
{
"id": img.id,
"prompt": img.prompt,
"pic_url": img.pic_url,
"status": img.status,
"error_message": img.error_message,
"created_at": img.created_at if hasattr(img, 'created_at') else None
}
for img in data["items"]
]
return data
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024x1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1
@router.post("/")
def api_create_image_task(
req: CreateDrawingTaskRequest,
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
user_id = user["user_id"]
style = req.style
size = req.size
platform = req.platform
n = req.n
prompt = req.prompt
model = req.model
print(user_id, req.platform, req.size, req.model, req.prompt)
api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model=model)
try:
# rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# print(rsp, 'rsp')
res_json = {
"status_code": 200,
"request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
"code": "",
"message": "",
"output": {
"task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
"task_status": "PENDING",
"results": []
},
"usage": None
}
rsp = res_json
if rsp['status_code'] != 200:
raise HTTPException(status_code=500, detail=rsp['message'])
option = {
'style': style
}
drawing = create_drawing_task(
db=db,
user_id=user["user_id"],
platform=platform,
model=model,
rsp=rsp,
prompt=prompt,
size=size,
options=json.dumps(option)
)
return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
except NotImplementedError:
print("该服务商不支持图片生成")
@router.get("/{id}")
def api_fetch_image_task_status(
id: int,
db: Session = Depends(get_db)
):
image, err = fetch_drawing_task_status(db, id)
if not image:
raise HTTPException(status_code=404, detail=err or "任务不存在")
return {
"id": image.id,
"status": image.status,
"pic_url": image.pic_url,
"error_message": image.error_message
}

View File

@@ -15,7 +15,7 @@ class OpenAIAdapter(MultiModalAICapability):
yield chunk
# 如需图片生成DALL·E可实现如下
def create_image_task(self, prompt, **kwargs):
def create_drawing_task(self, **kwargs):
# 伪代码,需用 openai.Image.create
# import openai
# response = openai.Image.create(api_key=self.api_key, prompt=prompt, ...)

View File

@@ -23,24 +23,20 @@ class TongYiAdapter(MultiModalAICapability):
async for chunk in self.llm.astream(messages):
yield chunk
@staticmethod
def create_image_task(api_key, model, prompt: str, style='<watercolor>', size='1024*1024', n=1):
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
print(self.model, self.api_key, 'key')
"""创建异步图片生成任务"""
rsp = ImageSynthesis.async_call(
api_key=api_key,
model=model,
api_key=self.api_key,
model=self.model,
prompt=prompt,
n=n,
style=style,
style=f'<{style}>',
size=size
)
if rsp.status_code == HTTPStatus.OK:
return rsp
else:
raise Exception(f"Failed, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}")
print(rsp, 'rsp')
@staticmethod
def fetch_image_task_status(task):
def fetch_drawing_task_status(self, task):
"""获取异步图片任务状态"""
status = ImageSynthesis.fetch(task)
if status.status_code == HTTPStatus.OK:

View File

@@ -9,14 +9,14 @@ class MultiModalAICapability(ABC):
raise NotImplementedError("stream_chat not supported by this provider")
# 图片生成能力
def create_image_task(self, prompt, **kwargs):
raise NotImplementedError("image generation not supported by this provider")
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
raise NotImplementedError("drawing generation not supported by this provider")
def fetch_image_task_status(self, task):
raise NotImplementedError("image task status not supported by this provider")
def fetch_drawing_task_status(self, task):
raise NotImplementedError("drawing task status not supported by this provider")
def fetch_image_result(self, task):
raise NotImplementedError("image result not supported by this provider")
def fetch_drawing_result(self, task):
raise NotImplementedError("drawing result not supported by this provider")
# 视频生成能力
def create_video_task(self, prompt, **kwargs):

View File

@@ -1,8 +1,8 @@
import os
from fastapi import FastAPI
from dotenv import load_dotenv
from api.v1 import ai_chat
from fastapi.middleware.cors import CORSMiddleware
from api.v1 import api_v1_router
# 加载.env环境变量优先项目根目录
load_dotenv()
@@ -23,7 +23,8 @@ app.add_middleware(
)
# 注册路由
app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"])
# app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"])
app.include_router(api_v1_router)
# 健康检查
@app.get("/ping")

View File

@@ -253,8 +253,8 @@ class ChatRoleTool(Base):
tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)
class Image(CoreModel):
__tablename__ = 'ai_image'
class Drawing(CoreModel):
__tablename__ = 'ai_drawing'
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
public_status = Column(Boolean, default=False, nullable=False)

View File

@@ -0,0 +1,62 @@
from datetime import datetime
from dashscope import ImageSynthesis
from http import HTTPStatus
from sqlalchemy import desc
from models.ai import Drawing
from sqlalchemy.orm import Session
def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp,
options: str = None):
# 写入数据库
drawing = Drawing(
user_id=user_id,
platform=platform,
model=model,
prompt=prompt,
width=int(size.split('x')[0]),
height=int(size.split('x')[1]),
create_time=datetime.now(),
update_time=datetime.now(),
options=options,
status=rsp['output']['task_status'],
task_id=rsp['output']['task_id'],
error_message=rsp['message']
)
db.add(drawing)
db.commit()
db.refresh(drawing)
return drawing
def fetch_drawing_task_status(db: Session, drawing_id: int):
drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first()
if not drawing or not drawing.task_id:
return None, "任务不存在"
status = ImageSynthesis.fetch(drawing.task_id)
if status.status_code == HTTPStatus.OK:
# 可根据 status.output.task_status 更新数据库
drawing.status = status.output.task_status
if hasattr(status.output, 'results') and status.output.results:
drawing.pic_url = status.output.results[0].url
db.commit()
db.refresh(drawing)
return drawing, None
else:
return None, status.message
def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12):
query = db.query(Drawing)
if user_id:
query = query.filter(Drawing.user_id == user_id)
total = query.count()
items = query.order_by(desc(Drawing.id)).offset((page - 1) * page_size).limit(page_size).all()
return {
'total': total,
'page': page,
'page_size': page_size,
'items': items
}