优化历史对话
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,6 +12,7 @@ from langchain_community.chat_models import ChatOpenAI
|
||||
from deps.auth import get_current_user
|
||||
from services.chat_service import ChatDBService
|
||||
from db.session import get_db
|
||||
from models.ai import ChatConversation, ChatMessage
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -29,7 +30,7 @@ def get_deepseek_llm(api_key: str, model: str, openai_api_base: str):
|
||||
)
|
||||
|
||||
@router.post('/stream')
|
||||
async def chat_stream(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
async def chat_stream(request: Request, user=Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
body = await request.json()
|
||||
content = body.get('content')
|
||||
conversation_id = body.get('conversation_id')
|
||||
@@ -67,4 +68,45 @@ async def chat_stream(request: Request, db: Session = Depends(get_db), user=Depe
|
||||
yield f"data: {chunk}\n\n"
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
||||
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
||||
|
||||
@router.get('/conversations')
|
||||
def get_conversations(
|
||||
user_id: int = Query(None),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取指定用户的聊天对话列表"""
|
||||
conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all()
|
||||
# 可根据需要序列化
|
||||
return [
|
||||
{
|
||||
'id': c.id,
|
||||
'title': c.title,
|
||||
'update_time': c.update_time,
|
||||
'last_message': c.messages[-1].content if c.messages else '',
|
||||
}
|
||||
for c in conversations
|
||||
]
|
||||
|
||||
@router.get('/messages')
|
||||
def get_messages(
|
||||
conversation_id: int = Query(None),
|
||||
user_id: int = Query(None),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取指定会话的消息列表,可选user_id过滤"""
|
||||
query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id)
|
||||
if user_id is not None:
|
||||
query = query.filter(ChatMessage.user_id == user_id)
|
||||
messages = query.order_by(ChatMessage.id).all()
|
||||
return [
|
||||
{
|
||||
'id': m.id,
|
||||
'role': m.role_id, # 如需role名可再查
|
||||
'content': m.content,
|
||||
'user_id': m.user_id,
|
||||
'conversation_id': m.conversation_id,
|
||||
'create_time': m.create_time,
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
Reference in New Issue
Block a user