diff --git a/backend/ai/llm/__init__.py b/backend/ai/llm/__init__.py new file mode 100644 index 0000000..059fdd8 --- /dev/null +++ b/backend/ai/llm/__init__.py @@ -0,0 +1,50 @@ +import os +from enum import Enum + +from pydantic import SecretStr + + +class ProviderEnum(str, Enum): + """支持的 LLM 服务商""" + DEEPSEEK = "deepseek" + OPENAI = "openai" + TONGYI = "tongyi" + +class LLMFactory(object): + + @staticmethod + def get_llm(provider: ProviderEnum, model: str = None, **kwargs): + if provider == ProviderEnum.DEEPSEEK: + from langchain_deepseek import ChatDeepSeek + api_key = os.getenv("DEEPSEEK_API_KEY") + model = model or "deepseek-chat" + return ChatDeepSeek( + api_key=SecretStr(api_key), + model=model, + streaming=True, + **kwargs + ) + + elif provider == ProviderEnum.OPENAI: + from langchain_openai import ChatOpenAI + api_key = os.getenv("OPENAI_API_KEY") + model = model or "gpt-3.5-turbo" + return ChatOpenAI( + api_key=SecretStr(api_key), + model=model, + streaming=True, + **kwargs + ) + + elif provider == ProviderEnum.TONGYI: + from langchain_community.llms import Tongyi + api_key = os.getenv("DASHSCOPE_API_KEY") + model = model or "qwen-turbo" + return Tongyi( + api_key=SecretStr(api_key), + model=model, + streaming=True, + **kwargs + ) + else: + raise ValueError(f"不支持的 LLM 服务商: {provider}") diff --git a/backend/ai/llm/adapter/__init__.py b/backend/ai/llm/adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/ai/llm/adapter/deepseek.py b/backend/ai/llm/adapter/deepseek.py new file mode 100644 index 0000000..615923e --- /dev/null +++ b/backend/ai/llm/adapter/deepseek.py @@ -0,0 +1,17 @@ +from langchain_deepseek import ChatDeepSeek + +from llm.base import MultiModalAICapability + + +class DeepSeekAdapter(MultiModalAICapability): + def __init__(self, api_key, model, **kwargs): + + self.llm = ChatDeepSeek(api_key=api_key, model=model, streaming=True) + + async def chat(self, messages, **kwargs): + # 兼容 DeepSeek 的调用方式 + return await self.llm.ainvoke(messages) + + async def stream_chat(self, messages, **kwargs): + async for chunk in self.llm.astream(messages): + yield chunk \ No newline at end of file diff --git a/backend/ai/llm/adapter/genai.py b/backend/ai/llm/adapter/genai.py new file mode 100644 index 0000000..f16c710 --- /dev/null +++ b/backend/ai/llm/adapter/genai.py @@ -0,0 +1,21 @@ +# 假设有 google genai sdk +# from google_genai import GenAI +from llm.base import MultiModalAICapability + + +class GoogleGenAIAdapter(MultiModalAICapability): + def __init__(self, api_key, model, **kwargs): + self.api_key = api_key + self.model = model + # self.llm = GenAI(api_key=api_key, model=model) + + async def chat(self, messages, **kwargs): + # return await self.llm.chat(messages) + raise NotImplementedError("Google GenAI chat未实现") + + async def stream_chat(self, messages, **kwargs): + # async for chunk in self.llm.stream_chat(messages): + # yield chunk + raise NotImplementedError("Google GenAI stream_chat未实现") + + # 其他能力同理 \ No newline at end of file diff --git a/backend/ai/llm/adapter/openai.py b/backend/ai/llm/adapter/openai.py new file mode 100644 index 0000000..021138a --- /dev/null +++ b/backend/ai/llm/adapter/openai.py @@ -0,0 +1,25 @@ +from llm.base import MultiModalAICapability +from langchain_openai import ChatOpenAI +# from openai import OpenAI # 如需图片/音频/视频等API + +class OpenAIAdapter(MultiModalAICapability): + def __init__(self, api_key, model, **kwargs): + self.llm = ChatOpenAI(api_key=api_key, model=model, streaming=True) + self.api_key = api_key + + async def chat(self, messages, **kwargs): + return await self.llm.ainvoke(messages) + + async def stream_chat(self, messages, **kwargs): + async for chunk in self.llm.astream(messages): + yield chunk + + # 如需图片生成(DALL·E),可实现如下 + def create_drawing_task(self, **kwargs): + # 伪代码,需用 openai.Image.create + # import openai + # response = openai.Image.create(api_key=self.api_key, prompt=prompt, ...) + # return response + raise NotImplementedError("OpenAI 图片生成请用 openai.Image.create 实现") + + # 其他能力同理 \ No newline at end of file diff --git a/backend/ai/llm/adapter/tongyi.py b/backend/ai/llm/adapter/tongyi.py new file mode 100644 index 0000000..1633563 --- /dev/null +++ b/backend/ai/llm/adapter/tongyi.py @@ -0,0 +1,44 @@ +from langchain_community.chat_models import ChatTongyi +from http import HTTPStatus +from urllib.parse import urlparse, unquote +from pathlib import PurePosixPath +import requests +from dashscope import ImageSynthesis +import os + +from llm.base import MultiModalAICapability + + +class TongYiAdapter(MultiModalAICapability): + def __init__(self, api_key, model, **kwargs): + self.api_key = api_key + self.model = model + self.llm = ChatTongyi(api_key=api_key, model=model, streaming=True) + + async def chat(self, messages, **kwargs): + # 兼容 DeepSeek 的调用方式 + return await self.llm.ainvoke(messages) + + async def stream_chat(self, messages, **kwargs): + async for chunk in self.llm.astream(messages): + yield chunk + + def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs): + """创建异步图片生成任务""" + rsp = ImageSynthesis.async_call( + api_key=self.api_key, + model=self.model, + prompt=prompt, + n=n, + style=f'<{style}>', + size=size + ) + return rsp + + def fetch_drawing_task_status(self, task): + """获取异步图片任务状态""" + rsp = ImageSynthesis.fetch(task, api_key=self.api_key) + return rsp + + + \ No newline at end of file diff --git a/backend/ai/llm/base.py b/backend/ai/llm/base.py new file mode 100644 index 0000000..29c42c3 --- /dev/null +++ b/backend/ai/llm/base.py @@ -0,0 +1,37 @@ +from abc import ABC + +class MultiModalAICapability(ABC): + # 对话能力 + async def chat(self, messages, **kwargs): + raise NotImplementedError("chat not supported by this provider") + + async def stream_chat(self, messages, **kwargs): + raise NotImplementedError("stream_chat not supported by this provider") + + # 图片生成能力 + def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs): + raise NotImplementedError("drawing generation not supported by this provider") + + def fetch_drawing_task_status(self, task): + raise NotImplementedError("drawing task status not supported by this provider") + + def fetch_drawing_result(self, task): + raise NotImplementedError("drawing result not supported by this provider") + + # 视频生成能力 + def create_video_task(self, prompt, **kwargs): + raise NotImplementedError("video generation not supported by this provider") + + def fetch_video_task_status(self, task): + raise NotImplementedError("video task status not supported by this provider") + + def fetch_video_result(self, task): + raise NotImplementedError("video result not supported by this provider") + + # 知识库能力 + def query_knowledge(self, query, **kwargs): + raise NotImplementedError("knowledge query not supported by this provider") + + # 语音合成能力 + def synthesize_speech(self, text, **kwargs): + raise NotImplementedError("speech synthesis not supported by this provider") \ No newline at end of file diff --git a/backend/ai/llm/enums.py b/backend/ai/llm/enums.py new file mode 100644 index 0000000..5bb0c3c --- /dev/null +++ b/backend/ai/llm/enums.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class LLMProvider(str, Enum): + """LLM 提供商枚举""" + DEEPSEEK = "deepseek" + TONGYI = "tongyi" + OPENAI = "openai" + GOOGLE_GENAI = "google-genai" + + @classmethod + def get_model_by_platform(cls, platform: str) -> tuple[str, str]: + """根据平台名称获取对应的模型和API密钥环境变量名""" + if platform == cls.TONGYI: + return 'qwen-plus', 'DASHSCOPE_API_KEY' + elif platform == cls.DEEPSEEK: + return 'deepseek-chat', 'DEEPSEEK_API_KEY' + elif platform == cls.OPENAI: + return 'gpt-3.5-turbo', 'OPENAI_API_KEY' + elif platform == cls.GOOGLE_GENAI: + return 'gemini-pro', 'GOOGLE_API_KEY' + else: + # 默认使用 DeepSeek + return 'deepseek-chat', 'DEEPSEEK_API_KEY' + + @classmethod + def from_string(cls, platform: str) -> 'LLMProvider': + """从字符串创建枚举值,如果不存在则返回默认值""" + try: + return cls(platform) + except ValueError: + return cls.DEEPSEEK # 默认返回 DeepSeek \ No newline at end of file diff --git a/backend/ai/llm/factory.py b/backend/ai/llm/factory.py new file mode 100644 index 0000000..a4de606 --- /dev/null +++ b/backend/ai/llm/factory.py @@ -0,0 +1,35 @@ +from .adapter.deepseek import DeepSeekAdapter +from .adapter.genai import GoogleGenAIAdapter +from .adapter.openai import OpenAIAdapter +from .adapter.tongyi import TongYiAdapter +from .enums import LLMProvider + + +def get_adapter(provider: LLMProvider, api_key, model, **kwargs): + if provider == LLMProvider.DEEPSEEK: + return DeepSeekAdapter(api_key, model, **kwargs) + elif provider == LLMProvider.TONGYI: + return TongYiAdapter(api_key, model, **kwargs) + elif provider == LLMProvider.OPENAI: + return OpenAIAdapter(api_key, model, **kwargs) + elif provider == LLMProvider.GOOGLE_GENAI: + return GoogleGenAIAdapter(api_key, model, **kwargs) + else: + raise ValueError('不支持的服务商') + +# 使用示例 +# adapter = get_adapter(LLMProvider.TONGYI, api_key='xxx', model='wanx_v1') + +# 对话 +# try: +# result = await adapter.chat(messages) +# except NotImplementedError: +# print("该服务商不支持对话能力") + +# # 图片生成 +# try: +# task = adapter.create_image_task(prompt="一只猫") +# status = adapter.fetch_image_task_status(task) +# result = adapter.fetch_image_result(task) +# except NotImplementedError: +# print("该服务商不支持图片生成") \ No newline at end of file diff --git a/backend/ai/urls.py b/backend/ai/urls.py index 3df0452..b4a04e5 100644 --- a/backend/ai/urls.py +++ b/backend/ai/urls.py @@ -10,6 +10,7 @@ router.register(r'tool', views.ToolViewSet) router.register(r'knowledge', views.KnowledgeViewSet) router.register(r'chat_conversation', views.ChatConversationViewSet) router.register(r'chat_message', views.ChatMessageViewSet) +router.register(r'drawing', views.DrawingViewSet) urlpatterns = [ diff --git a/backend/ai/views/__init__.py b/backend/ai/views/__init__.py index 7e5fa57..e253712 100644 --- a/backend/ai/views/__init__.py +++ b/backend/ai/views/__init__.py @@ -5,6 +5,7 @@ __all__ = [ 'KnowledgeViewSet', 'ChatConversationViewSet', 'ChatMessageViewSet', + 'DrawingViewSet', ] from ai.views.ai_api_key import AIApiKeyViewSet @@ -12,4 +13,5 @@ from ai.views.ai_model import AIModelViewSet from ai.views.tool import ToolViewSet from ai.views.knowledge import KnowledgeViewSet from ai.views.chat_conversation import ChatConversationViewSet -from ai.views.chat_message import ChatMessageViewSet \ No newline at end of file +from ai.views.chat_message import ChatMessageViewSet +from ai.views.drawing import DrawingViewSet \ No newline at end of file diff --git a/backend/ai/views/drawing.py b/backend/ai/views/drawing.py new file mode 100644 index 0000000..a0554de --- /dev/null +++ b/backend/ai/views/drawing.py @@ -0,0 +1,83 @@ +import os +from datetime import datetime + +from rest_framework.response import Response + +from ai.models import Drawing +from backend import settings +from llm.enums import LLMProvider +from llm.factory import get_adapter +from utils.serializers import CustomModelSerializer +from utils.custom_model_viewSet import CustomModelViewSet +from django_filters import rest_framework as filters + + +class DrawingSerializer(CustomModelSerializer): + """ + AI 绘画表 序列化器 + """ + class Meta: + model = Drawing + fields = '__all__' + read_only_fields = ['id', 'create_time', 'update_time'] + + +class DrawingFilter(filters.FilterSet): + + class Meta: + model = Drawing + fields = ['id', 'remark', 'creator', 'modifier', 'is_deleted', 'public_status', 'platform', + 'model', 'width', 'height', 'status', 'pic_url', 'error_message', 'task_id', 'buttons'] + + +class DrawingViewSet(CustomModelViewSet): + """ + AI 绘画表 视图集 + """ + queryset = Drawing.objects.filter(is_deleted=False).order_by('-id') + serializer_class = DrawingSerializer + filterset_class = DrawingFilter + search_fields = ['name'] # 根据实际字段调整 + ordering_fields = ['create_time', 'id'] + ordering = ['-create_time'] + + def create(self, request, *args, **kwargs): + model = request.data.get('model') + prompt = request.data.get('prompt') + n = request.data.get('n', 1) + style = request.data.get('style') + size = request.data.get('size') + api_key = settings.DASHSCOPE_API_KEY + request.data['width'] = int(size.split('*')[0]) + request.data['height'] = int(size.split('*')[1]) + adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model=model) + rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size) + if rsp['status_code'] != 200: + return Response(rsp['data'], status=rsp['status_code']) + else: + request.data['status'] = rsp['output']['task_status'] + request.data['task_id'] = rsp['output']['task_id'] + return super().create(request, *args, **kwargs) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + if instance.status in ("PENDING", 'RUNNING'): + api_key = settings.DASHSCOPE_API_KEY + adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model='') + rsp = adapter.fetch_drawing_task_status(instance.task_id) + print(rsp, 'sadsadas') + if rsp['status_code'] == 200: + # 可根据 status.output.task_status 更新数据库 + if rsp['output']['task_status'] == 'SUCCEEDED': + instance.update_time = datetime.now() + instance.status = rsp['output']['task_status'] + instance.pic_url = rsp['output']['results'][0]['url'] + elif rsp['output']['task_status'] == 'FAILED': + instance.update_time = datetime.now() + instance.status = rsp['output']['task_status'] + instance.error_message = rsp['output']['message'] + elif rsp['output']['task_status'] == 'RUNNING': + instance.update_time = datetime.now() + instance.status = rsp['output']['task_status'] + instance.save() + return super().retrieve(request, *args, **kwargs) diff --git a/backend/backend/settings.py b/backend/backend/settings.py index d2c8497..a5ca487 100644 --- a/backend/backend/settings.py +++ b/backend/backend/settings.py @@ -242,5 +242,8 @@ ASGI_APPLICATION = 'backend.asgi.application' # } # } +DEEPSEEK_API_KEY = os.getenv('DEEPSEEK_API_KEY', ''), +DASHSCOPE_API_KEY = os.getenv('DASHSCOPE_API_KEY', ''), + if os.path.exists(os.path.join(BASE_DIR, 'backend/local_settings.py')): from backend.local_settings import * \ No newline at end of file diff --git a/web/apps/web-antd/src/models/ai/drawing.ts b/web/apps/web-antd/src/models/ai/drawing.ts new file mode 100644 index 0000000..7011491 --- /dev/null +++ b/web/apps/web-antd/src/models/ai/drawing.ts @@ -0,0 +1,19 @@ +import { BaseModel } from '#/models/base'; + +export namespace DrawingApi { + export interface Drawing { + id: number; + remark: string; + creator: string; + modifier: string; + update_time: string; + create_time: string; + is_deleted: boolean; + } +} + +export class DrawingModel extends BaseModel { + constructor() { + super('/ai/drawing/'); + } +} diff --git a/web/apps/web-antd/src/views/ai/drawing/index.vue b/web/apps/web-antd/src/views/ai/drawing/index.vue index 544eb41..fcc9922 100644 --- a/web/apps/web-antd/src/views/ai/drawing/index.vue +++ b/web/apps/web-antd/src/views/ai/drawing/index.vue @@ -16,13 +16,10 @@ import { Spin, } from 'ant-design-vue'; -import { - createDrawing, - getDrawingDetail, - getDrawingPage, -} from '#/api/ai/drawing'; - +import { DrawingModel } from '#/models/ai/drawing'; +// DrawingModel // 定义图片对象类型 +const operator = new DrawingModel(); interface DrawingImage { id: number; status: string; @@ -89,11 +86,7 @@ async function handleDraw() { loading.value = true; try { // 这里调用你的AI画图API,返回图片url数组 - const data = await createDrawing(form); - if (data.code !== 0) { - message.error(data.message || '生成失败'); - return; - } + await operator.create(form); page.value = 1; await fetchDrawingList(page.value, pageSize.value); // 刷新第一页图片列表 // images.value = res.data.images; @@ -105,20 +98,42 @@ async function handleDraw() { loading.value = false; } } +let drawingTimer: NodeJS.Timeout | null = null; // 轮询获取图片详情 const pollDrawingDetail = async (id: number) => { - fetchDrawingDetail(id).then((res) => { - if (res && res.data.status === 'RUNNING') { - setTimeout(() => pollDrawingDetail(id), 5000); + try { + const res = await operator.retrieve(id); // 改用 await 简化代码 + if (res?.status === 'RUNNING') { + // 保存定时器 ID,方便后续清除 + drawingTimer = setTimeout(() => pollDrawingDetail(id), 5000); + } else if ( + res?.status === '"SUCCEEDED"' && // 当状态为 "SUCCEEDED" 时,清除定时器 + drawingTimer + ) { + clearTimeout(drawingTimer); + drawingTimer = null; // 重置定时器 ID + await fetchDrawingList(); + // 可在此处添加任务成功后的其他逻辑(如刷新页面、提示用户等) } - }); + // 可根据需要添加对其他状态的处理(如 FAILED) + } catch (error) { + // 处理请求失败的情况,避免轮询异常中断 + console.error('轮询失败:', error); + if (drawingTimer) { + clearTimeout(drawingTimer); + drawingTimer = null; + } + } }; // 获取图片分页列表 async function fetchDrawingList(pageNum = 1, pageSize = 9) { try { - const res = await getDrawingPage({ page: pageNum, page_size: pageSize }); + const res = await operator.list({ + page: pageNum, + pageSize, + }); images.value = res.items; // images.value = res.items.map(item => item.pic_url); total.value = res.total; @@ -135,23 +150,6 @@ async function fetchDrawingList(pageNum = 1, pageSize = 9) { } } -// 获取图片详情 -const fetchDrawingDetail = async (id: number) => { - try { - const res = await getDrawingDetail(id); - // 更新 images 中对应项 - const idx = images.value.findIndex((item) => item.id === id); - if (idx !== -1) { - images.value[idx] = { ...images.value[idx], ...res.data }; - } - // 处理详情数据 - return res; - } catch { - message.error('获取图片详情失败'); - return null; - } -}; - // 页面加载时调用获取图片列表 onMounted(() => { fetchDrawingList();