新对话bug
This commit is contained in:
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||||||
from typing import List
|
from typing import List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, SecretStr
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
|
||||||
@@ -16,6 +16,7 @@ from services.chat_service import ChatDBService
|
|||||||
from db.session import get_db
|
from db.session import get_db
|
||||||
from models.ai import ChatConversation, ChatMessage, MessageType
|
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||||
from utils.resp import resp_success, Response
|
from utils.resp import resp_success, Response
|
||||||
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -23,13 +24,13 @@ class ChatRequest(BaseModel):
|
|||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
def get_deepseek_llm(api_key: str, model: str, openai_api_base: str):
|
|
||||||
|
def get_deepseek_llm(api_key: SecretStr, model: str):
|
||||||
# deepseek 兼容 OpenAI API,需指定 base_url
|
# deepseek 兼容 OpenAI API,需指定 base_url
|
||||||
return ChatOpenAI(
|
return ChatDeepSeek(
|
||||||
openai_api_key=api_key,
|
api_key=api_key,
|
||||||
model_name=model,
|
model=model,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
openai_api_base=openai_api_base, # deepseek的API地址
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.post('/stream')
|
@router.post('/stream')
|
||||||
@@ -41,29 +42,34 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
model = 'deepseek-chat'
|
model = 'deepseek-chat'
|
||||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||||
openai_api_base = "https://api.deepseek.com/v1"
|
openai_api_base = "https://api.deepseek.com/v1"
|
||||||
llm = get_deepseek_llm(api_key, model, openai_api_base)
|
llm = get_deepseek_llm(api_key, model)
|
||||||
|
|
||||||
if not content or not isinstance(content, str):
|
if not content or not isinstance(content, str):
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
||||||
|
|
||||||
user_id = user["user_id"]
|
user_id = user["user_id"]
|
||||||
|
print(conversation_id, 'conversation_id')
|
||||||
# 1. 获取或新建对话
|
# 1. 获取或新建对话
|
||||||
try:
|
try:
|
||||||
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model)
|
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model, content)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
print(23232)
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
return JSONResponse({"error": str(e)}, status_code=400)
|
return JSONResponse({"error": str(e)}, status_code=400)
|
||||||
|
print(conversation, 'dsds')
|
||||||
# 2. 插入当前消息
|
# 2. 插入当前消息
|
||||||
ChatDBService.add_message(db, conversation, user_id, content)
|
ChatDBService.add_message(db, conversation, user_id, content)
|
||||||
|
context = [
|
||||||
|
("system", "You are a helpful assistant. Answer all questions to the best of your ability in {language}.")
|
||||||
|
]
|
||||||
# 3. 查询历史消息,组装上下文
|
# 3. 查询历史消息,组装上下文
|
||||||
history = ChatDBService.get_history(db, conversation.id)
|
history = ChatDBService.get_history(db, conversation.id)
|
||||||
history_contents = [msg.content for msg in history]
|
for msg in history:
|
||||||
context = '\n'.join(history_contents)
|
# 假设 msg.type 存储的是 'user' 或 'assistant'
|
||||||
|
# role = msg.type if msg.type in ("user", "assistant") else "user"
|
||||||
|
context.append((msg.type, msg.content))
|
||||||
|
print('context', context)
|
||||||
ai_reply = ""
|
ai_reply = ""
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
@@ -76,7 +82,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
ai_reply += chunk
|
ai_reply += chunk
|
||||||
yield f"data: {chunk}\n\n"
|
yield f"data: {chunk}\n\n"
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
# 生成器结束时插入AI消息
|
# 只保留最新AI回复
|
||||||
if ai_reply:
|
if ai_reply:
|
||||||
ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model)
|
ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model)
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,14 @@ from models.ai import ChatConversation, ChatMessage, MessageType
|
|||||||
|
|
||||||
class ChatDBService:
|
class ChatDBService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str) -> ChatConversation:
|
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
|
print(conversation_id, 'conversation_id')
|
||||||
conversation = ChatConversation(
|
conversation = ChatConversation(
|
||||||
title="新对话",
|
title=content,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
role_id=None,
|
role_id=None,
|
||||||
model_id=1, # 需根据实际模型id调整
|
model_id=None, # 需根据实际模型id调整
|
||||||
model=model,
|
model=model,
|
||||||
system_message=None,
|
system_message=None,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from typing import Generic, TypeVar, Optional
|
from typing import Generic, TypeVar, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.generics import GenericModel
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ interface Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
interface ChatItem {
|
interface ChatItem {
|
||||||
id: number;
|
id: null | number;
|
||||||
title: string;
|
title: string;
|
||||||
lastMessage: string;
|
lastMessage: string;
|
||||||
}
|
}
|
||||||
@@ -33,7 +33,6 @@ const chatList = ref<ChatItem[]>([]);
|
|||||||
|
|
||||||
// 聊天消息
|
// 聊天消息
|
||||||
const messages = ref<Message[]>([]);
|
const messages = ref<Message[]>([]);
|
||||||
const currentMessages = ref<Message[]>([]);
|
|
||||||
|
|
||||||
// 模型列表
|
// 模型列表
|
||||||
const modelOptions = [
|
const modelOptions = [
|
||||||
@@ -57,18 +56,19 @@ const filteredChats = computed(() => {
|
|||||||
async function selectChat(id: number) {
|
async function selectChat(id: number) {
|
||||||
selectedChatId.value = id;
|
selectedChatId.value = id;
|
||||||
const { data } = await getMessages(id);
|
const { data } = await getMessages(id);
|
||||||
currentMessages.value = data;
|
messages.value = data;
|
||||||
nextTick(scrollToBottom);
|
nextTick(scrollToBottom);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChat() {
|
function handleNewChat() {
|
||||||
const newId = Date.now();
|
const newId = null;
|
||||||
chatList.value.unshift({
|
chatList.value.unshift({
|
||||||
id: newId,
|
id: newId,
|
||||||
title: `新对话${chatList.value.length + 1}`,
|
title: `新对话${chatList.value.length + 1}`,
|
||||||
lastMessage: '',
|
lastMessage: '',
|
||||||
});
|
});
|
||||||
selectedChatId.value = newId;
|
selectedChatId.value = newId;
|
||||||
|
messages.value = [];
|
||||||
nextTick(scrollToBottom);
|
nextTick(scrollToBottom);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,30 +87,30 @@ async function handleSend() {
|
|||||||
content: '',
|
content: '',
|
||||||
};
|
};
|
||||||
messages.value.push(aiMsgObj);
|
messages.value.push(aiMsgObj);
|
||||||
currentAiMessage.value = aiMsgObj;
|
const aiMsgIndex = messages.value.length - 1; // 记录AI消息的索引
|
||||||
|
|
||||||
isAiTyping.value = true;
|
isAiTyping.value = true;
|
||||||
|
|
||||||
const stream = await fetchAIStream({
|
const stream = await fetchAIStream({
|
||||||
content: input.value,
|
content: input.value,
|
||||||
conversation_id: selectedChatId.value, // 新增
|
conversation_id: selectedChatId.value, // 新增
|
||||||
});
|
});
|
||||||
|
if (chatList.value.length > 0) {
|
||||||
|
chatList.value[0]!.title = input.value.slice(0, 10);
|
||||||
|
}
|
||||||
|
// 立刻清空输入框
|
||||||
|
input.value = '';
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
for (const char of chunk) {
|
for (const char of chunk) {
|
||||||
aiMsgObj.content += char;
|
messages.value[aiMsgIndex]!.content += char;
|
||||||
// 保证messages数组响应式更新
|
// 用 splice 替换,确保响应式
|
||||||
const idx = messages.value.indexOf(aiMsgObj);
|
messages.value.splice(aiMsgIndex, 1, { ...messages.value[aiMsgIndex]! });
|
||||||
if (idx !== -1) {
|
|
||||||
messages.value.splice(idx, 1, { ...aiMsgObj });
|
|
||||||
}
|
|
||||||
currentAiMessage.value = { ...aiMsgObj };
|
|
||||||
await nextTick();
|
await nextTick();
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
await new Promise((resolve) => setTimeout(resolve, 15));
|
await new Promise((resolve) => setTimeout(resolve, 15));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
isAiTyping.value = false;
|
isAiTyping.value = false;
|
||||||
input.value = '';
|
|
||||||
nextTick(scrollToBottom);
|
nextTick(scrollToBottom);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ onMounted(() => {
|
|||||||
<Page auto-content-height>
|
<Page auto-content-height>
|
||||||
<Row style="height: 100%">
|
<Row style="height: 100%">
|
||||||
<!-- 左侧历史对话 -->
|
<!-- 左侧历史对话 -->
|
||||||
<Col :span="6" class="chat-sider">
|
<Col :span="5" class="chat-sider">
|
||||||
<div class="sider-header">
|
<div class="sider-header">
|
||||||
<Button type="primary" @click="handleNewChat">新建对话</Button>
|
<Button type="primary" @click="handleNewChat">新建对话</Button>
|
||||||
<Input
|
<Input
|
||||||
@@ -195,9 +195,9 @@ onMounted(() => {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="chat-messages" ref="messagesRef">
|
<div class="chat-messages" style="height: 100%;" ref="messagesRef">
|
||||||
<div
|
<div
|
||||||
v-for="msg in currentMessages"
|
v-for="msg in messages"
|
||||||
:key="msg.id"
|
:key="msg.id"
|
||||||
class="chat-message"
|
class="chat-message"
|
||||||
:class="[msg.type]"
|
:class="[msg.type]"
|
||||||
@@ -208,7 +208,9 @@ onMounted(() => {
|
|||||||
{{ msg.content }}
|
{{ msg.content }}
|
||||||
<span
|
<span
|
||||||
v-if="
|
v-if="
|
||||||
msg.type === 'ai' && isAiTyping && msg === currentAiMessage
|
msg.type === 'assistant' &&
|
||||||
|
isAiTyping &&
|
||||||
|
msg === currentAiMessage
|
||||||
"
|
"
|
||||||
class="typing-cursor"
|
class="typing-cursor"
|
||||||
></span>
|
></span>
|
||||||
@@ -277,6 +279,7 @@ onMounted(() => {
|
|||||||
}
|
}
|
||||||
.chat-content {
|
.chat-content {
|
||||||
display: flex;
|
display: flex;
|
||||||
|
height: 100%;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
padding: 16px 24px 8px 24px;
|
padding: 16px 24px 8px 24px;
|
||||||
background: #f6f8fa;
|
background: #f6f8fa;
|
||||||
@@ -292,13 +295,13 @@ onMounted(() => {
|
|||||||
justify-content: flex-end;
|
justify-content: flex-end;
|
||||||
}
|
}
|
||||||
.chat-messages {
|
.chat-messages {
|
||||||
flex: 1;
|
flex: 1 1 auto;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
background: #fff;
|
background: #fff;
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
padding: 24px 16px 80px 16px;
|
padding: 24px 16px 80px 16px;
|
||||||
margin-bottom: 0;
|
margin-bottom: 0;
|
||||||
min-height: 300px;
|
/* min-height: 300px; */
|
||||||
box-shadow: 0 2px 8px #0001;
|
box-shadow: 0 2px 8px #0001;
|
||||||
transition: box-shadow 0.2s;
|
transition: box-shadow 0.2s;
|
||||||
scrollbar-width: thin;
|
scrollbar-width: thin;
|
||||||
|
|||||||
Reference in New Issue
Block a user