From a88f272c197939029b67052e2aeb5ed9d042d768 Mon Sep 17 00:00:00 2001 From: XIE7654 <765462425@qq.com> Date: Mon, 11 Aug 2025 10:26:42 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9LLMProvider=20?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai_service/api/v1/chat/__init__.py | 7 +++++- ai_service/api/v1/drawing/__init__.py | 3 ++- ai_service/api/v1/drawing/vo.py | 3 ++- ai_service/llm/enums.py | 32 ++++++++++++++++++++++++++ ai_service/llm/factory.py | 13 ++++++----- ai_service/models/ai.py | 5 ++-- ai_service/services/drawing_service.py | 3 ++- 7 files changed, 54 insertions(+), 12 deletions(-) create mode 100644 ai_service/llm/enums.py diff --git a/ai_service/api/v1/chat/__init__.py b/ai_service/api/v1/chat/__init__.py index 1a7c940..d06b9fa 100644 --- a/ai_service/api/v1/chat/__init__.py +++ b/ai_service/api/v1/chat/__init__.py @@ -11,6 +11,7 @@ from langchain.chains import ConversationChain from api.v1.chat.vo import MessageVO from deps.auth import get_current_user from llm.factory import get_adapter +from llm.enums import LLMProvider from services.chat_service import ChatDBService from db.session import get_db from models.ai import ChatConversation, ChatMessage @@ -37,11 +38,13 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess if platform == 'tongyi': model = 'qwen-plus' api_key = os.getenv("DASHSCOPE_API_KEY") + provider = LLMProvider.TONGYI else: # 默认使用 DeepSeek model = 'deepseek-chat' api_key = os.getenv("DEEPSEEK_API_KEY") - llm = get_adapter(platform, api_key=api_key, model=model) + provider = LLMProvider.DEEPSEEK + llm = get_adapter(provider, api_key=api_key, model=model) if not content or not isinstance(content, str): from fastapi.responses import JSONResponse @@ -94,9 +97,11 @@ async def create_conversation(request: Request, db: Session = Depends(get_db), u platform = body.get('platform') if platform == 'tongyi': model = 'qwen-plus' + # provider = LLMProvider.TONGYI else: # 默认使用 DeepSeek model = 'deepseek-chat' + # provider = LLMProvider.DEEPSEEK user_id = user["user_id"] conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话') return resp_success(data=conversation.id) diff --git a/ai_service/api/v1/drawing/__init__.py b/ai_service/api/v1/drawing/__init__.py index 5706182..af69092 100644 --- a/ai_service/api/v1/drawing/__init__.py +++ b/ai_service/api/v1/drawing/__init__.py @@ -8,6 +8,7 @@ from api.v1.drawing.vo import CreateDrawingTaskRequest from db.session import get_db from deps.auth import get_current_user from llm.factory import get_adapter +from llm.enums import LLMProvider from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status from utils.resp import resp_error, resp_success @@ -50,7 +51,7 @@ def api_create_image_task( prompt = req.prompt model = req.model api_key = os.getenv("DASHSCOPE_API_KEY") - adapter = get_adapter('tongyi', api_key=api_key, model=model) + adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model=model) try: rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size) # rsp = { diff --git a/ai_service/api/v1/drawing/vo.py b/ai_service/api/v1/drawing/vo.py index e1e0d62..5b379b4 100644 --- a/ai_service/api/v1/drawing/vo.py +++ b/ai_service/api/v1/drawing/vo.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +from llm.enums import LLMProvider class CreateDrawingTaskRequest(BaseModel): @@ -6,5 +7,5 @@ class CreateDrawingTaskRequest(BaseModel): style: str = 'auto' size: str = '1024*1024' model: str = 'wanx_v1' - platform: str = 'tongyi' + platform: str = LLMProvider.TONGYI n: int = 1 diff --git a/ai_service/llm/enums.py b/ai_service/llm/enums.py new file mode 100644 index 0000000..5bb0c3c --- /dev/null +++ b/ai_service/llm/enums.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class LLMProvider(str, Enum): + """LLM 提供商枚举""" + DEEPSEEK = "deepseek" + TONGYI = "tongyi" + OPENAI = "openai" + GOOGLE_GENAI = "google-genai" + + @classmethod + def get_model_by_platform(cls, platform: str) -> tuple[str, str]: + """根据平台名称获取对应的模型和API密钥环境变量名""" + if platform == cls.TONGYI: + return 'qwen-plus', 'DASHSCOPE_API_KEY' + elif platform == cls.DEEPSEEK: + return 'deepseek-chat', 'DEEPSEEK_API_KEY' + elif platform == cls.OPENAI: + return 'gpt-3.5-turbo', 'OPENAI_API_KEY' + elif platform == cls.GOOGLE_GENAI: + return 'gemini-pro', 'GOOGLE_API_KEY' + else: + # 默认使用 DeepSeek + return 'deepseek-chat', 'DEEPSEEK_API_KEY' + + @classmethod + def from_string(cls, platform: str) -> 'LLMProvider': + """从字符串创建枚举值,如果不存在则返回默认值""" + try: + return cls(platform) + except ValueError: + return cls.DEEPSEEK # 默认返回 DeepSeek \ No newline at end of file diff --git a/ai_service/llm/factory.py b/ai_service/llm/factory.py index 0fce39f..a4de606 100644 --- a/ai_service/llm/factory.py +++ b/ai_service/llm/factory.py @@ -2,22 +2,23 @@ from .adapter.deepseek import DeepSeekAdapter from .adapter.genai import GoogleGenAIAdapter from .adapter.openai import OpenAIAdapter from .adapter.tongyi import TongYiAdapter +from .enums import LLMProvider -def get_adapter(provider, api_key, model, **kwargs): - if provider == 'deepseek': +def get_adapter(provider: LLMProvider, api_key, model, **kwargs): + if provider == LLMProvider.DEEPSEEK: return DeepSeekAdapter(api_key, model, **kwargs) - elif provider == 'tongyi': + elif provider == LLMProvider.TONGYI: return TongYiAdapter(api_key, model, **kwargs) - elif provider == 'openai': + elif provider == LLMProvider.OPENAI: return OpenAIAdapter(api_key, model, **kwargs) - elif provider == 'google-genai': + elif provider == LLMProvider.GOOGLE_GENAI: return GoogleGenAIAdapter(api_key, model, **kwargs) else: raise ValueError('不支持的服务商') # 使用示例 -# adapter = get_adapter('tongyi', api_key='xxx', model='wanx_v1') +# adapter = get_adapter(LLMProvider.TONGYI, api_key='xxx', model='wanx_v1') # 对话 # try: diff --git a/ai_service/models/ai.py b/ai_service/models/ai.py index 8aad029..2509a04 100644 --- a/ai_service/models/ai.py +++ b/ai_service/models/ai.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import relationship, declarative_base from db.session import Base from models.base import CoreModel from models.user import DjangoUser # 确保导入 DjangoUser +from llm.enums import LLMProvider # 状态选择类(示例) class CommonStatus: @@ -19,12 +20,12 @@ class CommonStatus: # 平台选择类(示例) class PlatformChoices: - OPENAI = 'openai' + OPENAI = LLMProvider.OPENAI ALIMNS = 'alimns' @staticmethod def choices(): - return [('openai', 'OpenAI'), ('alimns', '阿里云MNS')] + return [(LLMProvider.OPENAI, 'OpenAI'), ('alimns', '阿里云MNS')] # 消息类型选择类(示例) diff --git a/ai_service/services/drawing_service.py b/ai_service/services/drawing_service.py index 2efd0db..4d4b2b0 100644 --- a/ai_service/services/drawing_service.py +++ b/ai_service/services/drawing_service.py @@ -7,6 +7,7 @@ from http import HTTPStatus from sqlalchemy import desc from llm.factory import get_adapter +from llm.enums import LLMProvider from models.ai import Drawing from sqlalchemy.orm import Session @@ -40,7 +41,7 @@ def fetch_drawing_task_status(db: Session, drawing_id: int): return None, "任务不存在" if drawing.status in ("PENDING", 'RUNNING'): api_key = os.getenv("DASHSCOPE_API_KEY") - adapter = get_adapter('tongyi', api_key=api_key, model='') + adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model='') rsp = adapter.fetch_drawing_task_status(drawing.task_id) if rsp['status_code'] == HTTPStatus.OK: # 可根据 status.output.task_status 更新数据库