优化历史对话

This commit is contained in:
XIE7654
2025-07-17 16:45:18 +08:00
parent 6ed606f7a4
commit ff47665dcb
2 changed files with 188 additions and 28 deletions

View File

@@ -1,7 +1,7 @@
import os
import asyncio
from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends, Request, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
@@ -12,6 +12,7 @@ from langchain_community.chat_models import ChatOpenAI
from deps.auth import get_current_user
from services.chat_service import ChatDBService
from db.session import get_db
from models.ai import ChatConversation, ChatMessage
router = APIRouter()
@@ -29,7 +30,7 @@ def get_deepseek_llm(api_key: str, model: str, openai_api_base: str):
)
@router.post('/stream')
async def chat_stream(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user)):
async def chat_stream(request: Request, user=Depends(get_current_user), db: Session = Depends(get_db)):
body = await request.json()
content = body.get('content')
conversation_id = body.get('conversation_id')
@@ -67,4 +68,45 @@ async def chat_stream(request: Request, db: Session = Depends(get_db), user=Depe
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
return StreamingResponse(event_generator(), media_type='text/event-stream')
return StreamingResponse(event_generator(), media_type='text/event-stream')
@router.get('/conversations')
def get_conversations(
user_id: int = Query(None),
db: Session = Depends(get_db)
):
"""获取指定用户的聊天对话列表"""
conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all()
# 可根据需要序列化
return [
{
'id': c.id,
'title': c.title,
'update_time': c.update_time,
'last_message': c.messages[-1].content if c.messages else '',
}
for c in conversations
]
@router.get('/messages')
def get_messages(
conversation_id: int = Query(None),
user_id: int = Query(None),
db: Session = Depends(get_db)
):
"""获取指定会话的消息列表可选user_id过滤"""
query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id)
if user_id is not None:
query = query.filter(ChatMessage.user_id == user_id)
messages = query.order_by(ChatMessage.id).all()
return [
{
'id': m.id,
'role': m.role_id, # 如需role名可再查
'content': m.content,
'user_id': m.user_id,
'conversation_id': m.conversation_id,
'create_time': m.create_time,
}
for m in messages
]