diff --git a/chat/api/v1/ai_chat.py b/chat/api/v1/ai_chat.py index 46a3544..ea0971c 100644 --- a/chat/api/v1/ai_chat.py +++ b/chat/api/v1/ai_chat.py @@ -1,18 +1,21 @@ import os import asyncio - from fastapi import APIRouter, Depends, Request, Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session +from typing import List +from datetime import datetime from pydantic import BaseModel from langchain.chains import ConversationChain from langchain_community.chat_models import ChatOpenAI +from api.v1.vo import MessageVO, ConversationsVO from deps.auth import get_current_user from services.chat_service import ChatDBService from db.session import get_db -from models.ai import ChatConversation, ChatMessage +from models.ai import ChatConversation, ChatMessage, MessageType +from utils.resp import resp_success, Response router = APIRouter() @@ -34,6 +37,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess body = await request.json() content = body.get('content') conversation_id = body.get('conversation_id') + print(content, 'content') model = 'deepseek-chat' api_key = os.getenv("DEEPSEEK_API_KEY") openai_api_base = "https://api.deepseek.com/v1" @@ -51,6 +55,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess except ValueError as e: from fastapi.responses import JSONResponse return JSONResponse({"error": str(e)}, status_code=400) + # 2. 插入当前消息 ChatDBService.add_message(db, conversation, user_id, content) @@ -59,36 +64,44 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess history_contents = [msg.content for msg in history] context = '\n'.join(history_contents) + ai_reply = "" + async def event_generator(): + nonlocal ai_reply async for chunk in llm.astream(context): - # 只返回 chunk.content 内容 if hasattr(chunk, 'content'): + ai_reply += chunk.content yield f"data: {chunk.content}\n\n" else: + ai_reply += chunk yield f"data: {chunk}\n\n" await asyncio.sleep(0.01) + # 生成器结束时插入AI消息 + if ai_reply: + ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model) return StreamingResponse(event_generator(), media_type='text/event-stream') @router.get('/conversations') -def get_conversations( +async def get_conversations( db: Session = Depends(get_db), user=Depends(get_current_user) ): - """获取当前用户的聊天对话列表""" + """获取当前用户的聊天对话列表,last_message为字符串""" user_id = user["user_id"] conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all() - return [ + return resp_success(data=[ { 'id': c.id, 'title': c.title, 'update_time': c.update_time, - 'last_message': c.messages[-1].content if c.messages else '', + 'last_message': c.messages[-1].content if c.messages else None, } for c in conversations - ] + ]) -@router.get('/messages') + +@router.get('/messages', response_model=Response[List[MessageVO]]) def get_messages( conversation_id: int = Query(...), db: Session = Depends(get_db), @@ -96,16 +109,7 @@ def get_messages( ): """获取指定会话的消息列表(当前用户)""" user_id = user["user_id"] - query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id, ChatMessage.user_id == user_id) - messages = query.order_by(ChatMessage.id).all() - return [ - { - 'id': m.id, - 'role': m.role_id, - 'content': m.content, - 'user_id': m.user_id, - 'conversation_id': m.conversation_id, - 'create_time': m.create_time, - } - for m in messages - ] \ No newline at end of file + query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id, + ChatMessage.user_id == user_id).order_by(ChatMessage.id).all() + return resp_success(data=query) + diff --git a/chat/api/v1/vo.py b/chat/api/v1/vo.py new file mode 100644 index 0000000..4ad59ae --- /dev/null +++ b/chat/api/v1/vo.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel +from datetime import datetime + +class MessageVO(BaseModel): + id: int + content: str + conversation_id: int + type: str + + class Config: + from_attributes = True # 启用ORM模式支持 + +class ConversationsVO(BaseModel): + id: int + title: str + update_time: datetime + last_message: str | None = None + + class Config: + from_attributes = True \ No newline at end of file diff --git a/chat/models/ai.py b/chat/models/ai.py index c014a4d..5c94dd3 100644 --- a/chat/models/ai.py +++ b/chat/models/ai.py @@ -29,14 +29,35 @@ class PlatformChoices: # 消息类型选择类(示例) class MessageType: - TEXT = 'text' - IMAGE = 'image' + SYSTEM = "system" # 系统指令 + USER = "user" # 用户消息 + ASSISTANT = "assistant" # 助手回复 + FUNCTION = "function" # 函数返回结果 @staticmethod def choices(): - return [('text', '文本'), ('image', '图片')] + """返回可用的消息角色选项""" + return [ + (MessageType.SYSTEM, "系统"), + (MessageType.USER, "用户"), + (MessageType.ASSISTANT, "助手"), + (MessageType.FUNCTION, "函数") + ] +class MessageContentType: + """消息内容类型""" + TEXT = "text" + FUNCTION_CALL = "function_call" + + @staticmethod + def choices(): + """返回可用的内容类型选项""" + return [ + (MessageContentType.TEXT, "文本"), + (MessageContentType.FUNCTION_CALL, "函数调用") + ] + # AI API 密钥表 class AIApiKey(CoreModel): __tablename__ = 'ai_api_key' diff --git a/chat/services/chat_service.py b/chat/services/chat_service.py index bf7a074..b5ffa8a 100644 --- a/chat/services/chat_service.py +++ b/chat/services/chat_service.py @@ -39,7 +39,7 @@ class ChatDBService: role_id=None, model=conversation.model, model_id=conversation.model_id, - type=MessageType.TEXT, + type=MessageType.USER, reply_id=None, content=content, use_context=True, @@ -52,6 +52,28 @@ class ChatDBService: db.commit() return message + @staticmethod + def insert_ai_message(db: Session, conversation, user_id: int, content: str, model: str): + from datetime import datetime + from models.ai import MessageType + message = ChatMessage( + conversation_id=conversation.id, + user_id=user_id, + role_id=None, + model=model, + model_id=conversation.model_id, + type=MessageType.ASSISTANT, + reply_id=None, + content=content, + use_context=True, + segment_ids=None, + create_time=datetime.now(), + update_time=datetime.now(), + is_deleted=False + ) + db.add(message) + db.commit() + @staticmethod def get_history(db: Session, conversation_id: int) -> list[ChatMessage]: return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all() diff --git a/chat/utils/resp.py b/chat/utils/resp.py index 607bd21..e251591 100644 --- a/chat/utils/resp.py +++ b/chat/utils/resp.py @@ -1,16 +1,17 @@ -from typing import Any, Optional +from typing import Generic, TypeVar, Optional + from pydantic import BaseModel +from pydantic.generics import GenericModel -class CommonResponse(BaseModel): - code: int = 0 - message: str = "success" - data: Any = None - error: Optional[Any] = None +T = TypeVar("T") +class Response(BaseModel, Generic[T]): + code: int + message: str + data: Optional[T] = None # ✅ 明确 data 可为 None -def resp_success(data=None, message="success"): - return CommonResponse(code=0, message=message, data=data, error=None) +def resp_success(data: T, message: str = "success") -> Response[T]: + return Response(code=0, message=message, data=data) - -def resp_error(message="error", code=1, error=None): - return CommonResponse(code=code, message=message, data=None, error=error) \ No newline at end of file +def resp_error(message="error", code=1) -> Response[None]: + return Response(code=code, message=message, data=None) \ No newline at end of file diff --git a/web/apps/web-antd/src/views/ai/chat/index.vue b/web/apps/web-antd/src/views/ai/chat/index.vue index 97ea09c..f15d63d 100644 --- a/web/apps/web-antd/src/views/ai/chat/index.vue +++ b/web/apps/web-antd/src/views/ai/chat/index.vue @@ -17,8 +17,8 @@ import { import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat'; interface Message { - id: number; - role: 'ai' | 'user'; + id: null | number; + type: 'assistant' | 'user'; content: string; } @@ -28,14 +28,14 @@ interface ChatItem { lastMessage: string; } -// mock 历史对话 +// 历史对话 const chatList = ref([]); -// mock 聊天消息 +// 聊天消息 const messages = ref([]); const currentMessages = ref([]); -// mock 模型列表 +// 模型列表 const modelOptions = [ { label: 'deepseek', value: 'deepseek' }, { label: 'GPT-4', value: 'gpt-4' }, @@ -56,7 +56,7 @@ const filteredChats = computed(() => { async function selectChat(id: number) { selectedChatId.value = id; - const data = await getMessages(id); + const { data } = await getMessages(id); currentMessages.value = data; nextTick(scrollToBottom); } @@ -74,16 +74,16 @@ function handleNewChat() { async function handleSend() { const msg: Message = { - id: Date.now(), - role: 'user', + id: null, + type: 'user', content: input.value, }; messages.value.push(msg); // 预留AI消息 const aiMsgObj: Message = { - id: Date.now() + 1, - role: 'ai', + id: null, + type: 'assistant', content: '', }; messages.value.push(aiMsgObj); @@ -122,17 +122,16 @@ function scrollToBottom() { // 获取历史对话 async function fetchConversations() { - const data = await getConversations(); + const { data } = await getConversations(); chatList.value = data.map((item: any) => ({ id: item.id, title: item.title, lastMessage: item.last_message || '', })); - console.log(chatList.value, 'chatList'); // 默认选中第一个对话 if (chatList.value.length > 0) { selectedChatId.value = chatList.value[0].id; - selectChat(selectedChatId.value) + await selectChat(selectedChatId.value); } } @@ -201,15 +200,15 @@ onMounted(() => { v-for="msg in currentMessages" :key="msg.id" class="chat-message" - :class="[msg.role]" + :class="[msg.type]" > -
- {{ msg.role === 'user' ? '我' : 'AI' }} +
+ {{ msg.type === 'user' ? '我' : 'AI' }} {{ msg.content }}