add ai drawing
This commit is contained in:
@@ -1,18 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body, Query
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.v1.drawing.vo import CreateDrawingTaskRequest
|
||||
from db.session import get_db
|
||||
from deps.auth import get_current_user
|
||||
from llm.factory import get_adapter
|
||||
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
|
||||
from utils.resp import resp_error, resp_success
|
||||
|
||||
router = APIRouter(prefix="/drawing", tags=["drawing"])
|
||||
|
||||
@router.get("/")
|
||||
def api_get_image_page(
|
||||
def api_get_drawing_page(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(12, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -27,19 +29,12 @@ def api_get_image_page(
|
||||
"pic_url": img.pic_url,
|
||||
"status": img.status,
|
||||
"error_message": img.error_message,
|
||||
"created_at": img.created_at if hasattr(img, 'created_at') else None
|
||||
"create_time": img.create_time if hasattr(img, 'create_time') else None
|
||||
}
|
||||
for img in data["items"]
|
||||
]
|
||||
return data
|
||||
|
||||
class CreateDrawingTaskRequest(BaseModel):
|
||||
prompt: str
|
||||
style: str = 'auto'
|
||||
size: str = '1024x1024'
|
||||
model: str = 'wanx_v1'
|
||||
platform: str = 'tongyi'
|
||||
n: int = 1
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@@ -48,40 +43,37 @@ def api_create_image_task(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user)
|
||||
):
|
||||
user_id = user["user_id"]
|
||||
style = req.style
|
||||
size = req.size
|
||||
platform = req.platform
|
||||
n = req.n
|
||||
prompt = req.prompt
|
||||
model = req.model
|
||||
print(user_id, req.platform, req.size, req.model, req.prompt)
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
adapter = get_adapter('tongyi', api_key=api_key, model=model)
|
||||
try:
|
||||
# rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
|
||||
# print(rsp, 'rsp')
|
||||
res_json = {
|
||||
"status_code": 200,
|
||||
"request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
|
||||
"code": "",
|
||||
"message": "",
|
||||
"output": {
|
||||
"task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
|
||||
"task_status": "PENDING",
|
||||
"results": []
|
||||
},
|
||||
"usage": None
|
||||
}
|
||||
rsp = res_json
|
||||
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
|
||||
# rsp = {
|
||||
# "status_code": 200,
|
||||
# "request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
|
||||
# "code": "",
|
||||
# "message": "",
|
||||
# "output": {
|
||||
# "task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
|
||||
# "task_status": "PENDING",
|
||||
# "results": []
|
||||
# },
|
||||
# "usage": None
|
||||
# }
|
||||
if rsp['status_code'] != 200:
|
||||
raise HTTPException(status_code=500, detail=rsp['message'])
|
||||
return resp_error(message=rsp['message'], code=rsp['status_code'])
|
||||
# raise HTTPException(status_code=500, detail=rsp['message'])
|
||||
option = {
|
||||
'style': style
|
||||
}
|
||||
drawing = create_drawing_task(
|
||||
db=db,
|
||||
user_id=user["user_id"],
|
||||
user=user,
|
||||
platform=platform,
|
||||
model=model,
|
||||
rsp=rsp,
|
||||
@@ -89,22 +81,28 @@ def api_create_image_task(
|
||||
size=size,
|
||||
options=json.dumps(option)
|
||||
)
|
||||
return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
|
||||
return resp_success(data={
|
||||
"id": drawing.id,
|
||||
"task_id": drawing.task_id,
|
||||
"status": drawing.status
|
||||
})
|
||||
# return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
|
||||
except NotImplementedError:
|
||||
print("该服务商不支持图片生成")
|
||||
|
||||
|
||||
@router.get("/{id}")
|
||||
def api_fetch_image_task_status(
|
||||
id: int,
|
||||
@router.get("/{drawing_id}/")
|
||||
def api_fetch_drawing_task_status(
|
||||
drawing_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
image, err = fetch_drawing_task_status(db, id)
|
||||
if not image:
|
||||
drawing, err = fetch_drawing_task_status(db, drawing_id)
|
||||
if not drawing:
|
||||
raise HTTPException(status_code=404, detail=err or "任务不存在")
|
||||
return {
|
||||
"id": image.id,
|
||||
"status": image.status,
|
||||
"pic_url": image.pic_url,
|
||||
"error_message": image.error_message
|
||||
}
|
||||
return resp_success(data={
|
||||
"id": drawing.id,
|
||||
"status": drawing.status,
|
||||
"pic_url": drawing.pic_url,
|
||||
"error_message": drawing.error_message
|
||||
})
|
||||
|
||||
|
||||
10
ai_service/api/v1/drawing/vo.py
Normal file
10
ai_service/api/v1/drawing/vo.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateDrawingTaskRequest(BaseModel):
|
||||
prompt: str
|
||||
style: str = 'auto'
|
||||
size: str = '1024*1024'
|
||||
model: str = 'wanx_v1'
|
||||
platform: str = 'tongyi'
|
||||
n: int = 1
|
||||
@@ -23,10 +23,10 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
|
||||
|
||||
# 自动填充时间字段(如果模型有created_at/updated_at)
|
||||
if hasattr(self.model, "created_at"):
|
||||
obj_in_data["created_at"] = datetime.now()
|
||||
if hasattr(self.model, "updated_at"):
|
||||
obj_in_data["updated_at"] = datetime.now()
|
||||
if hasattr(self.model, "create_time"):
|
||||
obj_in_data["create_time"] = datetime.now()
|
||||
if hasattr(self.model, "update_time"):
|
||||
obj_in_data["update_time"] = datetime.now()
|
||||
|
||||
db_obj = self.model(**obj_in_data) # 实例化模型
|
||||
db.add(db_obj)
|
||||
|
||||
@@ -24,7 +24,6 @@ class TongYiAdapter(MultiModalAICapability):
|
||||
yield chunk
|
||||
|
||||
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
|
||||
print(self.model, self.api_key, 'key')
|
||||
"""创建异步图片生成任务"""
|
||||
rsp = ImageSynthesis.async_call(
|
||||
api_key=self.api_key,
|
||||
@@ -34,14 +33,12 @@ class TongYiAdapter(MultiModalAICapability):
|
||||
style=f'<{style}>',
|
||||
size=size
|
||||
)
|
||||
print(rsp, 'rsp')
|
||||
return rsp
|
||||
|
||||
def fetch_drawing_task_status(self, task):
|
||||
"""获取异步图片任务状态"""
|
||||
status = ImageSynthesis.fetch(task)
|
||||
if status.status_code == HTTPStatus.OK:
|
||||
return status.output.task_status
|
||||
else:
|
||||
raise Exception(f"Failed, status_code: {status.status_code}, code: {status.code}, message: {status.message}")
|
||||
rsp = ImageSynthesis.fetch(task, api_key=self.api_key)
|
||||
return rsp
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from db.session import Base
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
||||
@@ -8,6 +10,26 @@ class CoreModel(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
create_time = Column(DateTime)
|
||||
update_time = Column(DateTime)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
remark = Column(String(256), nullable=True, comment="备注")
|
||||
creator = Column(String(64), nullable=True, comment="创建人")
|
||||
modifier = Column(String(64), nullable=True, comment="修改人")
|
||||
|
||||
# 创建时间 - 使用函数默认值,在插入时自动生成
|
||||
create_time = Column(DateTime, default=datetime.now(), comment="创建时间")
|
||||
|
||||
# 修改时间 - 使用SQL函数,在更新时自动触发
|
||||
update_time = Column(DateTime, default=datetime.now(), onupdate=datetime.now(), comment="修改时间")
|
||||
|
||||
is_deleted = Column(Boolean, default=False, comment="是否软删除")
|
||||
|
||||
# 软删除方法
|
||||
# def soft_delete(self, session):
|
||||
# self.is_deleted = True
|
||||
# self.modifier = get_current_user() # 需要实现这个函数获取当前用户
|
||||
# self.update_time = datetime.utcnow()
|
||||
# session.commit()
|
||||
|
||||
# 查询时自动过滤已删除记录
|
||||
@classmethod
|
||||
def get_active(cls, session):
|
||||
return session.query(cls).filter(cls.is_deleted == False)
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from dashscope import ImageSynthesis
|
||||
@@ -5,22 +6,24 @@ from http import HTTPStatus
|
||||
|
||||
from sqlalchemy import desc
|
||||
|
||||
from llm.factory import get_adapter
|
||||
from models.ai import Drawing
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp,
|
||||
def create_drawing_task(db: Session, user, platform: str, model: str, prompt: str, size: str, rsp,
|
||||
options: str = None):
|
||||
# 写入数据库
|
||||
|
||||
drawing = Drawing(
|
||||
user_id=user_id,
|
||||
user_id=user['user_id'],
|
||||
creator=user['username'],
|
||||
modifier=user['username'],
|
||||
platform=platform,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
width=int(size.split('x')[0]),
|
||||
height=int(size.split('x')[1]),
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now(),
|
||||
width=int(size.split('*')[0]),
|
||||
height=int(size.split('*')[1]),
|
||||
options=options,
|
||||
status=rsp['output']['task_status'],
|
||||
task_id=rsp['output']['task_id'],
|
||||
@@ -35,17 +38,32 @@ def fetch_drawing_task_status(db: Session, drawing_id: int):
|
||||
drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first()
|
||||
if not drawing or not drawing.task_id:
|
||||
return None, "任务不存在"
|
||||
status = ImageSynthesis.fetch(drawing.task_id)
|
||||
if status.status_code == HTTPStatus.OK:
|
||||
# 可根据 status.output.task_status 更新数据库
|
||||
drawing.status = status.output.task_status
|
||||
if hasattr(status.output, 'results') and status.output.results:
|
||||
drawing.pic_url = status.output.results[0].url
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
if drawing.status in ("PENDING", 'RUNNING'):
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
adapter = get_adapter('tongyi', api_key=api_key, model='')
|
||||
rsp = adapter.fetch_drawing_task_status(drawing.task_id)
|
||||
if rsp['status_code'] == HTTPStatus.OK:
|
||||
# 可根据 status.output.task_status 更新数据库
|
||||
if rsp['output']['task_status'] == 'SUCCEEDED':
|
||||
drawing.update_time = datetime.now()
|
||||
drawing.status = rsp['output']['task_status']
|
||||
drawing.pic_url = rsp['output']['results'][0]['url']
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
elif rsp['output']['task_status'] == 'FAILED':
|
||||
drawing.update_time = datetime.now()
|
||||
drawing.status = rsp['output']['task_status']
|
||||
drawing.error_message = rsp['output']['message']
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
elif rsp['output']['task_status'] == 'RUNNING':
|
||||
drawing.update_time = datetime.now()
|
||||
drawing.status = rsp['output']['task_status']
|
||||
db.commit()
|
||||
db.refresh(drawing)
|
||||
return drawing, None
|
||||
else:
|
||||
return None, status.message
|
||||
return drawing, None
|
||||
|
||||
|
||||
def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12):
|
||||
|
||||
@@ -9,7 +9,7 @@ export interface CreateImageTaskParams {
|
||||
n?: number;
|
||||
}
|
||||
|
||||
export async function createImageTask(params: CreateImageTaskParams) {
|
||||
export async function createDrawing(params: CreateImageTaskParams) {
|
||||
const res = await fetchWithAuth('drawing/', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(params),
|
||||
@@ -20,7 +20,7 @@ export async function createImageTask(params: CreateImageTaskParams) {
|
||||
return await res.json();
|
||||
}
|
||||
|
||||
export async function fetchImageTaskStatus(id: number) {
|
||||
export async function getDrawingDetail(id: number) {
|
||||
const res = await fetchWithAuth(`drawing/${id}/`);
|
||||
if (!res.ok) {
|
||||
throw new Error('查询图片任务状态失败');
|
||||
@@ -33,12 +33,12 @@ export interface GetImagePageParams {
|
||||
page_size?: number;
|
||||
}
|
||||
|
||||
export async function getImagePage(params: GetImagePageParams = {}) {
|
||||
export async function getDrawingPage(params: GetImagePageParams = {}) {
|
||||
const query = new URLSearchParams();
|
||||
if (params.page) query.append('page', String(params.page));
|
||||
if (params.page_size) query.append('page_size', String(params.page_size));
|
||||
const res = await fetchWithAuth(
|
||||
`drawing${query.toString() ? `?${query.toString()}` : ''}`,
|
||||
`drawing/${query.toString() ? `?${query.toString()}` : ''}`,
|
||||
);
|
||||
if (!res.ok) {
|
||||
throw new Error('获取图片分页失败');
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, h, reactive, ref } from 'vue';
|
||||
import { onMounted, reactive, ref } from 'vue';
|
||||
|
||||
import { Page } from '@vben/common-ui';
|
||||
|
||||
@@ -13,9 +13,22 @@ import {
|
||||
Pagination,
|
||||
Row,
|
||||
Select,
|
||||
Spin,
|
||||
} from 'ant-design-vue';
|
||||
|
||||
import { createImageTask } from '#/api/ai/drawing';
|
||||
import {
|
||||
createDrawing,
|
||||
getDrawingDetail,
|
||||
getDrawingPage,
|
||||
} from '#/api/ai/drawing';
|
||||
|
||||
// 定义图片对象类型
|
||||
interface DrawingImage {
|
||||
id: number;
|
||||
status: string;
|
||||
pic_url?: string;
|
||||
// 其他属性
|
||||
}
|
||||
|
||||
// 表单选项
|
||||
const platforms = [
|
||||
@@ -24,14 +37,14 @@ const platforms = [
|
||||
// { label: 'OpenAI', value: 'openai' },
|
||||
// { label: 'Google GenAI', value: 'google-genai' },
|
||||
];
|
||||
const models = {
|
||||
|
||||
const models: Record<string, { label: string; value: string }[]> = {
|
||||
tongyi: [{ label: 'wanx_v1', value: 'wanx_v1' }],
|
||||
// deepseek: [{ label: 'deepseek-img', value: 'deepseek-img' }],
|
||||
// openai: [{ label: 'dall-e-3', value: 'dall-e-3' }],
|
||||
// 'google-genai': [{ label: 'imagen', value: 'imagen' }],
|
||||
// 其他平台...
|
||||
};
|
||||
|
||||
const sizes = [
|
||||
{ label: '1024x1024', value: '1024x1024' },
|
||||
{ label: '1024*1024', value: '1024*1024' },
|
||||
{ label: '720*1280', value: '720*1280' },
|
||||
{ label: '768*1152', value: '768*1152' },
|
||||
{ label: '1280*720', value: '1280*720' },
|
||||
@@ -51,26 +64,24 @@ const styles = [
|
||||
|
||||
// 表单数据
|
||||
const form = reactive({
|
||||
prompt: '近景镜头,18岁的中国女孩,古代服饰,圆脸,正面看着镜头,民族优雅的服装,商业摄影,室外,电影级光照,半身特写,精致的淡妆,锐利的边缘。',
|
||||
prompt:
|
||||
'近景镜头,18岁的中国女孩,古代服饰,圆脸,正面看着镜头,民族优雅的服装,商业摄影,室外,电影级光照,半身特写,精致的淡妆,锐利的边缘。',
|
||||
platform: 'tongyi',
|
||||
model: 'wanx_v1',
|
||||
size: '1024x1024',
|
||||
model: 'wanx-v1',
|
||||
size: '1024*1024',
|
||||
style: 'watercolor',
|
||||
});
|
||||
|
||||
// 图片数据与分页
|
||||
const images = ref<string[]>([]);
|
||||
const images = ref<DrawingImage[]>([]);
|
||||
const loading = ref(false);
|
||||
const page = ref(1);
|
||||
const pageSize = 12;
|
||||
const total = computed(() => images.value.length);
|
||||
const pagedImages = computed(() =>
|
||||
images.value.slice((page.value - 1) * pageSize, page.value * pageSize),
|
||||
);
|
||||
const pageSize = ref(9);
|
||||
const total = ref(0);
|
||||
|
||||
// 平台切换时自动切换模型
|
||||
const onPlatformChange = (val: string) => {
|
||||
form.model = models[val][0].value;
|
||||
const onPlatformChange = (value: number | string) => {
|
||||
form.model = models[value as string]?.[0]?.value ?? '';
|
||||
};
|
||||
|
||||
// 提交表单,调用AI画图API
|
||||
@@ -78,14 +89,15 @@ async function handleDraw() {
|
||||
loading.value = true;
|
||||
try {
|
||||
// 这里调用你的AI画图API,返回图片url数组
|
||||
const res = await createImageTask(form);
|
||||
const data = await createDrawing(form);
|
||||
if (data.code !== 0) {
|
||||
message.error(data.message || '生成失败');
|
||||
return;
|
||||
}
|
||||
page.value = 1;
|
||||
await fetchDrawingList(page.value, pageSize.value); // 刷新第一页图片列表
|
||||
// images.value = res.data.images;
|
||||
// DEMO用假数据
|
||||
// images.value = Array.from(
|
||||
// { length: 30 },
|
||||
// (_, i) => `https://picsum.photos/seed/${form.prompt}-${i}/300/300`,
|
||||
// );
|
||||
page.value = 1;
|
||||
message.success('生成成功');
|
||||
} catch {
|
||||
message.error('生成失败');
|
||||
@@ -93,6 +105,59 @@ async function handleDraw() {
|
||||
loading.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
// 轮询获取图片详情
|
||||
const pollDrawingDetail = async (id: number) => {
|
||||
fetchDrawingDetail(id).then((res) => {
|
||||
if (res && res.status === 'RUNNING') {
|
||||
setTimeout(() => pollDrawingDetail(id), 5000);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// 获取图片分页列表
|
||||
async function fetchDrawingList(pageNum = 1, pageSize = 9) {
|
||||
try {
|
||||
const res = await getDrawingPage({ page: pageNum, page_size: pageSize });
|
||||
images.value = res.items;
|
||||
// images.value = res.items.map(item => item.pic_url);
|
||||
total.value = res.total;
|
||||
// 检查每个 item 的状态
|
||||
for (const item of res.items) {
|
||||
if (item.status === 'PENDING') {
|
||||
fetchDrawingDetail(item.id);
|
||||
} else if (item.status === 'RUNNING') {
|
||||
pollDrawingDetail(item.id);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
} catch {
|
||||
message.error('获取图片列表失败');
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// 获取图片详情
|
||||
const fetchDrawingDetail = async (id: number) => {
|
||||
try {
|
||||
const res = await getDrawingDetail(id);
|
||||
// 更新 images 中对应项
|
||||
const idx = images.value.findIndex((item) => item.id === id);
|
||||
if (idx !== -1) {
|
||||
images.value[idx] = { ...images.value[idx], ...res.data };
|
||||
}
|
||||
// 处理详情数据
|
||||
return res;
|
||||
} catch {
|
||||
message.error('获取图片详情失败');
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
// 页面加载时调用获取图片列表
|
||||
onMounted(() => {
|
||||
fetchDrawingList();
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -166,28 +231,46 @@ async function handleDraw() {
|
||||
<Card title="生成结果" bordered>
|
||||
<Row :gutter="16">
|
||||
<Col
|
||||
v-for="(img, idx) in pagedImages"
|
||||
:key="img"
|
||||
:span="4"
|
||||
v-for="(img, idx) in images"
|
||||
:key="idx"
|
||||
:span="8"
|
||||
style="margin-bottom: 16px"
|
||||
>
|
||||
<Card
|
||||
hoverable
|
||||
:cover="
|
||||
h('img', {
|
||||
src: img,
|
||||
style: 'width:100%;height:180px;object-fit:cover;',
|
||||
})
|
||||
"
|
||||
/>
|
||||
<Card hoverable>
|
||||
<template #cover>
|
||||
<div
|
||||
v-if="img.status === 'PENDING' || img.status === 'RUNNING'"
|
||||
style="
|
||||
width: 100%;
|
||||
height: 180px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
"
|
||||
>
|
||||
<Spin size="large" />
|
||||
</div>
|
||||
<img
|
||||
v-else
|
||||
:src="img.pic_url"
|
||||
style="width: 100%; height: 180px; object-fit: cover"
|
||||
/>
|
||||
</template>
|
||||
</Card>
|
||||
</Col>
|
||||
</Row>
|
||||
<Pagination
|
||||
v-if="total > pageSize"
|
||||
v-model:current="page"
|
||||
:total="total"
|
||||
:page-size="pageSize"
|
||||
style="margin-top: 16px; text-align: right"
|
||||
@change="
|
||||
(p, ps) => {
|
||||
page = p;
|
||||
pageSize = ps;
|
||||
fetchDrawingList(p, ps);
|
||||
}
|
||||
"
|
||||
/>
|
||||
</Card>
|
||||
</Col>
|
||||
|
||||
Reference in New Issue
Block a user