feat: 修改drawing 为django 接口

This commit is contained in:
XIE7654
2025-10-31 22:16:55 +08:00
parent e4aa6ad18c
commit e5ec2fec56
15 changed files with 401 additions and 34 deletions

View File

@@ -0,0 +1,50 @@
import os
from enum import Enum
from pydantic import SecretStr
class ProviderEnum(str, Enum):
"""支持的 LLM 服务商"""
DEEPSEEK = "deepseek"
OPENAI = "openai"
TONGYI = "tongyi"
class LLMFactory(object):
@staticmethod
def get_llm(provider: ProviderEnum, model: str = None, **kwargs):
if provider == ProviderEnum.DEEPSEEK:
from langchain_deepseek import ChatDeepSeek
api_key = os.getenv("DEEPSEEK_API_KEY")
model = model or "deepseek-chat"
return ChatDeepSeek(
api_key=SecretStr(api_key),
model=model,
streaming=True,
**kwargs
)
elif provider == ProviderEnum.OPENAI:
from langchain_openai import ChatOpenAI
api_key = os.getenv("OPENAI_API_KEY")
model = model or "gpt-3.5-turbo"
return ChatOpenAI(
api_key=SecretStr(api_key),
model=model,
streaming=True,
**kwargs
)
elif provider == ProviderEnum.TONGYI:
from langchain_community.llms import Tongyi
api_key = os.getenv("DASHSCOPE_API_KEY")
model = model or "qwen-turbo"
return Tongyi(
api_key=SecretStr(api_key),
model=model,
streaming=True,
**kwargs
)
else:
raise ValueError(f"不支持的 LLM 服务商: {provider}")

View File

View File

@@ -0,0 +1,17 @@
from langchain_deepseek import ChatDeepSeek
from llm.base import MultiModalAICapability
class DeepSeekAdapter(MultiModalAICapability):
def __init__(self, api_key, model, **kwargs):
self.llm = ChatDeepSeek(api_key=api_key, model=model, streaming=True)
async def chat(self, messages, **kwargs):
# 兼容 DeepSeek 的调用方式
return await self.llm.ainvoke(messages)
async def stream_chat(self, messages, **kwargs):
async for chunk in self.llm.astream(messages):
yield chunk

View File

@@ -0,0 +1,21 @@
# 假设有 google genai sdk
# from google_genai import GenAI
from llm.base import MultiModalAICapability
class GoogleGenAIAdapter(MultiModalAICapability):
def __init__(self, api_key, model, **kwargs):
self.api_key = api_key
self.model = model
# self.llm = GenAI(api_key=api_key, model=model)
async def chat(self, messages, **kwargs):
# return await self.llm.chat(messages)
raise NotImplementedError("Google GenAI chat未实现")
async def stream_chat(self, messages, **kwargs):
# async for chunk in self.llm.stream_chat(messages):
# yield chunk
raise NotImplementedError("Google GenAI stream_chat未实现")
# 其他能力同理

View File

@@ -0,0 +1,25 @@
from llm.base import MultiModalAICapability
from langchain_openai import ChatOpenAI
# from openai import OpenAI # 如需图片/音频/视频等API
class OpenAIAdapter(MultiModalAICapability):
def __init__(self, api_key, model, **kwargs):
self.llm = ChatOpenAI(api_key=api_key, model=model, streaming=True)
self.api_key = api_key
async def chat(self, messages, **kwargs):
return await self.llm.ainvoke(messages)
async def stream_chat(self, messages, **kwargs):
async for chunk in self.llm.astream(messages):
yield chunk
# 如需图片生成DALL·E可实现如下
def create_drawing_task(self, **kwargs):
# 伪代码,需用 openai.Image.create
# import openai
# response = openai.Image.create(api_key=self.api_key, prompt=prompt, ...)
# return response
raise NotImplementedError("OpenAI 图片生成请用 openai.Image.create 实现")
# 其他能力同理

View File

@@ -0,0 +1,44 @@
from langchain_community.chat_models import ChatTongyi
from http import HTTPStatus
from urllib.parse import urlparse, unquote
from pathlib import PurePosixPath
import requests
from dashscope import ImageSynthesis
import os
from llm.base import MultiModalAICapability
class TongYiAdapter(MultiModalAICapability):
def __init__(self, api_key, model, **kwargs):
self.api_key = api_key
self.model = model
self.llm = ChatTongyi(api_key=api_key, model=model, streaming=True)
async def chat(self, messages, **kwargs):
# 兼容 DeepSeek 的调用方式
return await self.llm.ainvoke(messages)
async def stream_chat(self, messages, **kwargs):
async for chunk in self.llm.astream(messages):
yield chunk
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
"""创建异步图片生成任务"""
rsp = ImageSynthesis.async_call(
api_key=self.api_key,
model=self.model,
prompt=prompt,
n=n,
style=f'<{style}>',
size=size
)
return rsp
def fetch_drawing_task_status(self, task):
"""获取异步图片任务状态"""
rsp = ImageSynthesis.fetch(task, api_key=self.api_key)
return rsp

37
backend/ai/llm/base.py Normal file
View File

@@ -0,0 +1,37 @@
from abc import ABC
class MultiModalAICapability(ABC):
# 对话能力
async def chat(self, messages, **kwargs):
raise NotImplementedError("chat not supported by this provider")
async def stream_chat(self, messages, **kwargs):
raise NotImplementedError("stream_chat not supported by this provider")
# 图片生成能力
def create_drawing_task(self, prompt: str, style='watercolor', size='1024*1024', n=1, **kwargs):
raise NotImplementedError("drawing generation not supported by this provider")
def fetch_drawing_task_status(self, task):
raise NotImplementedError("drawing task status not supported by this provider")
def fetch_drawing_result(self, task):
raise NotImplementedError("drawing result not supported by this provider")
# 视频生成能力
def create_video_task(self, prompt, **kwargs):
raise NotImplementedError("video generation not supported by this provider")
def fetch_video_task_status(self, task):
raise NotImplementedError("video task status not supported by this provider")
def fetch_video_result(self, task):
raise NotImplementedError("video result not supported by this provider")
# 知识库能力
def query_knowledge(self, query, **kwargs):
raise NotImplementedError("knowledge query not supported by this provider")
# 语音合成能力
def synthesize_speech(self, text, **kwargs):
raise NotImplementedError("speech synthesis not supported by this provider")

32
backend/ai/llm/enums.py Normal file
View File

@@ -0,0 +1,32 @@
from enum import Enum
class LLMProvider(str, Enum):
"""LLM 提供商枚举"""
DEEPSEEK = "deepseek"
TONGYI = "tongyi"
OPENAI = "openai"
GOOGLE_GENAI = "google-genai"
@classmethod
def get_model_by_platform(cls, platform: str) -> tuple[str, str]:
"""根据平台名称获取对应的模型和API密钥环境变量名"""
if platform == cls.TONGYI:
return 'qwen-plus', 'DASHSCOPE_API_KEY'
elif platform == cls.DEEPSEEK:
return 'deepseek-chat', 'DEEPSEEK_API_KEY'
elif platform == cls.OPENAI:
return 'gpt-3.5-turbo', 'OPENAI_API_KEY'
elif platform == cls.GOOGLE_GENAI:
return 'gemini-pro', 'GOOGLE_API_KEY'
else:
# 默认使用 DeepSeek
return 'deepseek-chat', 'DEEPSEEK_API_KEY'
@classmethod
def from_string(cls, platform: str) -> 'LLMProvider':
"""从字符串创建枚举值,如果不存在则返回默认值"""
try:
return cls(platform)
except ValueError:
return cls.DEEPSEEK # 默认返回 DeepSeek

35
backend/ai/llm/factory.py Normal file
View File

@@ -0,0 +1,35 @@
from .adapter.deepseek import DeepSeekAdapter
from .adapter.genai import GoogleGenAIAdapter
from .adapter.openai import OpenAIAdapter
from .adapter.tongyi import TongYiAdapter
from .enums import LLMProvider
def get_adapter(provider: LLMProvider, api_key, model, **kwargs):
if provider == LLMProvider.DEEPSEEK:
return DeepSeekAdapter(api_key, model, **kwargs)
elif provider == LLMProvider.TONGYI:
return TongYiAdapter(api_key, model, **kwargs)
elif provider == LLMProvider.OPENAI:
return OpenAIAdapter(api_key, model, **kwargs)
elif provider == LLMProvider.GOOGLE_GENAI:
return GoogleGenAIAdapter(api_key, model, **kwargs)
else:
raise ValueError('不支持的服务商')
# 使用示例
# adapter = get_adapter(LLMProvider.TONGYI, api_key='xxx', model='wanx_v1')
# 对话
# try:
# result = await adapter.chat(messages)
# except NotImplementedError:
# print("该服务商不支持对话能力")
# # 图片生成
# try:
# task = adapter.create_image_task(prompt="一只猫")
# status = adapter.fetch_image_task_status(task)
# result = adapter.fetch_image_result(task)
# except NotImplementedError:
# print("该服务商不支持图片生成")

View File

@@ -10,6 +10,7 @@ router.register(r'tool', views.ToolViewSet)
router.register(r'knowledge', views.KnowledgeViewSet)
router.register(r'chat_conversation', views.ChatConversationViewSet)
router.register(r'chat_message', views.ChatMessageViewSet)
router.register(r'drawing', views.DrawingViewSet)
urlpatterns = [

View File

@@ -5,6 +5,7 @@ __all__ = [
'KnowledgeViewSet',
'ChatConversationViewSet',
'ChatMessageViewSet',
'DrawingViewSet',
]
from ai.views.ai_api_key import AIApiKeyViewSet
@@ -12,4 +13,5 @@ from ai.views.ai_model import AIModelViewSet
from ai.views.tool import ToolViewSet
from ai.views.knowledge import KnowledgeViewSet
from ai.views.chat_conversation import ChatConversationViewSet
from ai.views.chat_message import ChatMessageViewSet
from ai.views.chat_message import ChatMessageViewSet
from ai.views.drawing import DrawingViewSet

View File

@@ -0,0 +1,83 @@
import os
from datetime import datetime
from rest_framework.response import Response
from ai.models import Drawing
from backend import settings
from llm.enums import LLMProvider
from llm.factory import get_adapter
from utils.serializers import CustomModelSerializer
from utils.custom_model_viewSet import CustomModelViewSet
from django_filters import rest_framework as filters
class DrawingSerializer(CustomModelSerializer):
"""
AI 绘画表 序列化器
"""
class Meta:
model = Drawing
fields = '__all__'
read_only_fields = ['id', 'create_time', 'update_time']
class DrawingFilter(filters.FilterSet):
class Meta:
model = Drawing
fields = ['id', 'remark', 'creator', 'modifier', 'is_deleted', 'public_status', 'platform',
'model', 'width', 'height', 'status', 'pic_url', 'error_message', 'task_id', 'buttons']
class DrawingViewSet(CustomModelViewSet):
"""
AI 绘画表 视图集
"""
queryset = Drawing.objects.filter(is_deleted=False).order_by('-id')
serializer_class = DrawingSerializer
filterset_class = DrawingFilter
search_fields = ['name'] # 根据实际字段调整
ordering_fields = ['create_time', 'id']
ordering = ['-create_time']
def create(self, request, *args, **kwargs):
model = request.data.get('model')
prompt = request.data.get('prompt')
n = request.data.get('n', 1)
style = request.data.get('style')
size = request.data.get('size')
api_key = settings.DASHSCOPE_API_KEY
request.data['width'] = int(size.split('*')[0])
request.data['height'] = int(size.split('*')[1])
adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model=model)
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
if rsp['status_code'] != 200:
return Response(rsp['data'], status=rsp['status_code'])
else:
request.data['status'] = rsp['output']['task_status']
request.data['task_id'] = rsp['output']['task_id']
return super().create(request, *args, **kwargs)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
if instance.status in ("PENDING", 'RUNNING'):
api_key = settings.DASHSCOPE_API_KEY
adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model='')
rsp = adapter.fetch_drawing_task_status(instance.task_id)
print(rsp, 'sadsadas')
if rsp['status_code'] == 200:
# 可根据 status.output.task_status 更新数据库
if rsp['output']['task_status'] == 'SUCCEEDED':
instance.update_time = datetime.now()
instance.status = rsp['output']['task_status']
instance.pic_url = rsp['output']['results'][0]['url']
elif rsp['output']['task_status'] == 'FAILED':
instance.update_time = datetime.now()
instance.status = rsp['output']['task_status']
instance.error_message = rsp['output']['message']
elif rsp['output']['task_status'] == 'RUNNING':
instance.update_time = datetime.now()
instance.status = rsp['output']['task_status']
instance.save()
return super().retrieve(request, *args, **kwargs)

View File

@@ -242,5 +242,8 @@ ASGI_APPLICATION = 'backend.asgi.application'
# }
# }
DEEPSEEK_API_KEY = os.getenv('DEEPSEEK_API_KEY', ''),
DASHSCOPE_API_KEY = os.getenv('DASHSCOPE_API_KEY', ''),
if os.path.exists(os.path.join(BASE_DIR, 'backend/local_settings.py')):
from backend.local_settings import *