fix: 修改LLMProvider 类型
This commit is contained in:
@@ -11,6 +11,7 @@ from langchain.chains import ConversationChain
|
||||
from api.v1.chat.vo import MessageVO
|
||||
from deps.auth import get_current_user
|
||||
from llm.factory import get_adapter
|
||||
from llm.enums import LLMProvider
|
||||
from services.chat_service import ChatDBService
|
||||
from db.session import get_db
|
||||
from models.ai import ChatConversation, ChatMessage
|
||||
@@ -37,11 +38,13 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
||||
if platform == 'tongyi':
|
||||
model = 'qwen-plus'
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
provider = LLMProvider.TONGYI
|
||||
else:
|
||||
# 默认使用 DeepSeek
|
||||
model = 'deepseek-chat'
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
llm = get_adapter(platform, api_key=api_key, model=model)
|
||||
provider = LLMProvider.DEEPSEEK
|
||||
llm = get_adapter(provider, api_key=api_key, model=model)
|
||||
|
||||
if not content or not isinstance(content, str):
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -94,9 +97,11 @@ async def create_conversation(request: Request, db: Session = Depends(get_db), u
|
||||
platform = body.get('platform')
|
||||
if platform == 'tongyi':
|
||||
model = 'qwen-plus'
|
||||
# provider = LLMProvider.TONGYI
|
||||
else:
|
||||
# 默认使用 DeepSeek
|
||||
model = 'deepseek-chat'
|
||||
# provider = LLMProvider.DEEPSEEK
|
||||
user_id = user["user_id"]
|
||||
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
|
||||
return resp_success(data=conversation.id)
|
||||
|
||||
@@ -8,6 +8,7 @@ from api.v1.drawing.vo import CreateDrawingTaskRequest
|
||||
from db.session import get_db
|
||||
from deps.auth import get_current_user
|
||||
from llm.factory import get_adapter
|
||||
from llm.enums import LLMProvider
|
||||
from services.drawing_service import get_drawing_page, create_drawing_task, fetch_drawing_task_status
|
||||
from utils.resp import resp_error, resp_success
|
||||
|
||||
@@ -50,7 +51,7 @@ def api_create_image_task(
|
||||
prompt = req.prompt
|
||||
model = req.model
|
||||
api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
adapter = get_adapter('tongyi', api_key=api_key, model=model)
|
||||
adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model=model)
|
||||
try:
|
||||
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
|
||||
# rsp = {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from llm.enums import LLMProvider
|
||||
|
||||
|
||||
class CreateDrawingTaskRequest(BaseModel):
|
||||
@@ -6,5 +7,5 @@ class CreateDrawingTaskRequest(BaseModel):
|
||||
style: str = 'auto'
|
||||
size: str = '1024*1024'
|
||||
model: str = 'wanx_v1'
|
||||
platform: str = 'tongyi'
|
||||
platform: str = LLMProvider.TONGYI
|
||||
n: int = 1
|
||||
|
||||
Reference in New Issue
Block a user