add ai drawing

This commit is contained in:
XIE7654
2025-07-22 12:09:37 +08:00
parent 71d5053b9c
commit 04ee6bf0e0
8 changed files with 240 additions and 112 deletions

View File

@@ -1,18 +1,20 @@
import json import json
import os import os
from fastapi import APIRouter, Depends, HTTPException, Body, Query from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from api.v1.drawing.vo import CreateDrawingTaskRequest
from db.session import get_db from db.session import get_db
from deps.auth import get_current_user from deps.auth import get_current_user
from llm.factory import get_adapter from llm.factory import get_adapter
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
from utils.resp import resp_error, resp_success
router = APIRouter(prefix="/drawing", tags=["drawing"]) router = APIRouter(prefix="/drawing", tags=["drawing"])
@router.get("/") @router.get("/")
def api_get_image_page( def api_get_drawing_page(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(12, ge=1, le=100), page_size: int = Query(12, ge=1, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -27,19 +29,12 @@ def api_get_image_page(
"pic_url": img.pic_url, "pic_url": img.pic_url,
"status": img.status, "status": img.status,
"error_message": img.error_message, "error_message": img.error_message,
"created_at": img.created_at if hasattr(img, 'created_at') else None "create_time": img.create_time if hasattr(img, 'create_time') else None
} }
for img in data["items"] for img in data["items"]
] ]
return data return data
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024x1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1
@router.post("/") @router.post("/")
@@ -48,40 +43,37 @@ def api_create_image_task(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user=Depends(get_current_user) user=Depends(get_current_user)
): ):
user_id = user["user_id"]
style = req.style style = req.style
size = req.size size = req.size
platform = req.platform platform = req.platform
n = req.n n = req.n
prompt = req.prompt prompt = req.prompt
model = req.model model = req.model
print(user_id, req.platform, req.size, req.model, req.prompt)
api_key = os.getenv("DASHSCOPE_API_KEY") api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model=model) adapter = get_adapter('tongyi', api_key=api_key, model=model)
try: try:
# rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size) rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# print(rsp, 'rsp') # rsp = {
res_json = { # "status_code": 200,
"status_code": 200, # "request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
"request_id": "31b04171-011c-96bd-ac00-f0383b669cc7", # "code": "",
"code": "", # "message": "",
"message": "", # "output": {
"output": { # "task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
"task_id": "4f90cf14-a34e-4eae-xxxxxxxx", # "task_status": "PENDING",
"task_status": "PENDING", # "results": []
"results": [] # },
}, # "usage": None
"usage": None # }
}
rsp = res_json
if rsp['status_code'] != 200: if rsp['status_code'] != 200:
raise HTTPException(status_code=500, detail=rsp['message']) return resp_error(message=rsp['message'], code=rsp['status_code'])
# raise HTTPException(status_code=500, detail=rsp['message'])
option = { option = {
'style': style 'style': style
} }
drawing = create_drawing_task( drawing = create_drawing_task(
db=db, db=db,
user_id=user["user_id"], user=user,
platform=platform, platform=platform,
model=model, model=model,
rsp=rsp, rsp=rsp,
@@ -89,22 +81,28 @@ def api_create_image_task(
size=size, size=size,
options=json.dumps(option) options=json.dumps(option)
) )
return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status} return resp_success(data={
"id": drawing.id,
"task_id": drawing.task_id,
"status": drawing.status
})
# return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
except NotImplementedError: except NotImplementedError:
print("该服务商不支持图片生成") print("该服务商不支持图片生成")
@router.get("/{id}") @router.get("/{drawing_id}/")
def api_fetch_image_task_status( def api_fetch_drawing_task_status(
id: int, drawing_id: int,
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
image, err = fetch_drawing_task_status(db, id) drawing, err = fetch_drawing_task_status(db, drawing_id)
if not image: if not drawing:
raise HTTPException(status_code=404, detail=err or "任务不存在") raise HTTPException(status_code=404, detail=err or "任务不存在")
return { return resp_success(data={
"id": image.id, "id": drawing.id,
"status": image.status, "status": drawing.status,
"pic_url": image.pic_url, "pic_url": drawing.pic_url,
"error_message": image.error_message "error_message": drawing.error_message
} })

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024*1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1

View File

@@ -23,10 +23,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典 obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
# 自动填充时间字段如果模型有created_at/updated_at # 自动填充时间字段如果模型有created_at/updated_at
if hasattr(self.model, "created_at"): if hasattr(self.model, "create_time"):
obj_in_data["created_at"] = datetime.now() obj_in_data["create_time"] = datetime.now()
if hasattr(self.model, "updated_at"): if hasattr(self.model, "update_time"):
obj_in_data["updated_at"] = datetime.now() obj_in_data["update_time"] = datetime.now()
db_obj = self.model(**obj_in_data) # 实例化模型 db_obj = self.model(**obj_in_data) # 实例化模型
db.add(db_obj) db.add(db_obj)

View File

@@ -24,7 +24,6 @@ class TongYiAdapter(MultiModalAICapability):
yield chunk yield chunk
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs): def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
print(self.model, self.api_key, 'key')
"""创建异步图片生成任务""" """创建异步图片生成任务"""
rsp = ImageSynthesis.async_call( rsp = ImageSynthesis.async_call(
api_key=self.api_key, api_key=self.api_key,
@@ -34,14 +33,12 @@ class TongYiAdapter(MultiModalAICapability):
style=f'<{style}>', style=f'<{style}>',
size=size size=size
) )
print(rsp, 'rsp') return rsp
def fetch_drawing_task_status(self, task): def fetch_drawing_task_status(self, task):
"""获取异步图片任务状态""" """获取异步图片任务状态"""
status = ImageSynthesis.fetch(task) rsp = ImageSynthesis.fetch(task, api_key=self.api_key)
if status.status_code == HTTPStatus.OK: return rsp
return status.output.task_status
else:
raise Exception(f"Failed, status_code: {status.status_code}, code: {status.code}, message: {status.message}")

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from db.session import Base from db.session import Base
from sqlalchemy import ( from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
@@ -8,6 +10,26 @@ class CoreModel(Base):
__abstract__ = True __abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
create_time = Column(DateTime) remark = Column(String(256), nullable=True, comment="备注")
update_time = Column(DateTime) creator = Column(String(64), nullable=True, comment="创建人")
is_deleted = Column(Boolean, default=False) modifier = Column(String(64), nullable=True, comment="修改人")
# 创建时间 - 使用函数默认值,在插入时自动生成
create_time = Column(DateTime, default=datetime.now(), comment="创建时间")
# 修改时间 - 使用SQL函数在更新时自动触发
update_time = Column(DateTime, default=datetime.now(), onupdate=datetime.now(), comment="修改时间")
is_deleted = Column(Boolean, default=False, comment="是否软删除")
# 软删除方法
# def soft_delete(self, session):
# self.is_deleted = True
# self.modifier = get_current_user() # 需要实现这个函数获取当前用户
# self.update_time = datetime.utcnow()
# session.commit()
# 查询时自动过滤已删除记录
@classmethod
def get_active(cls, session):
return session.query(cls).filter(cls.is_deleted == False)

View File

@@ -1,3 +1,4 @@
import os
from datetime import datetime from datetime import datetime
from dashscope import ImageSynthesis from dashscope import ImageSynthesis
@@ -5,22 +6,24 @@ from http import HTTPStatus
from sqlalchemy import desc from sqlalchemy import desc
from llm.factory import get_adapter
from models.ai import Drawing from models.ai import Drawing
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp, def create_drawing_task(db: Session, user, platform: str, model: str, prompt: str, size: str, rsp,
options: str = None): options: str = None):
# 写入数据库 # 写入数据库
drawing = Drawing( drawing = Drawing(
user_id=user_id, user_id=user['user_id'],
creator=user['username'],
modifier=user['username'],
platform=platform, platform=platform,
model=model, model=model,
prompt=prompt, prompt=prompt,
width=int(size.split('x')[0]), width=int(size.split('*')[0]),
height=int(size.split('x')[1]), height=int(size.split('*')[1]),
create_time=datetime.now(),
update_time=datetime.now(),
options=options, options=options,
status=rsp['output']['task_status'], status=rsp['output']['task_status'],
task_id=rsp['output']['task_id'], task_id=rsp['output']['task_id'],
@@ -35,17 +38,32 @@ def fetch_drawing_task_status(db: Session, drawing_id: int):
drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first() drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first()
if not drawing or not drawing.task_id: if not drawing or not drawing.task_id:
return None, "任务不存在" return None, "任务不存在"
status = ImageSynthesis.fetch(drawing.task_id) if drawing.status in ("PENDING", 'RUNNING'):
if status.status_code == HTTPStatus.OK: api_key = os.getenv("DASHSCOPE_API_KEY")
# 可根据 status.output.task_status 更新数据库 adapter = get_adapter('tongyi', api_key=api_key, model='')
drawing.status = status.output.task_status rsp = adapter.fetch_drawing_task_status(drawing.task_id)
if hasattr(status.output, 'results') and status.output.results: if rsp['status_code'] == HTTPStatus.OK:
drawing.pic_url = status.output.results[0].url # 可根据 status.output.task_status 更新数据库
db.commit() if rsp['output']['task_status'] == 'SUCCEEDED':
db.refresh(drawing) drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
drawing.pic_url = rsp['output']['results'][0]['url']
db.commit()
db.refresh(drawing)
elif rsp['output']['task_status'] == 'FAILED':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
drawing.error_message = rsp['output']['message']
db.commit()
db.refresh(drawing)
elif rsp['output']['task_status'] == 'RUNNING':
drawing.update_time = datetime.now()
drawing.status = rsp['output']['task_status']
db.commit()
db.refresh(drawing)
return drawing, None return drawing, None
else: else:
return None, status.message return drawing, None
def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12): def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12):

View File

@@ -9,7 +9,7 @@ export interface CreateImageTaskParams {
n?: number; n?: number;
} }
export async function createImageTask(params: CreateImageTaskParams) { export async function createDrawing(params: CreateImageTaskParams) {
const res = await fetchWithAuth('drawing/', { const res = await fetchWithAuth('drawing/', {
method: 'POST', method: 'POST',
body: JSON.stringify(params), body: JSON.stringify(params),
@@ -20,7 +20,7 @@ export async function createImageTask(params: CreateImageTaskParams) {
return await res.json(); return await res.json();
} }
export async function fetchImageTaskStatus(id: number) { export async function getDrawingDetail(id: number) {
const res = await fetchWithAuth(`drawing/${id}/`); const res = await fetchWithAuth(`drawing/${id}/`);
if (!res.ok) { if (!res.ok) {
throw new Error('查询图片任务状态失败'); throw new Error('查询图片任务状态失败');
@@ -33,12 +33,12 @@ export interface GetImagePageParams {
page_size?: number; page_size?: number;
} }
export async function getImagePage(params: GetImagePageParams = {}) { export async function getDrawingPage(params: GetImagePageParams = {}) {
const query = new URLSearchParams(); const query = new URLSearchParams();
if (params.page) query.append('page', String(params.page)); if (params.page) query.append('page', String(params.page));
if (params.page_size) query.append('page_size', String(params.page_size)); if (params.page_size) query.append('page_size', String(params.page_size));
const res = await fetchWithAuth( const res = await fetchWithAuth(
`drawing${query.toString() ? `?${query.toString()}` : ''}`, `drawing/${query.toString() ? `?${query.toString()}` : ''}`,
); );
if (!res.ok) { if (!res.ok) {
throw new Error('获取图片分页失败'); throw new Error('获取图片分页失败');

View File

@@ -1,5 +1,5 @@
<script setup lang="ts"> <script setup lang="ts">
import { computed, h, reactive, ref } from 'vue'; import { onMounted, reactive, ref } from 'vue';
import { Page } from '@vben/common-ui'; import { Page } from '@vben/common-ui';
@@ -13,9 +13,22 @@ import {
Pagination, Pagination,
Row, Row,
Select, Select,
Spin,
} from 'ant-design-vue'; } from 'ant-design-vue';
import { createImageTask } from '#/api/ai/drawing'; import {
createDrawing,
getDrawingDetail,
getDrawingPage,
} from '#/api/ai/drawing';
// 定义图片对象类型
interface DrawingImage {
id: number;
status: string;
pic_url?: string;
// 其他属性
}
// 表单选项 // 表单选项
const platforms = [ const platforms = [
@@ -24,14 +37,14 @@ const platforms = [
// { label: 'OpenAI', value: 'openai' }, // { label: 'OpenAI', value: 'openai' },
// { label: 'Google GenAI', value: 'google-genai' }, // { label: 'Google GenAI', value: 'google-genai' },
]; ];
const models = {
const models: Record<string, { label: string; value: string }[]> = {
tongyi: [{ label: 'wanx_v1', value: 'wanx_v1' }], tongyi: [{ label: 'wanx_v1', value: 'wanx_v1' }],
// deepseek: [{ label: 'deepseek-img', value: 'deepseek-img' }], // 其他平台...
// openai: [{ label: 'dall-e-3', value: 'dall-e-3' }],
// 'google-genai': [{ label: 'imagen', value: 'imagen' }],
}; };
const sizes = [ const sizes = [
{ label: '1024x1024', value: '1024x1024' }, { label: '1024*1024', value: '1024*1024' },
{ label: '720*1280', value: '720*1280' }, { label: '720*1280', value: '720*1280' },
{ label: '768*1152', value: '768*1152' }, { label: '768*1152', value: '768*1152' },
{ label: '1280*720', value: '1280*720' }, { label: '1280*720', value: '1280*720' },
@@ -51,26 +64,24 @@ const styles = [
// 表单数据 // 表单数据
const form = reactive({ const form = reactive({
prompt: '近景镜头18岁的中国女孩古代服饰圆脸正面看着镜头民族优雅的服装商业摄影室外电影级光照半身特写精致的淡妆锐利的边缘。', prompt:
'近景镜头18岁的中国女孩古代服饰圆脸正面看着镜头民族优雅的服装商业摄影室外电影级光照半身特写精致的淡妆锐利的边缘。',
platform: 'tongyi', platform: 'tongyi',
model: 'wanx_v1', model: 'wanx-v1',
size: '1024x1024', size: '1024*1024',
style: 'watercolor', style: 'watercolor',
}); });
// 图片数据与分页 // 图片数据与分页
const images = ref<string[]>([]); const images = ref<DrawingImage[]>([]);
const loading = ref(false); const loading = ref(false);
const page = ref(1); const page = ref(1);
const pageSize = 12; const pageSize = ref(9);
const total = computed(() => images.value.length); const total = ref(0);
const pagedImages = computed(() =>
images.value.slice((page.value - 1) * pageSize, page.value * pageSize),
);
// 平台切换时自动切换模型 // 平台切换时自动切换模型
const onPlatformChange = (val: string) => { const onPlatformChange = (value: number | string) => {
form.model = models[val][0].value; form.model = models[value as string]?.[0]?.value ?? '';
}; };
// 提交表单调用AI画图API // 提交表单调用AI画图API
@@ -78,14 +89,15 @@ async function handleDraw() {
loading.value = true; loading.value = true;
try { try {
// 这里调用你的AI画图API返回图片url数组 // 这里调用你的AI画图API返回图片url数组
const res = await createImageTask(form); const data = await createDrawing(form);
if (data.code !== 0) {
message.error(data.message || '生成失败');
return;
}
page.value = 1;
await fetchDrawingList(page.value, pageSize.value); // 刷新第一页图片列表
// images.value = res.data.images; // images.value = res.data.images;
// DEMO用假数据 // DEMO用假数据
// images.value = Array.from(
// { length: 30 },
// (_, i) => `https://picsum.photos/seed/${form.prompt}-${i}/300/300`,
// );
page.value = 1;
message.success('生成成功'); message.success('生成成功');
} catch { } catch {
message.error('生成失败'); message.error('生成失败');
@@ -93,6 +105,59 @@ async function handleDraw() {
loading.value = false; loading.value = false;
} }
} }
// 轮询获取图片详情
const pollDrawingDetail = async (id: number) => {
fetchDrawingDetail(id).then((res) => {
if (res && res.status === 'RUNNING') {
setTimeout(() => pollDrawingDetail(id), 5000);
}
});
};
// 获取图片分页列表
async function fetchDrawingList(pageNum = 1, pageSize = 9) {
try {
const res = await getDrawingPage({ page: pageNum, page_size: pageSize });
images.value = res.items;
// images.value = res.items.map(item => item.pic_url);
total.value = res.total;
// 检查每个 item 的状态
for (const item of res.items) {
if (item.status === 'PENDING') {
fetchDrawingDetail(item.id);
} else if (item.status === 'RUNNING') {
pollDrawingDetail(item.id);
}
}
return res;
} catch {
message.error('获取图片列表失败');
return null;
}
}
// 获取图片详情
const fetchDrawingDetail = async (id: number) => {
try {
const res = await getDrawingDetail(id);
// 更新 images 中对应项
const idx = images.value.findIndex((item) => item.id === id);
if (idx !== -1) {
images.value[idx] = { ...images.value[idx], ...res.data };
}
// 处理详情数据
return res;
} catch {
message.error('获取图片详情失败');
return null;
}
};
// 页面加载时调用获取图片列表
onMounted(() => {
fetchDrawingList();
});
</script> </script>
<template> <template>
@@ -166,28 +231,46 @@ async function handleDraw() {
<Card title="生成结果" bordered> <Card title="生成结果" bordered>
<Row :gutter="16"> <Row :gutter="16">
<Col <Col
v-for="(img, idx) in pagedImages" v-for="(img, idx) in images"
:key="img" :key="idx"
:span="4" :span="8"
style="margin-bottom: 16px" style="margin-bottom: 16px"
> >
<Card <Card hoverable>
hoverable <template #cover>
:cover=" <div
h('img', { v-if="img.status === 'PENDING' || img.status === 'RUNNING'"
src: img, style="
style: 'width:100%;height:180px;object-fit:cover;', width: 100%;
}) height: 180px;
" display: flex;
/> align-items: center;
justify-content: center;
"
>
<Spin size="large" />
</div>
<img
v-else
:src="img.pic_url"
style="width: 100%; height: 180px; object-fit: cover"
/>
</template>
</Card>
</Col> </Col>
</Row> </Row>
<Pagination <Pagination
v-if="total > pageSize"
v-model:current="page" v-model:current="page"
:total="total" :total="total"
:page-size="pageSize" :page-size="pageSize"
style="margin-top: 16px; text-align: right" style="margin-top: 16px; text-align: right"
@change="
(p, ps) => {
page = p;
pageSize = ps;
fetchDrawingList(p, ps);
}
"
/> />
</Card> </Card>
</Col> </Col>