mock create drawing

This commit is contained in:
XIE7654
2025-07-21 22:22:32 +08:00
parent 816668530c
commit 71d5053b9c
19 changed files with 504 additions and 43 deletions

View File

@@ -0,0 +1,22 @@
from fastapi import APIRouter
from .chat import router as chat_router
from .drawing import router as drawing_router
# from .video import router as video_router
# from .audio import router as audio_router
# from .multimodal import router as multimodal_router
# from .model_manage import router as model_manage_router
# from .knowledge import router as knowledge_router
# from .system import router as system_router
# from .user import router as user_router
api_v1_router = APIRouter(prefix="/api/ai/v1")
api_v1_router.include_router(chat_router)
api_v1_router.include_router(drawing_router)
# api_v1_router.include_router(video_router)
# api_v1_router.include_router(audio_router)
# api_v1_router.include_router(multimodal_router)
# api_v1_router.include_router(model_manage_router)
# api_v1_router.include_router(knowledge_router)
# api_v1_router.include_router(system_router)
# api_v1_router.include_router(user_router)

View File

@@ -8,7 +8,7 @@ from typing import List
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
from api.v1.vo import MessageVO from api.v1.chat.vo import MessageVO
from deps.auth import get_current_user from deps.auth import get_current_user
from services.chat_service import ChatDBService from services.chat_service import ChatDBService
from db.session import get_db from db.session import get_db
@@ -16,7 +16,7 @@ from models.ai import ChatConversation, ChatMessage
from utils.resp import resp_success, Response from utils.resp import resp_success, Response
from langchain_deepseek import ChatDeepSeek from langchain_deepseek import ChatDeepSeek
router = APIRouter() router = APIRouter(prefix="/chat", tags=["chat"])
def get_deepseek_llm(api_key: SecretStr, model: str): def get_deepseek_llm(api_key: SecretStr, model: str):
# deepseek 兼容 OpenAI API需指定 base_url # deepseek 兼容 OpenAI API需指定 base_url

View File

@@ -0,0 +1,110 @@
import json
import os
from fastapi import APIRouter, Depends, HTTPException, Body, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from db.session import get_db
from deps.auth import get_current_user
from llm.factory import get_adapter
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
router = APIRouter(prefix="/drawing", tags=["drawing"])
@router.get("/")
def api_get_image_page(
page: int = Query(1, ge=1),
page_size: int = Query(12, ge=1, le=100),
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
data = get_drawing_page(db, user_id=user["user_id"], page=page, page_size=page_size)
# 序列化 items
data["items"] = [
{
"id": img.id,
"prompt": img.prompt,
"pic_url": img.pic_url,
"status": img.status,
"error_message": img.error_message,
"created_at": img.created_at if hasattr(img, 'created_at') else None
}
for img in data["items"]
]
return data
class CreateDrawingTaskRequest(BaseModel):
prompt: str
style: str = 'auto'
size: str = '1024x1024'
model: str = 'wanx_v1'
platform: str = 'tongyi'
n: int = 1
@router.post("/")
def api_create_image_task(
req: CreateDrawingTaskRequest,
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
user_id = user["user_id"]
style = req.style
size = req.size
platform = req.platform
n = req.n
prompt = req.prompt
model = req.model
print(user_id, req.platform, req.size, req.model, req.prompt)
api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model=model)
try:
# rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# print(rsp, 'rsp')
res_json = {
"status_code": 200,
"request_id": "31b04171-011c-96bd-ac00-f0383b669cc7",
"code": "",
"message": "",
"output": {
"task_id": "4f90cf14-a34e-4eae-xxxxxxxx",
"task_status": "PENDING",
"results": []
},
"usage": None
}
rsp = res_json
if rsp['status_code'] != 200:
raise HTTPException(status_code=500, detail=rsp['message'])
option = {
'style': style
}
drawing = create_drawing_task(
db=db,
user_id=user["user_id"],
platform=platform,
model=model,
rsp=rsp,
prompt=prompt,
size=size,
options=json.dumps(option)
)
return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status}
except NotImplementedError:
print("该服务商不支持图片生成")
@router.get("/{id}")
def api_fetch_image_task_status(
id: int,
db: Session = Depends(get_db)
):
image, err = fetch_drawing_task_status(db, id)
if not image:
raise HTTPException(status_code=404, detail=err or "任务不存在")
return {
"id": image.id,
"status": image.status,
"pic_url": image.pic_url,
"error_message": image.error_message
}

View File

@@ -15,7 +15,7 @@ class OpenAIAdapter(MultiModalAICapability):
yield chunk yield chunk
# 如需图片生成DALL·E可实现如下 # 如需图片生成DALL·E可实现如下
def create_image_task(self, prompt, **kwargs): def create_drawing_task(self, **kwargs):
# 伪代码,需用 openai.Image.create # 伪代码,需用 openai.Image.create
# import openai # import openai
# response = openai.Image.create(api_key=self.api_key, prompt=prompt, ...) # response = openai.Image.create(api_key=self.api_key, prompt=prompt, ...)

View File

@@ -23,24 +23,20 @@ class TongYiAdapter(MultiModalAICapability):
async for chunk in self.llm.astream(messages): async for chunk in self.llm.astream(messages):
yield chunk yield chunk
@staticmethod def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
def create_image_task(api_key, model, prompt: str, style='<watercolor>', size='1024*1024', n=1): print(self.model, self.api_key, 'key')
"""创建异步图片生成任务""" """创建异步图片生成任务"""
rsp = ImageSynthesis.async_call( rsp = ImageSynthesis.async_call(
api_key=api_key, api_key=self.api_key,
model=model, model=self.model,
prompt=prompt, prompt=prompt,
n=n, n=n,
style=style, style=f'<{style}>',
size=size size=size
) )
if rsp.status_code == HTTPStatus.OK: print(rsp, 'rsp')
return rsp
else:
raise Exception(f"Failed, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}")
@staticmethod def fetch_drawing_task_status(self, task):
def fetch_image_task_status(task):
"""获取异步图片任务状态""" """获取异步图片任务状态"""
status = ImageSynthesis.fetch(task) status = ImageSynthesis.fetch(task)
if status.status_code == HTTPStatus.OK: if status.status_code == HTTPStatus.OK:

View File

@@ -9,14 +9,14 @@ class MultiModalAICapability(ABC):
raise NotImplementedError("stream_chat not supported by this provider") raise NotImplementedError("stream_chat not supported by this provider")
# 图片生成能力 # 图片生成能力
def create_image_task(self, prompt, **kwargs): def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
raise NotImplementedError("image generation not supported by this provider") raise NotImplementedError("drawing generation not supported by this provider")
def fetch_image_task_status(self, task): def fetch_drawing_task_status(self, task):
raise NotImplementedError("image task status not supported by this provider") raise NotImplementedError("drawing task status not supported by this provider")
def fetch_image_result(self, task): def fetch_drawing_result(self, task):
raise NotImplementedError("image result not supported by this provider") raise NotImplementedError("drawing result not supported by this provider")
# 视频生成能力 # 视频生成能力
def create_video_task(self, prompt, **kwargs): def create_video_task(self, prompt, **kwargs):

View File

@@ -1,8 +1,8 @@
import os import os
from fastapi import FastAPI from fastapi import FastAPI
from dotenv import load_dotenv from dotenv import load_dotenv
from api.v1 import ai_chat
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from api.v1 import api_v1_router
# 加载.env环境变量优先项目根目录 # 加载.env环境变量优先项目根目录
load_dotenv() load_dotenv()
@@ -23,7 +23,8 @@ app.add_middleware(
) )
# 注册路由 # 注册路由
app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"]) # app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"])
app.include_router(api_v1_router)
# 健康检查 # 健康检查
@app.get("/ping") @app.get("/ping")

View File

@@ -253,8 +253,8 @@ class ChatRoleTool(Base):
tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True) tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)
class Image(CoreModel): class Drawing(CoreModel):
__tablename__ = 'ai_image' __tablename__ = 'ai_drawing'
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True) user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
public_status = Column(Boolean, default=False, nullable=False) public_status = Column(Boolean, default=False, nullable=False)

View File

@@ -0,0 +1,62 @@
from datetime import datetime
from dashscope import ImageSynthesis
from http import HTTPStatus
from sqlalchemy import desc
from models.ai import Drawing
from sqlalchemy.orm import Session
def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp,
options: str = None):
# 写入数据库
drawing = Drawing(
user_id=user_id,
platform=platform,
model=model,
prompt=prompt,
width=int(size.split('x')[0]),
height=int(size.split('x')[1]),
create_time=datetime.now(),
update_time=datetime.now(),
options=options,
status=rsp['output']['task_status'],
task_id=rsp['output']['task_id'],
error_message=rsp['message']
)
db.add(drawing)
db.commit()
db.refresh(drawing)
return drawing
def fetch_drawing_task_status(db: Session, drawing_id: int):
drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first()
if not drawing or not drawing.task_id:
return None, "任务不存在"
status = ImageSynthesis.fetch(drawing.task_id)
if status.status_code == HTTPStatus.OK:
# 可根据 status.output.task_status 更新数据库
drawing.status = status.output.task_status
if hasattr(status.output, 'results') and status.output.results:
drawing.pic_url = status.output.results[0].url
db.commit()
db.refresh(drawing)
return drawing, None
else:
return None, status.message
def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12):
query = db.query(Drawing)
if user_id:
query = query.filter(Drawing.user_id == user_id)
total = query.count()
items = query.order_by(desc(Drawing.id)).offset((page - 1) * page_size).limit(page_size).all()
return {
'total': total,
'page': page,
'page_size': page_size,
'items': items
}

View File

@@ -0,0 +1,23 @@
# Generated by Django 5.2.1 on 2025-07-21 13:24
from django.conf import settings
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("ai", "0005_image"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.RenameModel(
old_name="Image",
new_name="Drawing",
),
migrations.AlterModelTable(
name="drawing",
table="ai_drawing",
),
]

View File

@@ -332,7 +332,7 @@ class ChatMessage(CoreModel):
return self.content[:30] return self.content[:30]
class Image(CoreModel): class Drawing(CoreModel):
user = models.ForeignKey( user = models.ForeignKey(
settings.AUTH_USER_MODEL, settings.AUTH_USER_MODEL,
@@ -359,6 +359,6 @@ class Image(CoreModel):
buttons = models.CharField(max_length=2048, null=True, verbose_name='mj buttons 按钮') buttons = models.CharField(max_length=2048, null=True, verbose_name='mj buttons 按钮')
class Meta: class Meta:
db_table = 'ai_image' db_table = 'ai_drawing'
verbose_name = 'AI 绘画表' verbose_name = 'AI 绘画表'
verbose_name_plural = verbose_name verbose_name_plural = verbose_name

View File

@@ -1,12 +1,12 @@
import { fetchWithAuth } from '#/utils/fetch-with-auth'; import { fetchWithAuth } from '#/utils/fetch-with-auth';
export async function getConversations() { export async function getConversations() {
const res = await fetchWithAuth('/api/ai/v1/conversations'); const res = await fetchWithAuth('chat/conversations');
return await res.json(); return await res.json();
} }
export async function createConversation() { export async function createConversation() {
const response = await fetchWithAuth('/api/ai/v1/conversations', { const response = await fetchWithAuth('chat/conversations', {
method: 'POST', method: 'POST',
}); });
if (!response.ok) { if (!response.ok) {
@@ -17,7 +17,7 @@ export async function createConversation() {
export async function getMessages(conversationId: number) { export async function getMessages(conversationId: number) {
const res = await fetchWithAuth( const res = await fetchWithAuth(
`/api/ai/v1/messages?conversation_id=${conversationId}`, `chat/messages?conversation_id=${conversationId}`,
); );
return await res.json(); return await res.json();
} }
@@ -32,7 +32,7 @@ export async function fetchAIStream({
content, content,
conversation_id, conversation_id,
}: FetchAIStreamParams) { }: FetchAIStreamParams) {
const res = await fetchWithAuth('/api/ai/v1/stream', { const res = await fetchWithAuth('chat/stream', {
method: 'POST', method: 'POST',
body: JSON.stringify({ content, conversation_id }), body: JSON.stringify({ content, conversation_id }),
}); });

View File

@@ -0,0 +1,47 @@
import { fetchWithAuth } from '#/utils/fetch-with-auth';
export interface CreateImageTaskParams {
prompt: string;
style?: string;
size?: string;
model?: string;
platform?: string;
n?: number;
}
export async function createImageTask(params: CreateImageTaskParams) {
const res = await fetchWithAuth('drawing/', {
method: 'POST',
body: JSON.stringify(params),
});
if (!res.ok) {
throw new Error('创建图片任务失败');
}
return await res.json();
}
export async function fetchImageTaskStatus(id: number) {
const res = await fetchWithAuth(`drawing/${id}/`);
if (!res.ok) {
throw new Error('查询图片任务状态失败');
}
return await res.json();
}
export interface GetImagePageParams {
page?: number;
page_size?: number;
}
export async function getImagePage(params: GetImagePageParams = {}) {
const query = new URLSearchParams();
if (params.page) query.append('page', String(params.page));
if (params.page_size) query.append('page_size', String(params.page_size));
const res = await fetchWithAuth(
`drawing${query.toString() ? `?${query.toString()}` : ''}`,
);
if (!res.ok) {
throw new Error('获取图片分页失败');
}
return await res.json();
}

View File

@@ -20,6 +20,10 @@
"title": "AI CHAT", "title": "AI CHAT",
"name": "AI CHAT" "name": "AI CHAT"
}, },
"drawing": {
"title": "AI DRAWING",
"name": "AI DRAWING"
},
"chat_conversation": { "chat_conversation": {
"title": "CHAT Management", "title": "CHAT Management",
"name": "CHAT Management" "name": "CHAT Management"

View File

@@ -20,6 +20,10 @@
"title": "AI对话", "title": "AI对话",
"name": "AI对话" "name": "AI对话"
}, },
"drawing": {
"title": "AI绘画",
"name": "AI绘画"
},
"chat_conversation": { "chat_conversation": {
"title": "对话列表", "title": "对话列表",
"name": "对话列表" "name": "对话列表"

View File

@@ -1,11 +1,14 @@
import { formatToken } from '#/utils/auth';
import { useAccessStore } from '@vben/stores'; import { useAccessStore } from '@vben/stores';
import { formatToken } from '#/utils/auth';
export const API_BASE = '/api/ai/v1/';
export function fetchWithAuth(input: RequestInfo, init: RequestInit = {}) { export function fetchWithAuth(input: RequestInfo, init: RequestInit = {}) {
const accessStore = useAccessStore(); const accessStore = useAccessStore();
const token = accessStore.accessToken; const token = accessStore.accessToken;
const headers = new Headers(init.headers || {}); const headers = new Headers(init.headers || {});
headers.append('Content-Type', 'application/json'); headers.append('Content-Type', 'application/json');
headers.append('Authorization', formatToken(token) as string); headers.append('Authorization', formatToken(token) as string);
return fetch(input, { ...init, headers }); return fetch(API_BASE + input, { ...init, headers });
} }

View File

@@ -0,0 +1,200 @@
<script setup lang="ts">
import { computed, h, reactive, ref } from 'vue';
import { Page } from '@vben/common-ui';
import {
Button,
Card,
Col,
Form,
Input,
message,
Pagination,
Row,
Select,
} from 'ant-design-vue';
import { createImageTask } from '#/api/ai/drawing';
// 表单选项
const platforms = [
{ label: '通义千问', value: 'tongyi' },
// { label: 'DeepSeek', value: 'deepseek' },
// { label: 'OpenAI', value: 'openai' },
// { label: 'Google GenAI', value: 'google-genai' },
];
const models = {
tongyi: [{ label: 'wanx_v1', value: 'wanx_v1' }],
// deepseek: [{ label: 'deepseek-img', value: 'deepseek-img' }],
// openai: [{ label: 'dall-e-3', value: 'dall-e-3' }],
// 'google-genai': [{ label: 'imagen', value: 'imagen' }],
};
const sizes = [
{ label: '1024x1024', value: '1024x1024' },
{ label: '720*1280', value: '720*1280' },
{ label: '768*1152', value: '768*1152' },
{ label: '1280*720', value: '1280*720' },
];
const styles = [
{ label: '默认(由模型随机输出风格)', value: 'auto' },
{ label: '摄影', value: 'photography' },
{ label: '人像写真', value: 'portrait' },
{ label: '3D卡通', value: '3d cartoon' },
{ label: '动画', value: 'anime' },
{ label: '油画', value: 'oil painting' },
{ label: '水彩', value: 'watercolor' },
{ label: '素描', value: 'sketch' },
{ label: '中国画', value: 'chinese painting' },
{ label: '扁平插画', value: 'flat illustration' },
];
// 表单数据
const form = reactive({
prompt: '近景镜头18岁的中国女孩古代服饰圆脸正面看着镜头民族优雅的服装商业摄影室外电影级光照半身特写精致的淡妆锐利的边缘。',
platform: 'tongyi',
model: 'wanx_v1',
size: '1024x1024',
style: 'watercolor',
});
// 图片数据与分页
const images = ref<string[]>([]);
const loading = ref(false);
const page = ref(1);
const pageSize = 12;
const total = computed(() => images.value.length);
const pagedImages = computed(() =>
images.value.slice((page.value - 1) * pageSize, page.value * pageSize),
);
// 平台切换时自动切换模型
const onPlatformChange = (val: string) => {
form.model = models[val][0].value;
};
// 提交表单调用AI画图API
async function handleDraw() {
loading.value = true;
try {
// 这里调用你的AI画图API返回图片url数组
const res = await createImageTask(form);
// images.value = res.data.images;
// DEMO用假数据
// images.value = Array.from(
// { length: 30 },
// (_, i) => `https://picsum.photos/seed/${form.prompt}-${i}/300/300`,
// );
page.value = 1;
message.success('生成成功');
} catch {
message.error('生成失败');
} finally {
loading.value = false;
}
}
</script>
<template>
<Page auto-content-height>
<Row :gutter="24">
<!-- 左侧表单 -->
<Col :span="8">
<Card title="AI画图" bordered>
<Form layout="vertical" @submit.prevent="handleDraw">
<Form.Item label="画画描述">
<Input.TextArea
v-model:value="form.prompt"
:autosize="true"
placeholder="请输入画面描述"
/>
</Form.Item>
<Form.Item label="平台选择">
<Select v-model:value="form.platform" @change="onPlatformChange">
<Select.Option
v-for="item in platforms"
:key="item.value"
:value="item.value"
>
{{ item.label }}
</Select.Option>
</Select>
</Form.Item>
<Form.Item label="模型选择">
<Select v-model:value="form.model">
<Select.Option
v-for="item in models[form.platform]"
:key="item.value"
:value="item.value"
>
{{ item.label }}
</Select.Option>
</Select>
</Form.Item>
<Form.Item label="图片尺寸">
<Select v-model:value="form.size">
<Select.Option
v-for="item in sizes"
:key="item.value"
:value="item.value"
>
{{ item.label }}
</Select.Option>
</Select>
</Form.Item>
<Form.Item label="图像风格">
<Select v-model:value="form.style">
<Select.Option
v-for="item in styles"
:key="item.value"
:value="item.value"
>
{{ item.label }}
</Select.Option>
</Select>
</Form.Item>
<Form.Item>
<Button type="primary" html-type="submit" :loading="loading">
生成图片
</Button>
</Form.Item>
</Form>
</Card>
</Col>
<!-- 右侧图片展示 -->
<Col :span="16">
<Card title="生成结果" bordered>
<Row :gutter="16">
<Col
v-for="(img, idx) in pagedImages"
:key="img"
:span="4"
style="margin-bottom: 16px"
>
<Card
hoverable
:cover="
h('img', {
src: img,
style: 'width:100%;height:180px;object-fit:cover;',
})
"
/>
</Col>
</Row>
<Pagination
v-if="total > pageSize"
v-model:current="page"
:total="total"
:page-size="pageSize"
style="margin-top: 16px; text-align: right"
/>
</Card>
</Col>
</Row>
</Page>
</template>
<style scoped>
/* 可根据需要自定义样式 */
</style>

View File

@@ -1,11 +0,0 @@
<script setup lang="ts">
</script>
<template>
<div>dsadsa</div>
</template>
<style scoped lang="css">
</style>