fix: 修改LLMProvider 类型
This commit is contained in:
@@ -11,6 +11,7 @@ from langchain.chains import ConversationChain
|
|||||||
from api.v1.chat.vo import MessageVO
|
from api.v1.chat.vo import MessageVO
|
||||||
from deps.auth import get_current_user
|
from deps.auth import get_current_user
|
||||||
from llm.factory import get_adapter
|
from llm.factory import get_adapter
|
||||||
|
from llm.enums import LLMProvider
|
||||||
from services.chat_service import ChatDBService
|
from services.chat_service import ChatDBService
|
||||||
from db.session import get_db
|
from db.session import get_db
|
||||||
from models.ai import ChatConversation, ChatMessage
|
from models.ai import ChatConversation, ChatMessage
|
||||||
@@ -37,11 +38,13 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
if platform == 'tongyi':
|
if platform == 'tongyi':
|
||||||
model = 'qwen-plus'
|
model = 'qwen-plus'
|
||||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||||
|
provider = LLMProvider.TONGYI
|
||||||
else:
|
else:
|
||||||
# 默认使用 DeepSeek
|
# 默认使用 DeepSeek
|
||||||
model = 'deepseek-chat'
|
model = 'deepseek-chat'
|
||||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||||
llm = get_adapter(platform, api_key=api_key, model=model)
|
provider = LLMProvider.DEEPSEEK
|
||||||
|
llm = get_adapter(provider, api_key=api_key, model=model)
|
||||||
|
|
||||||
if not content or not isinstance(content, str):
|
if not content or not isinstance(content, str):
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -94,9 +97,11 @@ async def create_conversation(request: Request, db: Session = Depends(get_db), u
|
|||||||
platform = body.get('platform')
|
platform = body.get('platform')
|
||||||
if platform == 'tongyi':
|
if platform == 'tongyi':
|
||||||
model = 'qwen-plus'
|
model = 'qwen-plus'
|
||||||
|
# provider = LLMProvider.TONGYI
|
||||||
else:
|
else:
|
||||||
# 默认使用 DeepSeek
|
# 默认使用 DeepSeek
|
||||||
model = 'deepseek-chat'
|
model = 'deepseek-chat'
|
||||||
|
# provider = LLMProvider.DEEPSEEK
|
||||||
user_id = user["user_id"]
|
user_id = user["user_id"]
|
||||||
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
|
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
|
||||||
return resp_success(data=conversation.id)
|
return resp_success(data=conversation.id)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from api.v1.drawing.vo import CreateDrawingTaskRequest
|
|||||||
from db.session import get_db
|
from db.session import get_db
|
||||||
from deps.auth import get_current_user
|
from deps.auth import get_current_user
|
||||||
from llm.factory import get_adapter
|
from llm.factory import get_adapter
|
||||||
|
from llm.enums import LLMProvider
|
||||||
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
|
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
|
||||||
from utils.resp import resp_error, resp_success
|
from utils.resp import resp_error, resp_success
|
||||||
|
|
||||||
@@ -50,7 +51,7 @@ def api_create_image_task(
|
|||||||
prompt = req.prompt
|
prompt = req.prompt
|
||||||
model = req.model
|
model = req.model
|
||||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||||
adapter = get_adapter('tongyi', api_key=api_key, model=model)
|
adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model=model)
|
||||||
try:
|
try:
|
||||||
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
|
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
|
||||||
# rsp = {
|
# rsp = {
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from llm.enums import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
class CreateDrawingTaskRequest(BaseModel):
|
class CreateDrawingTaskRequest(BaseModel):
|
||||||
@@ -6,5 +7,5 @@ class CreateDrawingTaskRequest(BaseModel):
|
|||||||
style: str = 'auto'
|
style: str = 'auto'
|
||||||
size: str = '1024*1024'
|
size: str = '1024*1024'
|
||||||
model: str = 'wanx_v1'
|
model: str = 'wanx_v1'
|
||||||
platform: str = 'tongyi'
|
platform: str = LLMProvider.TONGYI
|
||||||
n: int = 1
|
n: int = 1
|
||||||
|
|||||||
32
ai_service/llm/enums.py
Normal file
32
ai_service/llm/enums.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProvider(str, Enum):
|
||||||
|
"""LLM 提供商枚举"""
|
||||||
|
DEEPSEEK = "deepseek"
|
||||||
|
TONGYI = "tongyi"
|
||||||
|
OPENAI = "openai"
|
||||||
|
GOOGLE_GENAI = "google-genai"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_by_platform(cls, platform: str) -> tuple[str, str]:
|
||||||
|
"""根据平台名称获取对应的模型和API密钥环境变量名"""
|
||||||
|
if platform == cls.TONGYI:
|
||||||
|
return 'qwen-plus', 'DASHSCOPE_API_KEY'
|
||||||
|
elif platform == cls.DEEPSEEK:
|
||||||
|
return 'deepseek-chat', 'DEEPSEEK_API_KEY'
|
||||||
|
elif platform == cls.OPENAI:
|
||||||
|
return 'gpt-3.5-turbo', 'OPENAI_API_KEY'
|
||||||
|
elif platform == cls.GOOGLE_GENAI:
|
||||||
|
return 'gemini-pro', 'GOOGLE_API_KEY'
|
||||||
|
else:
|
||||||
|
# 默认使用 DeepSeek
|
||||||
|
return 'deepseek-chat', 'DEEPSEEK_API_KEY'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, platform: str) -> 'LLMProvider':
|
||||||
|
"""从字符串创建枚举值,如果不存在则返回默认值"""
|
||||||
|
try:
|
||||||
|
return cls(platform)
|
||||||
|
except ValueError:
|
||||||
|
return cls.DEEPSEEK # 默认返回 DeepSeek
|
||||||
@@ -2,22 +2,23 @@ from .adapter.deepseek import DeepSeekAdapter
|
|||||||
from .adapter.genai import GoogleGenAIAdapter
|
from .adapter.genai import GoogleGenAIAdapter
|
||||||
from .adapter.openai import OpenAIAdapter
|
from .adapter.openai import OpenAIAdapter
|
||||||
from .adapter.tongyi import TongYiAdapter
|
from .adapter.tongyi import TongYiAdapter
|
||||||
|
from .enums import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
def get_adapter(provider, api_key, model, **kwargs):
|
def get_adapter(provider: LLMProvider, api_key, model, **kwargs):
|
||||||
if provider == 'deepseek':
|
if provider == LLMProvider.DEEPSEEK:
|
||||||
return DeepSeekAdapter(api_key, model, **kwargs)
|
return DeepSeekAdapter(api_key, model, **kwargs)
|
||||||
elif provider == 'tongyi':
|
elif provider == LLMProvider.TONGYI:
|
||||||
return TongYiAdapter(api_key, model, **kwargs)
|
return TongYiAdapter(api_key, model, **kwargs)
|
||||||
elif provider == 'openai':
|
elif provider == LLMProvider.OPENAI:
|
||||||
return OpenAIAdapter(api_key, model, **kwargs)
|
return OpenAIAdapter(api_key, model, **kwargs)
|
||||||
elif provider == 'google-genai':
|
elif provider == LLMProvider.GOOGLE_GENAI:
|
||||||
return GoogleGenAIAdapter(api_key, model, **kwargs)
|
return GoogleGenAIAdapter(api_key, model, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError('不支持的服务商')
|
raise ValueError('不支持的服务商')
|
||||||
|
|
||||||
# 使用示例
|
# 使用示例
|
||||||
# adapter = get_adapter('tongyi', api_key='xxx', model='wanx_v1')
|
# adapter = get_adapter(LLMProvider.TONGYI, api_key='xxx', model='wanx_v1')
|
||||||
|
|
||||||
# 对话
|
# 对话
|
||||||
# try:
|
# try:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from sqlalchemy.orm import relationship, declarative_base
|
|||||||
from db.session import Base
|
from db.session import Base
|
||||||
from models.base import CoreModel
|
from models.base import CoreModel
|
||||||
from models.user import DjangoUser # 确保导入 DjangoUser
|
from models.user import DjangoUser # 确保导入 DjangoUser
|
||||||
|
from llm.enums import LLMProvider
|
||||||
|
|
||||||
# 状态选择类(示例)
|
# 状态选择类(示例)
|
||||||
class CommonStatus:
|
class CommonStatus:
|
||||||
@@ -19,12 +20,12 @@ class CommonStatus:
|
|||||||
|
|
||||||
# 平台选择类(示例)
|
# 平台选择类(示例)
|
||||||
class PlatformChoices:
|
class PlatformChoices:
|
||||||
OPENAI = 'openai'
|
OPENAI = LLMProvider.OPENAI
|
||||||
ALIMNS = 'alimns'
|
ALIMNS = 'alimns'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def choices():
|
def choices():
|
||||||
return [('openai', 'OpenAI'), ('alimns', '阿里云MNS')]
|
return [(LLMProvider.OPENAI, 'OpenAI'), ('alimns', '阿里云MNS')]
|
||||||
|
|
||||||
|
|
||||||
# 消息类型选择类(示例)
|
# 消息类型选择类(示例)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from http import HTTPStatus
|
|||||||
from sqlalchemy import desc
|
from sqlalchemy import desc
|
||||||
|
|
||||||
from llm.factory import get_adapter
|
from llm.factory import get_adapter
|
||||||
|
from llm.enums import LLMProvider
|
||||||
from models.ai import Drawing
|
from models.ai import Drawing
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -40,7 +41,7 @@ def fetch_drawing_task_status(db: Session, drawing_id: int):
|
|||||||
return None, "任务不存在"
|
return None, "任务不存在"
|
||||||
if drawing.status in ("PENDING", 'RUNNING'):
|
if drawing.status in ("PENDING", 'RUNNING'):
|
||||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||||
adapter = get_adapter('tongyi', api_key=api_key, model='')
|
adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model='')
|
||||||
rsp = adapter.fetch_drawing_task_status(drawing.task_id)
|
rsp = adapter.fetch_drawing_task_status(drawing.task_id)
|
||||||
if rsp['status_code'] == HTTPStatus.OK:
|
if rsp['status_code'] == HTTPStatus.OK:
|
||||||
# 可根据 status.output.task_status 更新数据库
|
# 可根据 status.output.task_status 更新数据库
|
||||||
|
|||||||
Reference in New Issue
Block a user