diff --git a/ai_service/api/v1/__init__.py b/ai_service/api/v1/__init__.py new file mode 100644 index 0000000..3350a52 --- /dev/null +++ b/ai_service/api/v1/__init__.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from .chat import router as chat_router +from .drawing import router as drawing_router +# from .video import router as video_router +# from .audio import router as audio_router +# from .multimodal import router as multimodal_router +# from .model_manage import router as model_manage_router +# from .knowledge import router as knowledge_router +# from .system import router as system_router +# from .user import router as user_router + +api_v1_router = APIRouter(prefix="/api/ai/v1") + +api_v1_router.include_router(chat_router) +api_v1_router.include_router(drawing_router) +# api_v1_router.include_router(video_router) +# api_v1_router.include_router(audio_router) +# api_v1_router.include_router(multimodal_router) +# api_v1_router.include_router(model_manage_router) +# api_v1_router.include_router(knowledge_router) +# api_v1_router.include_router(system_router) +# api_v1_router.include_router(user_router) \ No newline at end of file diff --git a/ai_service/api/v1/ai_chat.py b/ai_service/api/v1/chat/__init__.py similarity index 98% rename from ai_service/api/v1/ai_chat.py rename to ai_service/api/v1/chat/__init__.py index 3cf5270..3b3c464 100644 --- a/ai_service/api/v1/ai_chat.py +++ b/ai_service/api/v1/chat/__init__.py @@ -8,7 +8,7 @@ from typing import List from pydantic import BaseModel, SecretStr from langchain.chains import ConversationChain -from api.v1.vo import MessageVO +from api.v1.chat.vo import MessageVO from deps.auth import get_current_user from services.chat_service import ChatDBService from db.session import get_db @@ -16,7 +16,7 @@ from models.ai import ChatConversation, ChatMessage from utils.resp import resp_success, Response from langchain_deepseek import ChatDeepSeek -router = APIRouter() +router = APIRouter(prefix="/chat", tags=["chat"]) def get_deepseek_llm(api_key: SecretStr, model: str): # deepseek 兼容 OpenAI API,需指定 base_url diff --git a/ai_service/api/v1/vo.py b/ai_service/api/v1/chat/vo.py similarity index 100% rename from ai_service/api/v1/vo.py rename to ai_service/api/v1/chat/vo.py diff --git a/ai_service/api/v1/drawing/__init__.py b/ai_service/api/v1/drawing/__init__.py new file mode 100644 index 0000000..5e8a045 --- /dev/null +++ b/ai_service/api/v1/drawing/__init__.py @@ -0,0 +1,110 @@ +import json +import os + +from fastapi import APIRouter, Depends, HTTPException, Body, Query +from pydantic import BaseModel +from sqlalchemy.orm import Session +from db.session import get_db +from deps.auth import get_current_user +from llm.factory import get_adapter +from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status + +router = APIRouter(prefix="/drawing", tags=["drawing"]) + +@router.get("/") +def api_get_image_page( + page: int = Query(1, ge=1), + page_size: int = Query(12, ge=1, le=100), + db: Session = Depends(get_db), + user=Depends(get_current_user) +): + data = get_drawing_page(db, user_id=user["user_id"], page=page, page_size=page_size) + # 序列化 items + data["items"] = [ + { + "id": img.id, + "prompt": img.prompt, + "pic_url": img.pic_url, + "status": img.status, + "error_message": img.error_message, + "created_at": img.created_at if hasattr(img, 'created_at') else None + } + for img in data["items"] + ] + return data + +class CreateDrawingTaskRequest(BaseModel): + prompt: str + style: str = 'auto' + size: str = '1024x1024' + model: str = 'wanx_v1' + platform: str = 'tongyi' + n: int = 1 + + +@router.post("/") +def api_create_image_task( + req: CreateDrawingTaskRequest, + db: Session = Depends(get_db), + user=Depends(get_current_user) +): + user_id = user["user_id"] + style = req.style + size = req.size + platform = req.platform + n = req.n + prompt = req.prompt + model = req.model + print(user_id, req.platform, req.size, req.model, req.prompt) + api_key = os.getenv("DASHSCOPE_API_KEY") + adapter = get_adapter('tongyi', api_key=api_key, model=model) + try: + # rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size) + # print(rsp, 'rsp') + res_json = { + "status_code": 200, + "request_id": "31b04171-011c-96bd-ac00-f0383b669cc7", + "code": "", + "message": "", + "output": { + "task_id": "4f90cf14-a34e-4eae-xxxxxxxx", + "task_status": "PENDING", + "results": [] + }, + "usage": None + } + rsp = res_json + if rsp['status_code'] != 200: + raise HTTPException(status_code=500, detail=rsp['message']) + option = { + 'style': style + } + drawing = create_drawing_task( + db=db, + user_id=user["user_id"], + platform=platform, + model=model, + rsp=rsp, + prompt=prompt, + size=size, + options=json.dumps(option) + ) + return {"id": drawing.id, "task_id": drawing.task_id, "status": drawing.status} + except NotImplementedError: + print("该服务商不支持图片生成") + + +@router.get("/{id}") +def api_fetch_image_task_status( + id: int, + db: Session = Depends(get_db) +): + image, err = fetch_drawing_task_status(db, id) + if not image: + raise HTTPException(status_code=404, detail=err or "任务不存在") + return { + "id": image.id, + "status": image.status, + "pic_url": image.pic_url, + "error_message": image.error_message + } diff --git a/ai_service/llm/adapter/openai.py b/ai_service/llm/adapter/openai.py index 1a3adb9..021138a 100644 --- a/ai_service/llm/adapter/openai.py +++ b/ai_service/llm/adapter/openai.py @@ -15,7 +15,7 @@ class OpenAIAdapter(MultiModalAICapability): yield chunk # 如需图片生成(DALL·E),可实现如下 - def create_image_task(self, prompt, **kwargs): + def create_drawing_task(self, **kwargs): # 伪代码,需用 openai.Image.create # import openai # response = openai.Image.create(api_key=self.api_key, prompt=prompt, ...) diff --git a/ai_service/llm/adapter/tongyi.py b/ai_service/llm/adapter/tongyi.py index 7d7af82..2079843 100644 --- a/ai_service/llm/adapter/tongyi.py +++ b/ai_service/llm/adapter/tongyi.py @@ -23,24 +23,20 @@ class TongYiAdapter(MultiModalAICapability): async for chunk in self.llm.astream(messages): yield chunk - @staticmethod - def create_image_task(api_key, model, prompt: str, style='', size='1024*1024', n=1): + def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs): + print(self.model, self.api_key, 'key') """创建异步图片生成任务""" rsp = ImageSynthesis.async_call( - api_key=api_key, - model=model, + api_key=self.api_key, + model=self.model, prompt=prompt, n=n, - style=style, + style=f'<{style}>', size=size ) - if rsp.status_code == HTTPStatus.OK: - return rsp - else: - raise Exception(f"Failed, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}") + print(rsp, 'rsp') - @staticmethod - def fetch_image_task_status(task): + def fetch_drawing_task_status(self, task): """获取异步图片任务状态""" status = ImageSynthesis.fetch(task) if status.status_code == HTTPStatus.OK: diff --git a/ai_service/llm/base.py b/ai_service/llm/base.py index 8f82046..29c42c3 100644 --- a/ai_service/llm/base.py +++ b/ai_service/llm/base.py @@ -9,14 +9,14 @@ class MultiModalAICapability(ABC): raise NotImplementedError("stream_chat not supported by this provider") # 图片生成能力 - def create_image_task(self, prompt, **kwargs): - raise NotImplementedError("image generation not supported by this provider") + def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs): + raise NotImplementedError("drawing generation not supported by this provider") - def fetch_image_task_status(self, task): - raise NotImplementedError("image task status not supported by this provider") + def fetch_drawing_task_status(self, task): + raise NotImplementedError("drawing task status not supported by this provider") - def fetch_image_result(self, task): - raise NotImplementedError("image result not supported by this provider") + def fetch_drawing_result(self, task): + raise NotImplementedError("drawing result not supported by this provider") # 视频生成能力 def create_video_task(self, prompt, **kwargs): diff --git a/ai_service/main.py b/ai_service/main.py index 4e3a2a9..672d20f 100644 --- a/ai_service/main.py +++ b/ai_service/main.py @@ -1,8 +1,8 @@ import os from fastapi import FastAPI from dotenv import load_dotenv -from api.v1 import ai_chat from fastapi.middleware.cors import CORSMiddleware +from api.v1 import api_v1_router # 加载.env环境变量,优先项目根目录 load_dotenv() @@ -23,7 +23,8 @@ app.add_middleware( ) # 注册路由 -app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"]) +# app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"]) +app.include_router(api_v1_router) # 健康检查 @app.get("/ping") diff --git a/ai_service/models/ai.py b/ai_service/models/ai.py index edd82e4..8aad029 100644 --- a/ai_service/models/ai.py +++ b/ai_service/models/ai.py @@ -253,8 +253,8 @@ class ChatRoleTool(Base): tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True) -class Image(CoreModel): - __tablename__ = 'ai_image' +class Drawing(CoreModel): + __tablename__ = 'ai_drawing' user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True) public_status = Column(Boolean, default=False, nullable=False) diff --git a/ai_service/services/drawing_service.py b/ai_service/services/drawing_service.py new file mode 100644 index 0000000..0792b50 --- /dev/null +++ b/ai_service/services/drawing_service.py @@ -0,0 +1,62 @@ +from datetime import datetime + +from dashscope import ImageSynthesis +from http import HTTPStatus + +from sqlalchemy import desc + +from models.ai import Drawing +from sqlalchemy.orm import Session + + +def create_drawing_task(db: Session, user_id: int, platform: str, model: str, prompt: str, size: str, rsp, + options: str = None): + # 写入数据库 + drawing = Drawing( + user_id=user_id, + platform=platform, + model=model, + prompt=prompt, + width=int(size.split('x')[0]), + height=int(size.split('x')[1]), + create_time=datetime.now(), + update_time=datetime.now(), + options=options, + status=rsp['output']['task_status'], + task_id=rsp['output']['task_id'], + error_message=rsp['message'] + ) + db.add(drawing) + db.commit() + db.refresh(drawing) + return drawing + +def fetch_drawing_task_status(db: Session, drawing_id: int): + drawing = db.query(Drawing).filter(Drawing.id == drawing_id).first() + if not drawing or not drawing.task_id: + return None, "任务不存在" + status = ImageSynthesis.fetch(drawing.task_id) + if status.status_code == HTTPStatus.OK: + # 可根据 status.output.task_status 更新数据库 + drawing.status = status.output.task_status + if hasattr(status.output, 'results') and status.output.results: + drawing.pic_url = status.output.results[0].url + db.commit() + db.refresh(drawing) + return drawing, None + else: + return None, status.message + + +def get_drawing_page(db: Session, user_id: int = None, page: int = 1, page_size: int = 12): + query = db.query(Drawing) + if user_id: + query = query.filter(Drawing.user_id == user_id) + total = query.count() + items = query.order_by(desc(Drawing.id)).offset((page - 1) * page_size).limit(page_size).all() + return { + 'total': total, + 'page': page, + 'page_size': page_size, + 'items': items + } \ No newline at end of file diff --git a/backend/ai/migrations/0006_rename_image_drawing_alter_drawing_table.py b/backend/ai/migrations/0006_rename_image_drawing_alter_drawing_table.py new file mode 100644 index 0000000..f9e55c2 --- /dev/null +++ b/backend/ai/migrations/0006_rename_image_drawing_alter_drawing_table.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.1 on 2025-07-21 13:24 + +from django.conf import settings +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("ai", "0005_image"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.RenameModel( + old_name="Image", + new_name="Drawing", + ), + migrations.AlterModelTable( + name="drawing", + table="ai_drawing", + ), + ] diff --git a/backend/ai/models.py b/backend/ai/models.py index 53d0ea7..010a10a 100644 --- a/backend/ai/models.py +++ b/backend/ai/models.py @@ -332,7 +332,7 @@ class ChatMessage(CoreModel): return self.content[:30] -class Image(CoreModel): +class Drawing(CoreModel): user = models.ForeignKey( settings.AUTH_USER_MODEL, @@ -359,6 +359,6 @@ class Image(CoreModel): buttons = models.CharField(max_length=2048, null=True, verbose_name='mj buttons 按钮') class Meta: - db_table = 'ai_image' + db_table = 'ai_drawing' verbose_name = 'AI 绘画表' verbose_name_plural = verbose_name diff --git a/web/apps/web-antd/src/api/ai/chat.ts b/web/apps/web-antd/src/api/ai/chat.ts index c48364a..e590689 100644 --- a/web/apps/web-antd/src/api/ai/chat.ts +++ b/web/apps/web-antd/src/api/ai/chat.ts @@ -1,12 +1,12 @@ import { fetchWithAuth } from '#/utils/fetch-with-auth'; export async function getConversations() { - const res = await fetchWithAuth('/api/ai/v1/conversations'); + const res = await fetchWithAuth('chat/conversations'); return await res.json(); } export async function createConversation() { - const response = await fetchWithAuth('/api/ai/v1/conversations', { + const response = await fetchWithAuth('chat/conversations', { method: 'POST', }); if (!response.ok) { @@ -17,7 +17,7 @@ export async function createConversation() { export async function getMessages(conversationId: number) { const res = await fetchWithAuth( - `/api/ai/v1/messages?conversation_id=${conversationId}`, + `chat/messages?conversation_id=${conversationId}`, ); return await res.json(); } @@ -32,7 +32,7 @@ export async function fetchAIStream({ content, conversation_id, }: FetchAIStreamParams) { - const res = await fetchWithAuth('/api/ai/v1/stream', { + const res = await fetchWithAuth('chat/stream', { method: 'POST', body: JSON.stringify({ content, conversation_id }), }); diff --git a/web/apps/web-antd/src/api/ai/drawing.ts b/web/apps/web-antd/src/api/ai/drawing.ts new file mode 100644 index 0000000..2eaf5b5 --- /dev/null +++ b/web/apps/web-antd/src/api/ai/drawing.ts @@ -0,0 +1,47 @@ +import { fetchWithAuth } from '#/utils/fetch-with-auth'; + +export interface CreateImageTaskParams { + prompt: string; + style?: string; + size?: string; + model?: string; + platform?: string; + n?: number; +} + +export async function createImageTask(params: CreateImageTaskParams) { + const res = await fetchWithAuth('drawing/', { + method: 'POST', + body: JSON.stringify(params), + }); + if (!res.ok) { + throw new Error('创建图片任务失败'); + } + return await res.json(); +} + +export async function fetchImageTaskStatus(id: number) { + const res = await fetchWithAuth(`drawing/${id}/`); + if (!res.ok) { + throw new Error('查询图片任务状态失败'); + } + return await res.json(); +} + +export interface GetImagePageParams { + page?: number; + page_size?: number; +} + +export async function getImagePage(params: GetImagePageParams = {}) { + const query = new URLSearchParams(); + if (params.page) query.append('page', String(params.page)); + if (params.page_size) query.append('page_size', String(params.page_size)); + const res = await fetchWithAuth( + `drawing${query.toString() ? `?${query.toString()}` : ''}`, + ); + if (!res.ok) { + throw new Error('获取图片分页失败'); + } + return await res.json(); +} diff --git a/web/apps/web-antd/src/locales/langs/en-US/ai.json b/web/apps/web-antd/src/locales/langs/en-US/ai.json index b5a7415..8f4ef08 100644 --- a/web/apps/web-antd/src/locales/langs/en-US/ai.json +++ b/web/apps/web-antd/src/locales/langs/en-US/ai.json @@ -20,6 +20,10 @@ "title": "AI CHAT", "name": "AI CHAT" }, + "drawing": { + "title": "AI DRAWING", + "name": "AI DRAWING" + }, "chat_conversation": { "title": "CHAT Management", "name": "CHAT Management" diff --git a/web/apps/web-antd/src/locales/langs/zh-CN/ai.json b/web/apps/web-antd/src/locales/langs/zh-CN/ai.json index 47ad5b1..9d5a317 100644 --- a/web/apps/web-antd/src/locales/langs/zh-CN/ai.json +++ b/web/apps/web-antd/src/locales/langs/zh-CN/ai.json @@ -20,6 +20,10 @@ "title": "AI对话", "name": "AI对话" }, + "drawing": { + "title": "AI绘画", + "name": "AI绘画" + }, "chat_conversation": { "title": "对话列表", "name": "对话列表" diff --git a/web/apps/web-antd/src/utils/fetch-with-auth.ts b/web/apps/web-antd/src/utils/fetch-with-auth.ts index 1a1543b..c53f527 100644 --- a/web/apps/web-antd/src/utils/fetch-with-auth.ts +++ b/web/apps/web-antd/src/utils/fetch-with-auth.ts @@ -1,11 +1,14 @@ -import { formatToken } from '#/utils/auth'; import { useAccessStore } from '@vben/stores'; +import { formatToken } from '#/utils/auth'; + +export const API_BASE = '/api/ai/v1/'; + export function fetchWithAuth(input: RequestInfo, init: RequestInit = {}) { const accessStore = useAccessStore(); const token = accessStore.accessToken; const headers = new Headers(init.headers || {}); headers.append('Content-Type', 'application/json'); headers.append('Authorization', formatToken(token) as string); - return fetch(input, { ...init, headers }); + return fetch(API_BASE + input, { ...init, headers }); } diff --git a/web/apps/web-antd/src/views/ai/drawing/index.vue b/web/apps/web-antd/src/views/ai/drawing/index.vue new file mode 100644 index 0000000..540ba06 --- /dev/null +++ b/web/apps/web-antd/src/views/ai/drawing/index.vue @@ -0,0 +1,200 @@ + + + + + diff --git a/web/apps/web-antd/src/views/ai/image/index.vue b/web/apps/web-antd/src/views/ai/image/index.vue deleted file mode 100644 index 1d68782..0000000 --- a/web/apps/web-antd/src/views/ai/image/index.vue +++ /dev/null @@ -1,11 +0,0 @@ - - - - -