新对话bug

This commit is contained in:
XIE7654
2025-07-17 23:45:55 +08:00
parent fc96f77499
commit 66d5971570
4 changed files with 48 additions and 39 deletions

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from typing import List from typing import List
from datetime import datetime from datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel, SecretStr
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
@@ -16,6 +16,7 @@ from services.chat_service import ChatDBService
from db.session import get_db from db.session import get_db
from models.ai import ChatConversation, ChatMessage, MessageType from models.ai import ChatConversation, ChatMessage, MessageType
from utils.resp import resp_success, Response from utils.resp import resp_success, Response
from langchain_deepseek import ChatDeepSeek
router = APIRouter() router = APIRouter()
@@ -23,13 +24,13 @@ class ChatRequest(BaseModel):
prompt: str prompt: str
def get_deepseek_llm(api_key: str, model: str, openai_api_base: str):
def get_deepseek_llm(api_key: SecretStr, model: str):
# deepseek 兼容 OpenAI API需指定 base_url # deepseek 兼容 OpenAI API需指定 base_url
return ChatOpenAI( return ChatDeepSeek(
openai_api_key=api_key, api_key=api_key,
model_name=model, model=model,
streaming=True, streaming=True,
openai_api_base=openai_api_base, # deepseek的API地址
) )
@router.post('/stream') @router.post('/stream')
@@ -41,29 +42,34 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
model = 'deepseek-chat' model = 'deepseek-chat'
api_key = os.getenv("DEEPSEEK_API_KEY") api_key = os.getenv("DEEPSEEK_API_KEY")
openai_api_base = "https://api.deepseek.com/v1" openai_api_base = "https://api.deepseek.com/v1"
llm = get_deepseek_llm(api_key, model, openai_api_base) llm = get_deepseek_llm(api_key, model)
if not content or not isinstance(content, str): if not content or not isinstance(content, str):
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
return JSONResponse({"error": "content不能为空"}, status_code=400) return JSONResponse({"error": "content不能为空"}, status_code=400)
user_id = user["user_id"] user_id = user["user_id"]
print(conversation_id, 'conversation_id')
# 1. 获取或新建对话 # 1. 获取或新建对话
try: try:
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model) conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model, content)
except ValueError as e: except ValueError as e:
print(23232)
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
return JSONResponse({"error": str(e)}, status_code=400) return JSONResponse({"error": str(e)}, status_code=400)
print(conversation, 'dsds')
# 2. 插入当前消息 # 2. 插入当前消息
ChatDBService.add_message(db, conversation, user_id, content) ChatDBService.add_message(db, conversation, user_id, content)
context = [
("system", "You are a helpful assistant. Answer all questions to the best of your ability in {language}.")
]
# 3. 查询历史消息,组装上下文 # 3. 查询历史消息,组装上下文
history = ChatDBService.get_history(db, conversation.id) history = ChatDBService.get_history(db, conversation.id)
history_contents = [msg.content for msg in history] for msg in history:
context = '\n'.join(history_contents) # 假设 msg.type 存储的是 'user' 或 'assistant'
# role = msg.type if msg.type in ("user", "assistant") else "user"
context.append((msg.type, msg.content))
print('context', context)
ai_reply = "" ai_reply = ""
async def event_generator(): async def event_generator():
@@ -76,7 +82,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
ai_reply += chunk ai_reply += chunk
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# 生成器结束时插入AI消息 # 只保留最新AI回复
if ai_reply: if ai_reply:
ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model) ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model)

View File

@@ -5,13 +5,14 @@ from models.ai import ChatConversation, ChatMessage, MessageType
class ChatDBService: class ChatDBService:
@staticmethod @staticmethod
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str) -> ChatConversation: def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
if not conversation_id: if not conversation_id:
print(conversation_id, 'conversation_id')
conversation = ChatConversation( conversation = ChatConversation(
title="新对话", title=content,
user_id=user_id, user_id=user_id,
role_id=None, role_id=None,
model_id=1, # 需根据实际模型id调整 model_id=None, # 需根据实际模型id调整
model=model, model=model,
system_message=None, system_message=None,
temperature=0.7, temperature=0.7,

View File

@@ -1,7 +1,6 @@
from typing import Generic, TypeVar, Optional from typing import Generic, TypeVar, Optional
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.generics import GenericModel
T = TypeVar("T") T = TypeVar("T")

View File

@@ -23,7 +23,7 @@ interface Message {
} }
interface ChatItem { interface ChatItem {
id: number; id: null | number;
title: string; title: string;
lastMessage: string; lastMessage: string;
} }
@@ -33,7 +33,6 @@ const chatList = ref<ChatItem[]>([]);
// 聊天消息 // 聊天消息
const messages = ref<Message[]>([]); const messages = ref<Message[]>([]);
const currentMessages = ref<Message[]>([]);
// 模型列表 // 模型列表
const modelOptions = [ const modelOptions = [
@@ -57,18 +56,19 @@ const filteredChats = computed(() => {
async function selectChat(id: number) { async function selectChat(id: number) {
selectedChatId.value = id; selectedChatId.value = id;
const { data } = await getMessages(id); const { data } = await getMessages(id);
currentMessages.value = data; messages.value = data;
nextTick(scrollToBottom); nextTick(scrollToBottom);
} }
function handleNewChat() { function handleNewChat() {
const newId = Date.now(); const newId = null;
chatList.value.unshift({ chatList.value.unshift({
id: newId, id: newId,
title: `新对话${chatList.value.length + 1}`, title: `新对话${chatList.value.length + 1}`,
lastMessage: '', lastMessage: '',
}); });
selectedChatId.value = newId; selectedChatId.value = newId;
messages.value = [];
nextTick(scrollToBottom); nextTick(scrollToBottom);
} }
@@ -87,30 +87,30 @@ async function handleSend() {
content: '', content: '',
}; };
messages.value.push(aiMsgObj); messages.value.push(aiMsgObj);
currentAiMessage.value = aiMsgObj; const aiMsgIndex = messages.value.length - 1; // 记录AI消息的索引
isAiTyping.value = true; isAiTyping.value = true;
const stream = await fetchAIStream({ const stream = await fetchAIStream({
content: input.value, content: input.value,
conversation_id: selectedChatId.value, // 新增 conversation_id: selectedChatId.value, // 新增
}); });
if (chatList.value.length > 0) {
chatList.value[0]!.title = input.value.slice(0, 10);
}
// 立刻清空输入框
input.value = '';
for await (const chunk of stream) { for await (const chunk of stream) {
for (const char of chunk) { for (const char of chunk) {
aiMsgObj.content += char; messages.value[aiMsgIndex]!.content += char;
// 保证messages数组响应式更新 // 用 splice 替换,确保响应式
const idx = messages.value.indexOf(aiMsgObj); messages.value.splice(aiMsgIndex, 1, { ...messages.value[aiMsgIndex]! });
if (idx !== -1) {
messages.value.splice(idx, 1, { ...aiMsgObj });
}
currentAiMessage.value = { ...aiMsgObj };
await nextTick(); await nextTick();
scrollToBottom(); scrollToBottom();
await new Promise((resolve) => setTimeout(resolve, 15)); await new Promise((resolve) => setTimeout(resolve, 15));
} }
} }
isAiTyping.value = false; isAiTyping.value = false;
input.value = '';
nextTick(scrollToBottom); nextTick(scrollToBottom);
} }
@@ -144,7 +144,7 @@ onMounted(() => {
<Page auto-content-height> <Page auto-content-height>
<Row style="height: 100%"> <Row style="height: 100%">
<!-- 左侧历史对话 --> <!-- 左侧历史对话 -->
<Col :span="6" class="chat-sider"> <Col :span="5" class="chat-sider">
<div class="sider-header"> <div class="sider-header">
<Button type="primary" @click="handleNewChat">新建对话</Button> <Button type="primary" @click="handleNewChat">新建对话</Button>
<Input <Input
@@ -195,9 +195,9 @@ onMounted(() => {
/> />
</div> </div>
</div> </div>
<div class="chat-messages" ref="messagesRef"> <div class="chat-messages" style="height: 100%;" ref="messagesRef">
<div <div
v-for="msg in currentMessages" v-for="msg in messages"
:key="msg.id" :key="msg.id"
class="chat-message" class="chat-message"
:class="[msg.type]" :class="[msg.type]"
@@ -208,7 +208,9 @@ onMounted(() => {
{{ msg.content }} {{ msg.content }}
<span <span
v-if=" v-if="
msg.type === 'ai' && isAiTyping && msg === currentAiMessage msg.type === 'assistant' &&
isAiTyping &&
msg === currentAiMessage
" "
class="typing-cursor" class="typing-cursor"
></span> ></span>
@@ -277,6 +279,7 @@ onMounted(() => {
} }
.chat-content { .chat-content {
display: flex; display: flex;
height: 100%;
flex-direction: column; flex-direction: column;
padding: 16px 24px 8px 24px; padding: 16px 24px 8px 24px;
background: #f6f8fa; background: #f6f8fa;
@@ -292,13 +295,13 @@ onMounted(() => {
justify-content: flex-end; justify-content: flex-end;
} }
.chat-messages { .chat-messages {
flex: 1; flex: 1 1 auto;
overflow-y: auto; overflow-y: auto;
background: #fff; background: #fff;
border-radius: 8px; border-radius: 8px;
padding: 24px 16px 80px 16px; padding: 24px 16px 80px 16px;
margin-bottom: 0; margin-bottom: 0;
min-height: 300px; /* min-height: 300px; */
box-shadow: 0 2px 8px #0001; box-shadow: 0 2px 8px #0001;
transition: box-shadow 0.2s; transition: box-shadow 0.2s;
scrollbar-width: thin; scrollbar-width: thin;