import os import asyncio from fastapi import APIRouter, Depends, Request, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from pydantic import BaseModel from langchain.chains import ConversationChain from langchain_community.chat_models import ChatOpenAI from deps.auth import get_current_user from services.chat_service import ChatDBService from db.session import get_db from models.ai import ChatConversation, ChatMessage router = APIRouter() class ChatRequest(BaseModel): prompt: str def get_deepseek_llm(api_key: str, model: str, openai_api_base: str): # deepseek 兼容 OpenAI API,需指定 base_url return ChatOpenAI( openai_api_key=api_key, model_name=model, streaming=True, openai_api_base=openai_api_base, # deepseek的API地址 ) @router.post('/stream') async def chat_stream(request: Request, user=Depends(get_current_user), db: Session = Depends(get_db)): body = await request.json() content = body.get('content') conversation_id = body.get('conversation_id') model = 'deepseek-chat' api_key = os.getenv("DEEPSEEK_API_KEY") openai_api_base = "https://api.deepseek.com/v1" llm = get_deepseek_llm(api_key, model, openai_api_base) if not content or not isinstance(content, str): from fastapi.responses import JSONResponse return JSONResponse({"error": "content不能为空"}, status_code=400) user_id = user["user_id"] # 1. 获取或新建对话 try: conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model) except ValueError as e: from fastapi.responses import JSONResponse return JSONResponse({"error": str(e)}, status_code=400) # 2. 插入当前消息 ChatDBService.add_message(db, conversation, user_id, content) # 3. 查询历史消息,组装上下文 history = ChatDBService.get_history(db, conversation.id) history_contents = [msg.content for msg in history] context = '\n'.join(history_contents) async def event_generator(): async for chunk in llm.astream(context): # 只返回 chunk.content 内容 if hasattr(chunk, 'content'): yield f"data: {chunk.content}\n\n" else: yield f"data: {chunk}\n\n" await asyncio.sleep(0.01) return StreamingResponse(event_generator(), media_type='text/event-stream') @router.get('/conversations') def get_conversations( user_id: int = Query(None), db: Session = Depends(get_db) ): """获取指定用户的聊天对话列表""" conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all() # 可根据需要序列化 return [ { 'id': c.id, 'title': c.title, 'update_time': c.update_time, 'last_message': c.messages[-1].content if c.messages else '', } for c in conversations ] @router.get('/messages') def get_messages( conversation_id: int = Query(None), user_id: int = Query(None), db: Session = Depends(get_db) ): """获取指定会话的消息列表,可选user_id过滤""" query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id) if user_id is not None: query = query.filter(ChatMessage.user_id == user_id) messages = query.order_by(ChatMessage.id).all() return [ { 'id': m.id, 'role': m.role_id, # 如需role名可再查 'content': m.content, 'user_id': m.user_id, 'conversation_id': m.conversation_id, 'create_time': m.create_time, } for m in messages ]