修改样式,接口

This commit is contained in:
XIE7654
2025-07-17 22:22:10 +08:00
parent e79eb196f2
commit fc96f77499
6 changed files with 121 additions and 54 deletions

View File

@@ -1,18 +1,21 @@
import os import os
import asyncio import asyncio
from fastapi import APIRouter, Depends, Request, Query from fastapi import APIRouter, Depends, Request, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List
from datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
from api.v1.vo import MessageVO, ConversationsVO
from deps.auth import get_current_user from deps.auth import get_current_user
from services.chat_service import ChatDBService from services.chat_service import ChatDBService
from db.session import get_db from db.session import get_db
from models.ai import ChatConversation, ChatMessage from models.ai import ChatConversation, ChatMessage, MessageType
from utils.resp import resp_success, Response
router = APIRouter() router = APIRouter()
@@ -34,6 +37,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
body = await request.json() body = await request.json()
content = body.get('content') content = body.get('content')
conversation_id = body.get('conversation_id') conversation_id = body.get('conversation_id')
print(content, 'content')
model = 'deepseek-chat' model = 'deepseek-chat'
api_key = os.getenv("DEEPSEEK_API_KEY") api_key = os.getenv("DEEPSEEK_API_KEY")
openai_api_base = "https://api.deepseek.com/v1" openai_api_base = "https://api.deepseek.com/v1"
@@ -51,6 +55,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
except ValueError as e: except ValueError as e:
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
return JSONResponse({"error": str(e)}, status_code=400) return JSONResponse({"error": str(e)}, status_code=400)
# 2. 插入当前消息 # 2. 插入当前消息
ChatDBService.add_message(db, conversation, user_id, content) ChatDBService.add_message(db, conversation, user_id, content)
@@ -59,36 +64,44 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
history_contents = [msg.content for msg in history] history_contents = [msg.content for msg in history]
context = '\n'.join(history_contents) context = '\n'.join(history_contents)
ai_reply = ""
async def event_generator(): async def event_generator():
nonlocal ai_reply
async for chunk in llm.astream(context): async for chunk in llm.astream(context):
# 只返回 chunk.content 内容
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
ai_reply += chunk.content
yield f"data: {chunk.content}\n\n" yield f"data: {chunk.content}\n\n"
else: else:
ai_reply += chunk
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# 生成器结束时插入AI消息
if ai_reply:
ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model)
return StreamingResponse(event_generator(), media_type='text/event-stream') return StreamingResponse(event_generator(), media_type='text/event-stream')
@router.get('/conversations') @router.get('/conversations')
def get_conversations( async def get_conversations(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user=Depends(get_current_user) user=Depends(get_current_user)
): ):
"""获取当前用户的聊天对话列表""" """获取当前用户的聊天对话列表last_message为字符串"""
user_id = user["user_id"] user_id = user["user_id"]
conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all() conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all()
return [ return resp_success(data=[
{ {
'id': c.id, 'id': c.id,
'title': c.title, 'title': c.title,
'update_time': c.update_time, 'update_time': c.update_time,
'last_message': c.messages[-1].content if c.messages else '', 'last_message': c.messages[-1].content if c.messages else None,
} }
for c in conversations for c in conversations
] ])
@router.get('/messages')
@router.get('/messages', response_model=Response[List[MessageVO]])
def get_messages( def get_messages(
conversation_id: int = Query(...), conversation_id: int = Query(...),
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -96,16 +109,7 @@ def get_messages(
): ):
"""获取指定会话的消息列表(当前用户)""" """获取指定会话的消息列表(当前用户)"""
user_id = user["user_id"] user_id = user["user_id"]
query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id, ChatMessage.user_id == user_id) query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id,
messages = query.order_by(ChatMessage.id).all() ChatMessage.user_id == user_id).order_by(ChatMessage.id).all()
return [ return resp_success(data=query)
{
'id': m.id,
'role': m.role_id,
'content': m.content,
'user_id': m.user_id,
'conversation_id': m.conversation_id,
'create_time': m.create_time,
}
for m in messages
]

20
chat/api/v1/vo.py Normal file
View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel
from datetime import datetime
class MessageVO(BaseModel):
id: int
content: str
conversation_id: int
type: str
class Config:
from_attributes = True # 启用ORM模式支持
class ConversationsVO(BaseModel):
id: int
title: str
update_time: datetime
last_message: str | None = None
class Config:
from_attributes = True

View File

@@ -29,14 +29,35 @@ class PlatformChoices:
# 消息类型选择类(示例) # 消息类型选择类(示例)
class MessageType: class MessageType:
TEXT = 'text' SYSTEM = "system" # 系统指令
IMAGE = 'image' USER = "user" # 用户消息
ASSISTANT = "assistant" # 助手回复
FUNCTION = "function" # 函数返回结果
@staticmethod @staticmethod
def choices(): def choices():
return [('text', '文本'), ('image', '图片')] """返回可用的消息角色选项"""
return [
(MessageType.SYSTEM, "系统"),
(MessageType.USER, "用户"),
(MessageType.ASSISTANT, "助手"),
(MessageType.FUNCTION, "函数")
]
class MessageContentType:
"""消息内容类型"""
TEXT = "text"
FUNCTION_CALL = "function_call"
@staticmethod
def choices():
"""返回可用的内容类型选项"""
return [
(MessageContentType.TEXT, "文本"),
(MessageContentType.FUNCTION_CALL, "函数调用")
]
# AI API 密钥表 # AI API 密钥表
class AIApiKey(CoreModel): class AIApiKey(CoreModel):
__tablename__ = 'ai_api_key' __tablename__ = 'ai_api_key'

View File

@@ -39,7 +39,7 @@ class ChatDBService:
role_id=None, role_id=None,
model=conversation.model, model=conversation.model,
model_id=conversation.model_id, model_id=conversation.model_id,
type=MessageType.TEXT, type=MessageType.USER,
reply_id=None, reply_id=None,
content=content, content=content,
use_context=True, use_context=True,
@@ -52,6 +52,28 @@ class ChatDBService:
db.commit() db.commit()
return message return message
@staticmethod
def insert_ai_message(db: Session, conversation, user_id: int, content: str, model: str):
from datetime import datetime
from models.ai import MessageType
message = ChatMessage(
conversation_id=conversation.id,
user_id=user_id,
role_id=None,
model=model,
model_id=conversation.model_id,
type=MessageType.ASSISTANT,
reply_id=None,
content=content,
use_context=True,
segment_ids=None,
create_time=datetime.now(),
update_time=datetime.now(),
is_deleted=False
)
db.add(message)
db.commit()
@staticmethod @staticmethod
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]: def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all() return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()

View File

@@ -1,16 +1,17 @@
from typing import Any, Optional from typing import Generic, TypeVar, Optional
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.generics import GenericModel
class CommonResponse(BaseModel): T = TypeVar("T")
code: int = 0
message: str = "success"
data: Any = None
error: Optional[Any] = None
class Response(BaseModel, Generic[T]):
code: int
message: str
data: Optional[T] = None # ✅ 明确 data 可为 None
def resp_success(data=None, message="success"): def resp_success(data: T, message: str = "success") -> Response[T]:
return CommonResponse(code=0, message=message, data=data, error=None) return Response(code=0, message=message, data=data)
def resp_error(message="error", code=1) -> Response[None]:
def resp_error(message="error", code=1, error=None): return Response(code=code, message=message, data=None)
return CommonResponse(code=code, message=message, data=None, error=error)

View File

@@ -17,8 +17,8 @@ import {
import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat'; import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat';
interface Message { interface Message {
id: number; id: null | number;
role: 'ai' | 'user'; type: 'assistant' | 'user';
content: string; content: string;
} }
@@ -28,14 +28,14 @@ interface ChatItem {
lastMessage: string; lastMessage: string;
} }
// mock 历史对话 // 历史对话
const chatList = ref<ChatItem[]>([]); const chatList = ref<ChatItem[]>([]);
// mock 聊天消息 // 聊天消息
const messages = ref<Message[]>([]); const messages = ref<Message[]>([]);
const currentMessages = ref<Message[]>([]); const currentMessages = ref<Message[]>([]);
// mock 模型列表 // 模型列表
const modelOptions = [ const modelOptions = [
{ label: 'deepseek', value: 'deepseek' }, { label: 'deepseek', value: 'deepseek' },
{ label: 'GPT-4', value: 'gpt-4' }, { label: 'GPT-4', value: 'gpt-4' },
@@ -56,7 +56,7 @@ const filteredChats = computed(() => {
async function selectChat(id: number) { async function selectChat(id: number) {
selectedChatId.value = id; selectedChatId.value = id;
const data = await getMessages(id); const { data } = await getMessages(id);
currentMessages.value = data; currentMessages.value = data;
nextTick(scrollToBottom); nextTick(scrollToBottom);
} }
@@ -74,16 +74,16 @@ function handleNewChat() {
async function handleSend() { async function handleSend() {
const msg: Message = { const msg: Message = {
id: Date.now(), id: null,
role: 'user', type: 'user',
content: input.value, content: input.value,
}; };
messages.value.push(msg); messages.value.push(msg);
// 预留AI消息 // 预留AI消息
const aiMsgObj: Message = { const aiMsgObj: Message = {
id: Date.now() + 1, id: null,
role: 'ai', type: 'assistant',
content: '', content: '',
}; };
messages.value.push(aiMsgObj); messages.value.push(aiMsgObj);
@@ -122,17 +122,16 @@ function scrollToBottom() {
// 获取历史对话 // 获取历史对话
async function fetchConversations() { async function fetchConversations() {
const data = await getConversations(); const { data } = await getConversations();
chatList.value = data.map((item: any) => ({ chatList.value = data.map((item: any) => ({
id: item.id, id: item.id,
title: item.title, title: item.title,
lastMessage: item.last_message || '', lastMessage: item.last_message || '',
})); }));
console.log(chatList.value, 'chatList');
// 默认选中第一个对话 // 默认选中第一个对话
if (chatList.value.length > 0) { if (chatList.value.length > 0) {
selectedChatId.value = chatList.value[0].id; selectedChatId.value = chatList.value[0].id;
selectChat(selectedChatId.value) await selectChat(selectedChatId.value);
} }
} }
@@ -201,15 +200,15 @@ onMounted(() => {
v-for="msg in currentMessages" v-for="msg in currentMessages"
:key="msg.id" :key="msg.id"
class="chat-message" class="chat-message"
:class="[msg.role]" :class="[msg.type]"
> >
<div class="bubble" :class="[msg.role]"> <div class="bubble" :class="[msg.type]">
<span class="role">{{ msg.role === 'user' ? '我' : 'AI' }}</span> <span class="role">{{ msg.type === 'user' ? '我' : 'AI' }}</span>
<span class="bubble-content"> <span class="bubble-content">
{{ msg.content }} {{ msg.content }}
<span <span
v-if=" v-if="
msg.role === 'ai' && isAiTyping && msg === currentAiMessage msg.type === 'ai' && isAiTyping && msg === currentAiMessage
" "
class="typing-cursor" class="typing-cursor"
></span> ></span>