fix: 修改LLMProvider 类型

This commit is contained in:
XIE7654
2025-08-11 10:26:42 +08:00
parent e21a1ac716
commit a88f272c19
7 changed files with 54 additions and 12 deletions

View File

@@ -11,6 +11,7 @@ from langchain.chains import ConversationChain
from api.v1.chat.vo import MessageVO from api.v1.chat.vo import MessageVO
from deps.auth import get_current_user from deps.auth import get_current_user
from llm.factory import get_adapter from llm.factory import get_adapter
from llm.enums import LLMProvider
from services.chat_service import ChatDBService from services.chat_service import ChatDBService
from db.session import get_db from db.session import get_db
from models.ai import ChatConversation, ChatMessage from models.ai import ChatConversation, ChatMessage
@@ -37,11 +38,13 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
if platform == 'tongyi': if platform == 'tongyi':
model = 'qwen-plus' model = 'qwen-plus'
api_key = os.getenv("DASHSCOPE_API_KEY") api_key = os.getenv("DASHSCOPE_API_KEY")
provider = LLMProvider.TONGYI
else: else:
# 默认使用 DeepSeek # 默认使用 DeepSeek
model = 'deepseek-chat' model = 'deepseek-chat'
api_key = os.getenv("DEEPSEEK_API_KEY") api_key = os.getenv("DEEPSEEK_API_KEY")
llm = get_adapter(platform, api_key=api_key, model=model) provider = LLMProvider.DEEPSEEK
llm = get_adapter(provider, api_key=api_key, model=model)
if not content or not isinstance(content, str): if not content or not isinstance(content, str):
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@@ -94,9 +97,11 @@ async def create_conversation(request: Request, db: Session = Depends(get_db), u
platform = body.get('platform') platform = body.get('platform')
if platform == 'tongyi': if platform == 'tongyi':
model = 'qwen-plus' model = 'qwen-plus'
# provider = LLMProvider.TONGYI
else: else:
# 默认使用 DeepSeek # 默认使用 DeepSeek
model = 'deepseek-chat' model = 'deepseek-chat'
# provider = LLMProvider.DEEPSEEK
user_id = user["user_id"] user_id = user["user_id"]
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话') conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
return resp_success(data=conversation.id) return resp_success(data=conversation.id)

View File

@@ -8,6 +8,7 @@ from api.v1.drawing.vo import CreateDrawingTaskRequest
from db.session import get_db from db.session import get_db
from deps.auth import get_current_user from deps.auth import get_current_user
from llm.factory import get_adapter from llm.factory import get_adapter
from llm.enums import LLMProvider
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
from utils.resp import resp_error, resp_success from utils.resp import resp_error, resp_success
@@ -50,7 +51,7 @@ def api_create_image_task(
prompt = req.prompt prompt = req.prompt
model = req.model model = req.model
api_key = os.getenv("DASHSCOPE_API_KEY") api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model=model) adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model=model)
try: try:
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size) rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
# rsp = { # rsp = {

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel from pydantic import BaseModel
from llm.enums import LLMProvider
class CreateDrawingTaskRequest(BaseModel): class CreateDrawingTaskRequest(BaseModel):
@@ -6,5 +7,5 @@ class CreateDrawingTaskRequest(BaseModel):
style: str = 'auto' style: str = 'auto'
size: str = '1024*1024' size: str = '1024*1024'
model: str = 'wanx_v1' model: str = 'wanx_v1'
platform: str = 'tongyi' platform: str = LLMProvider.TONGYI
n: int = 1 n: int = 1

32
ai_service/llm/enums.py Normal file
View File

@@ -0,0 +1,32 @@
from enum import Enum
class LLMProvider(str, Enum):
"""LLM 提供商枚举"""
DEEPSEEK = "deepseek"
TONGYI = "tongyi"
OPENAI = "openai"
GOOGLE_GENAI = "google-genai"
@classmethod
def get_model_by_platform(cls, platform: str) -> tuple[str, str]:
"""根据平台名称获取对应的模型和API密钥环境变量名"""
if platform == cls.TONGYI:
return 'qwen-plus', 'DASHSCOPE_API_KEY'
elif platform == cls.DEEPSEEK:
return 'deepseek-chat', 'DEEPSEEK_API_KEY'
elif platform == cls.OPENAI:
return 'gpt-3.5-turbo', 'OPENAI_API_KEY'
elif platform == cls.GOOGLE_GENAI:
return 'gemini-pro', 'GOOGLE_API_KEY'
else:
# 默认使用 DeepSeek
return 'deepseek-chat', 'DEEPSEEK_API_KEY'
@classmethod
def from_string(cls, platform: str) -> 'LLMProvider':
"""从字符串创建枚举值,如果不存在则返回默认值"""
try:
return cls(platform)
except ValueError:
return cls.DEEPSEEK # 默认返回 DeepSeek

View File

@@ -2,22 +2,23 @@ from .adapter.deepseek import DeepSeekAdapter
from .adapter.genai import GoogleGenAIAdapter from .adapter.genai import GoogleGenAIAdapter
from .adapter.openai import OpenAIAdapter from .adapter.openai import OpenAIAdapter
from .adapter.tongyi import TongYiAdapter from .adapter.tongyi import TongYiAdapter
from .enums import LLMProvider
def get_adapter(provider, api_key, model, **kwargs): def get_adapter(provider: LLMProvider, api_key, model, **kwargs):
if provider == 'deepseek': if provider == LLMProvider.DEEPSEEK:
return DeepSeekAdapter(api_key, model, **kwargs) return DeepSeekAdapter(api_key, model, **kwargs)
elif provider == 'tongyi': elif provider == LLMProvider.TONGYI:
return TongYiAdapter(api_key, model, **kwargs) return TongYiAdapter(api_key, model, **kwargs)
elif provider == 'openai': elif provider == LLMProvider.OPENAI:
return OpenAIAdapter(api_key, model, **kwargs) return OpenAIAdapter(api_key, model, **kwargs)
elif provider == 'google-genai': elif provider == LLMProvider.GOOGLE_GENAI:
return GoogleGenAIAdapter(api_key, model, **kwargs) return GoogleGenAIAdapter(api_key, model, **kwargs)
else: else:
raise ValueError('不支持的服务商') raise ValueError('不支持的服务商')
# 使用示例 # 使用示例
# adapter = get_adapter('tongyi', api_key='xxx', model='wanx_v1') # adapter = get_adapter(LLMProvider.TONGYI, api_key='xxx', model='wanx_v1')
# 对话 # 对话
# try: # try:

View File

@@ -6,6 +6,7 @@ from sqlalchemy.orm import relationship, declarative_base
from db.session import Base from db.session import Base
from models.base import CoreModel from models.base import CoreModel
from models.user import DjangoUser # 确保导入 DjangoUser from models.user import DjangoUser # 确保导入 DjangoUser
from llm.enums import LLMProvider
# 状态选择类(示例) # 状态选择类(示例)
class CommonStatus: class CommonStatus:
@@ -19,12 +20,12 @@ class CommonStatus:
# 平台选择类(示例) # 平台选择类(示例)
class PlatformChoices: class PlatformChoices:
OPENAI = 'openai' OPENAI = LLMProvider.OPENAI
ALIMNS = 'alimns' ALIMNS = 'alimns'
@staticmethod @staticmethod
def choices(): def choices():
return [('openai', 'OpenAI'), ('alimns', '阿里云MNS')] return [(LLMProvider.OPENAI, 'OpenAI'), ('alimns', '阿里云MNS')]
# 消息类型选择类(示例) # 消息类型选择类(示例)

View File

@@ -7,6 +7,7 @@ from http import HTTPStatus
from sqlalchemy import desc from sqlalchemy import desc
from llm.factory import get_adapter from llm.factory import get_adapter
from llm.enums import LLMProvider
from models.ai import Drawing from models.ai import Drawing
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -40,7 +41,7 @@ def fetch_drawing_task_status(db: Session, drawing_id: int):
return None, "任务不存在" return None, "任务不存在"
if drawing.status in ("PENDING", 'RUNNING'): if drawing.status in ("PENDING", 'RUNNING'):
api_key = os.getenv("DASHSCOPE_API_KEY") api_key = os.getenv("DASHSCOPE_API_KEY")
adapter = get_adapter('tongyi', api_key=api_key, model='') adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model='')
rsp = adapter.fetch_drawing_task_status(drawing.task_id) rsp = adapter.fetch_drawing_task_status(drawing.task_id)
if rsp['status_code'] == HTTPStatus.OK: if rsp['status_code'] == HTTPStatus.OK:
# 可根据 status.output.task_status 更新数据库 # 可根据 status.output.task_status 更新数据库