add ai drawing

This commit is contained in:
XIE7654
2025-07-22 12:09:37 +08:00
parent 71d5053b9c
commit 04ee6bf0e0
8 changed files with 240 additions and 112 deletions

View File

@@ -1,18 +1,20 @@
import json
import os
from fastapi import APIRouter, Depends, HTTPException, Body, Query
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from api.v1.drawing.vo import CreateDrawingTaskRequest
from db.session import get_db
from deps.auth import get_current_user
from llm.factory import get_adapter
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
from utils.resp import resp_error, resp_success
router = APIRouter(prefix="/drawing", tags=["drawing"])
@router.get("/")
def api_get_image_page(
def api_get_drawing_page(
page: int = Query(1, ge=1),
page_size: int = Query(12, ge=1, le=100),
db: Session = Depends(get_db),
@@ -27,19 +29,12 @@ def api_get_image_page(
"pic_url": img.pic_url,
"status": img.status,
"error_message": img.error_message,
"created_at": img.created_at if hasattr(img, 'created_at') else None
"create_time": img.create_time if hasattr(img, 'create_time') else None
}
for img in data["items"]
]
return data
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024x1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1
@router.post("/")
@@ -48,40 +43,37 @@ def api_create_image_task(
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
user_id = user["user_id"]
style = req.style
size = req.size
platform = req.platform
n = req.n
prompt = req.prompt
model = req.model
print(user_id, req.platform, req.size, req.model, req.prompt)
api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model=model)
try:
# rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# print(rsp, 'rsp')
res_json = {
"status_code": 200,
"request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
"code": "",
"message": "",
"output": {
"task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
"task_status": "PENDING",
"results": []
},
"usage": None
}
rsp = res_json
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# rsp = {
# "status_code": 200,
# "request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
# "code": "",
# "message": "",
# "output": {
# "task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
# "task_status": "PENDING",
# "results": []
# },
# "usage": None
# }
if rsp['status_code'] != 200:
raise HTTPException(status_code=500, detail=rsp['message'])
return resp_error(message=rsp['message'], code=rsp['status_code'])
# raise HTTPException(status_code=500, detail=rsp['message'])
option = {
'style': style
}
drawing = create_drawing_task(
db=db,
user_id=user["user_id"],
user=user,
platform=platform,
model=model,
rsp=rsp,
@@ -89,22 +81,28 @@ def api_create_image_task(
size=size,
options=json.dumps(option)
)
return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
return resp_success(data={
"id": drawing.id,
"task_id": drawing.task_id,
"status": drawing.status
})
# return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
except NotImplementedError:
print("该服务商不支持图片生成")
@router.get("/{id}")
def api_fetch_image_task_status(
id: int,
@router.get("/{drawing_id}/")
def api_fetch_drawing_task_status(
drawing_id: int,
db: Session = Depends(get_db)
):
image, err = fetch_drawing_task_status(db, id)
if not image:
drawing, err = fetch_drawing_task_status(db, drawing_id)
if not drawing:
raise HTTPException(status_code=404, detail=err or "任务不存在")
return {
"id": image.id,
"status": image.status,
"pic_url": image.pic_url,
"error_message": image.error_message
}
return resp_success(data={
"id": drawing.id,
"status": drawing.status,
"pic_url": drawing.pic_url,
"error_message": drawing.error_message
})

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024*1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1

View File

@@ -23,10 +23,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
# 自动填充时间字段如果模型有created_at/updated_at
if hasattr(self.model, "created_at"):
obj_in_data["created_at"] = datetime.now()
if hasattr(self.model, "updated_at"):
obj_in_data["updated_at"] = datetime.now()
if hasattr(self.model, "create_time"):
obj_in_data["create_time"] = datetime.now()
if hasattr(self.model, "update_time"):
obj_in_data["update_time"] = datetime.now()
db_obj = self.model(**obj_in_data) # 实例化模型
db.add(db_obj)

View File

@@ -24,7 +24,6 @@ class TongYiAdapter(MultiModalAICapability):
yield chunk
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
print(self.model, self.api_key, 'key')
"""创建异步图片生成任务"""
rsp = ImageSynthesis.async_call(
api_key=self.api_key,
@@ -34,14 +33,12 @@ class TongYiAdapter(MultiModalAICapability):
style=f'<{style}>',
size=size
)
print(rsp, 'rsp')
return rsp
def fetch_drawing_task_status(self, task):
"""获取异步图片任务状态"""
status = ImageSynthesis.fetch(task)
if status.status_code == HTTPStatus.OK:
return status.output.task_status
else:
raise Exception(f"Failed, status_code: {status.status_code}, code: {status.code}, message: {status.message}")
rsp = ImageSynthesis.fetch(task, api_key=self.api_key)
return rsp

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from db.session import Base
from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
@@ -8,6 +10,26 @@ class CoreModel(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
create_time = Column(DateTime)
update_time = Column(DateTime)
is_deleted = Column(Boolean, default=False)
remark = Column(String(256), nullable=True, comment="备注")
creator = Column(String(64), nullable=True, comment="创建人")
modifier = Column(String(64), nullable=True, comment="修改人")
# 创建时间 - 使用函数默认值,在插入时自动生成
create_time = Column(DateTime, default=datetime.now(), comment="创建时间")
# 修改时间 - 使用SQL函数在更新时自动触发
update_time = Column(DateTime, default=datetime.now(), onupdate=datetime.now(), comment="修改时间")
is_deleted = Column(Boolean, default=False, comment="是否软删除")
# 软删除方法
# def soft_delete(self, session):
# self.is_deleted = True
# self.modifier = get_current_user() # 需要实现这个函数获取当前用户
# self.update_time = datetime.utcnow()
# session.commit()
# 查询时自动过滤已删除记录
@classmethod
def get_active(cls, session):
return session.query(cls).filter(cls.is_deleted == False)

View File

@@ -1,3 +1,4 @@
import os
from datetime import datetime
from dashscope import ImageSynthesis
@@ -5,22 +6,24 @@ from http import HTTPStatus
from sqlalchemy import desc
from llm.factory import get_adapter
from models.ai import Drawing
from sqlalchemy.orm import Session
def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp,
def create_drawing_task(db: Session, user, platform: str, model: str, prompt: str, size: str, rsp,
options: str = None):
# 写入数据库
drawing = Drawing(
user_id=user_id,
user_id=user['user_id'],
creator=user['username'],
modifier=user['username'],
platform=platform,
model=model,
prompt=prompt,
width=int(size.split('x')[0]),
height=int(size.split('x')[1]),
create_time=datetime.now(),
update_time=datetime.now(),
width=int(size.split('*')[0]),
height=int(size.split('*')[1]),
options=options,
status=rsp['output']['task_status'],
task_id=rsp['output']['task_id'],
@@ -35,17 +38,32 @@ def fetch_drawing_task_status(db: Session, drawing_id: int):
drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first()
if not drawing or not drawing.task_id:
return None, "任务不存在"
status = ImageSynthesis.fetch(drawing.task_id)
if status.status_code == HTTPStatus.OK:
# 可根据 status.output.task_status 更新数据库
drawing.status = status.output.task_status
if hasattr(status.output, 'results') and status.output.results:
drawing.pic_url = status.output.results[0].url
db.commit()
db.refresh(drawing)
if drawing.status in ("PENDING", 'RUNNING'):
api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model='')
rsp = adapter.fetch_drawing_task_status(drawing.task_id)
if rsp['status_code'] == HTTPStatus.OK:
# 可根据 status.output.task_status 更新数据库
if rsp['output']['task_status'] == 'SUCCEEDED':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
drawing.pic_url = rsp['output']['results'][0]['url']
db.commit()
db.refresh(drawing)
elif rsp['output']['task_status'] == 'FAILED':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
drawing.error_message = rsp['output']['message']
db.commit()
db.refresh(drawing)
elif rsp['output']['task_status'] == 'RUNNING':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
db.commit()
db.refresh(drawing)
return drawing, None
else:
return None, status.message
return drawing, None
def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12):

View File

@@ -9,7 +9,7 @@ export interface CreateImageTaskParams {
n?: number;
}
export async function createImageTask(params: CreateImageTaskParams) {
export async function createDrawing(params: CreateImageTaskParams) {
const res = await fetchWithAuth('drawing/', {
method: 'POST',
body: JSON.stringify(params),
@@ -20,7 +20,7 @@ export async function createImageTask(params: CreateImageTaskParams) {
return await res.json();
}
export async function fetchImageTaskStatus(id: number) {
export async function getDrawingDetail(id: number) {
const res = await fetchWithAuth(`drawing/${id}/`);
if (!res.ok) {
throw new Error('查询图片任务状态失败');
@@ -33,12 +33,12 @@ export interface GetImagePageParams {
page_size?: number;
}
export async function getImagePage(params: GetImagePageParams = {}) {
export async function getDrawingPage(params: GetImagePageParams = {}) {
const query = new URLSearchParams();
if (params.page) query.append('page', String(params.page));
if (params.page_size) query.append('page_size', String(params.page_size));
const res = await fetchWithAuth(
`drawing${query.toString() ? `?${query.toString()}` : ''}`,
`drawing/${query.toString() ? `?${query.toString()}` : ''}`,
);
if (!res.ok) {
throw new Error('获取图片分页失败');

View File

@@ -1,5 +1,5 @@
<script setup lang="ts">
import { computed, h, reactive, ref } from 'vue';
import { onMounted, reactive, ref } from 'vue';
import { Page } from '@vben/common-ui';
@@ -13,9 +13,22 @@ import {
Pagination,
Row,
Select,
Spin,
} from 'ant-design-vue';
import { createImageTask } from '#/api/ai/drawing';
import {
createDrawing,
getDrawingDetail,
getDrawingPage,
} from '#/api/ai/drawing';
// 定义图片对象类型
interface DrawingImage {
id: number;
status: string;
pic_url?: string;
// 其他属性
}
// 表单选项
const platforms = [
@@ -24,14 +37,14 @@ const platforms = [
// { label: 'OpenAI', value: 'openai' },
// { label: 'Google GenAI', value: 'google-genai' },
];
const models = {
const models: Record<string, { label: string; value: string }[]> = {
tongyi: [{ label: 'wanx_v1', value: 'wanx_v1' }],
// deepseek: [{ label: 'deepseek-img', value: 'deepseek-img' }],
// openai: [{ label: 'dall-e-3', value: 'dall-e-3' }],
// 'google-genai': [{ label: 'imagen', value: 'imagen' }],
// 其他平台...
};
const sizes = [
{ label: '1024x1024', value: '1024x1024' },
{ label: '1024*1024', value: '1024*1024' },
{ label: '720*1280', value: '720*1280' },
{ label: '768*1152', value: '768*1152' },
{ label: '1280*720', value: '1280*720' },
@@ -51,26 +64,24 @@ const styles = [
// 表单数据
const form = reactive({
prompt: '近景镜头18岁的中国女孩古代服饰圆脸正面看着镜头民族优雅的服装商业摄影室外电影级光照半身特写精致的淡妆锐利的边缘。',
prompt:
'近景镜头18岁的中国女孩古代服饰圆脸正面看着镜头民族优雅的服装商业摄影室外电影级光照半身特写精致的淡妆锐利的边缘。',
platform: 'tongyi',
model: 'wanx_v1',
size: '1024x1024',
model: 'wanx-v1',
size: '1024*1024',
style: 'watercolor',
});
// 图片数据与分页
const images = ref<string[]>([]);
const images = ref<DrawingImage[]>([]);
const loading = ref(false);
const page = ref(1);
const pageSize = 12;
const total = computed(() => images.value.length);
const pagedImages = computed(() =>
images.value.slice((page.value - 1) * pageSize, page.value * pageSize),
);
const pageSize = ref(9);
const total = ref(0);
// 平台切换时自动切换模型
const onPlatformChange = (val: string) => {
form.model = models[val][0].value;
const onPlatformChange = (value: number | string) => {
form.model = models[value as string]?.[0]?.value ?? '';
};
// 提交表单调用AI画图API
@@ -78,14 +89,15 @@ async function handleDraw() {
loading.value = true;
try {
// 这里调用你的AI画图API返回图片url数组
const res = await createImageTask(form);
const data = await createDrawing(form);
if (data.code !== 0) {
message.error(data.message || '生成失败');
return;
}
page.value = 1;
await fetchDrawingList(page.value, pageSize.value); // 刷新第一页图片列表
// images.value = res.data.images;
// DEMO用假数据
// images.value = Array.from(
// { length: 30 },
// (_, i) => `https://picsum.photos/seed/${form.prompt}-${i}/300/300`,
// );
page.value = 1;
message.success('生成成功');
} catch {
message.error('生成失败');
@@ -93,6 +105,59 @@ async function handleDraw() {
loading.value = false;
}
}
// 轮询获取图片详情
const pollDrawingDetail = async (id: number) => {
fetchDrawingDetail(id).then((res) => {
if (res && res.status === 'RUNNING') {
setTimeout(() => pollDrawingDetail(id), 5000);
}
});
};
// 获取图片分页列表
async function fetchDrawingList(pageNum = 1, pageSize = 9) {
try {
const res = await getDrawingPage({ page: pageNum, page_size: pageSize });
images.value = res.items;
// images.value = res.items.map(item => item.pic_url);
total.value = res.total;
// 检查每个 item 的状态
for (const item of res.items) {
if (item.status === 'PENDING') {
fetchDrawingDetail(item.id);
} else if (item.status === 'RUNNING') {
pollDrawingDetail(item.id);
}
}
return res;
} catch {
message.error('获取图片列表失败');
return null;
}
}
// 获取图片详情
const fetchDrawingDetail = async (id: number) => {
try {
const res = await getDrawingDetail(id);
// 更新 images 中对应项
const idx = images.value.findIndex((item) => item.id === id);
if (idx !== -1) {
images.value[idx] = { ...images.value[idx], ...res.data };
}
// 处理详情数据
return res;
} catch {
message.error('获取图片详情失败');
return null;
}
};
// 页面加载时调用获取图片列表
onMounted(() => {
fetchDrawingList();
});
</script>
<template>
@@ -166,28 +231,46 @@ async function handleDraw() {
<Card title="生成结果" bordered>
<Row :gutter="16">
<Col
v-for="(img, idx) in pagedImages"
:key="img"
:span="4"
v-for="(img, idx) in images"
:key="idx"
:span="8"
style="margin-bottom: 16px"
>
<Card
hoverable
:cover="
h('img', {
src: img,
style: 'width:100%;height:180px;object-fit:cover;',
})
"
/>
<Card hoverable>
<template #cover>
<div
v-if="img.status === 'PENDING' || img.status === 'RUNNING'"
style="
width: 100%;
height: 180px;
display: flex;
align-items: center;
justify-content: center;
"
>
<Spin size="large" />
</div>
<img
v-else
:src="img.pic_url"
style="width: 100%; height: 180px; object-fit: cover"
/>
</template>
</Card>
</Col>
</Row>
<Pagination
v-if="total > pageSize"
v-model:current="page"
:total="total"
:page-size="pageSize"
style="margin-top: 16px; text-align: right"
@change="
(p, ps) => {
page = p;
pageSize = ps;
fetchDrawingList(p, ps);
}
"
/>
</Card>
</Col>