修改样式,接口
This commit is contained in:
@@ -1,18 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request, Query
|
from fastapi import APIRouter, Depends, Request, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
from api.v1.vo import MessageVO, ConversationsVO
|
||||||
from deps.auth import get_current_user
|
from deps.auth import get_current_user
|
||||||
from services.chat_service import ChatDBService
|
from services.chat_service import ChatDBService
|
||||||
from db.session import get_db
|
from db.session import get_db
|
||||||
from models.ai import ChatConversation, ChatMessage
|
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||||
|
from utils.resp import resp_success, Response
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -34,6 +37,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
body = await request.json()
|
body = await request.json()
|
||||||
content = body.get('content')
|
content = body.get('content')
|
||||||
conversation_id = body.get('conversation_id')
|
conversation_id = body.get('conversation_id')
|
||||||
|
print(content, 'content')
|
||||||
model = 'deepseek-chat'
|
model = 'deepseek-chat'
|
||||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||||
openai_api_base = "https://api.deepseek.com/v1"
|
openai_api_base = "https://api.deepseek.com/v1"
|
||||||
@@ -51,6 +55,7 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
return JSONResponse({"error": str(e)}, status_code=400)
|
return JSONResponse({"error": str(e)}, status_code=400)
|
||||||
|
|
||||||
# 2. 插入当前消息
|
# 2. 插入当前消息
|
||||||
ChatDBService.add_message(db, conversation, user_id, content)
|
ChatDBService.add_message(db, conversation, user_id, content)
|
||||||
|
|
||||||
@@ -59,36 +64,44 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
history_contents = [msg.content for msg in history]
|
history_contents = [msg.content for msg in history]
|
||||||
context = '\n'.join(history_contents)
|
context = '\n'.join(history_contents)
|
||||||
|
|
||||||
|
ai_reply = ""
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
|
nonlocal ai_reply
|
||||||
async for chunk in llm.astream(context):
|
async for chunk in llm.astream(context):
|
||||||
# 只返回 chunk.content 内容
|
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
|
ai_reply += chunk.content
|
||||||
yield f"data: {chunk.content}\n\n"
|
yield f"data: {chunk.content}\n\n"
|
||||||
else:
|
else:
|
||||||
|
ai_reply += chunk
|
||||||
yield f"data: {chunk}\n\n"
|
yield f"data: {chunk}\n\n"
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
# 生成器结束时插入AI消息
|
||||||
|
if ai_reply:
|
||||||
|
ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model)
|
||||||
|
|
||||||
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
||||||
|
|
||||||
@router.get('/conversations')
|
@router.get('/conversations')
|
||||||
def get_conversations(
|
async def get_conversations(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
user=Depends(get_current_user)
|
user=Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""获取当前用户的聊天对话列表"""
|
"""获取当前用户的聊天对话列表,last_message为字符串"""
|
||||||
user_id = user["user_id"]
|
user_id = user["user_id"]
|
||||||
conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all()
|
conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all()
|
||||||
return [
|
return resp_success(data=[
|
||||||
{
|
{
|
||||||
'id': c.id,
|
'id': c.id,
|
||||||
'title': c.title,
|
'title': c.title,
|
||||||
'update_time': c.update_time,
|
'update_time': c.update_time,
|
||||||
'last_message': c.messages[-1].content if c.messages else '',
|
'last_message': c.messages[-1].content if c.messages else None,
|
||||||
}
|
}
|
||||||
for c in conversations
|
for c in conversations
|
||||||
]
|
])
|
||||||
|
|
||||||
@router.get('/messages')
|
|
||||||
|
@router.get('/messages', response_model=Response[List[MessageVO]])
|
||||||
def get_messages(
|
def get_messages(
|
||||||
conversation_id: int = Query(...),
|
conversation_id: int = Query(...),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -96,16 +109,7 @@ def get_messages(
|
|||||||
):
|
):
|
||||||
"""获取指定会话的消息列表(当前用户)"""
|
"""获取指定会话的消息列表(当前用户)"""
|
||||||
user_id = user["user_id"]
|
user_id = user["user_id"]
|
||||||
query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id, ChatMessage.user_id == user_id)
|
query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id,
|
||||||
messages = query.order_by(ChatMessage.id).all()
|
ChatMessage.user_id == user_id).order_by(ChatMessage.id).all()
|
||||||
return [
|
return resp_success(data=query)
|
||||||
{
|
|
||||||
'id': m.id,
|
|
||||||
'role': m.role_id,
|
|
||||||
'content': m.content,
|
|
||||||
'user_id': m.user_id,
|
|
||||||
'conversation_id': m.conversation_id,
|
|
||||||
'create_time': m.create_time,
|
|
||||||
}
|
|
||||||
for m in messages
|
|
||||||
]
|
|
||||||
|
|||||||
20
chat/api/v1/vo.py
Normal file
20
chat/api/v1/vo.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class MessageVO(BaseModel):
|
||||||
|
id: int
|
||||||
|
content: str
|
||||||
|
conversation_id: int
|
||||||
|
type: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True # 启用ORM模式支持
|
||||||
|
|
||||||
|
class ConversationsVO(BaseModel):
|
||||||
|
id: int
|
||||||
|
title: str
|
||||||
|
update_time: datetime
|
||||||
|
last_message: str | None = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
@@ -29,14 +29,35 @@ class PlatformChoices:
|
|||||||
|
|
||||||
# 消息类型选择类(示例)
|
# 消息类型选择类(示例)
|
||||||
class MessageType:
|
class MessageType:
|
||||||
TEXT = 'text'
|
SYSTEM = "system" # 系统指令
|
||||||
IMAGE = 'image'
|
USER = "user" # 用户消息
|
||||||
|
ASSISTANT = "assistant" # 助手回复
|
||||||
|
FUNCTION = "function" # 函数返回结果
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def choices():
|
def choices():
|
||||||
return [('text', '文本'), ('image', '图片')]
|
"""返回可用的消息角色选项"""
|
||||||
|
return [
|
||||||
|
(MessageType.SYSTEM, "系统"),
|
||||||
|
(MessageType.USER, "用户"),
|
||||||
|
(MessageType.ASSISTANT, "助手"),
|
||||||
|
(MessageType.FUNCTION, "函数")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MessageContentType:
|
||||||
|
"""消息内容类型"""
|
||||||
|
TEXT = "text"
|
||||||
|
FUNCTION_CALL = "function_call"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def choices():
|
||||||
|
"""返回可用的内容类型选项"""
|
||||||
|
return [
|
||||||
|
(MessageContentType.TEXT, "文本"),
|
||||||
|
(MessageContentType.FUNCTION_CALL, "函数调用")
|
||||||
|
]
|
||||||
|
|
||||||
# AI API 密钥表
|
# AI API 密钥表
|
||||||
class AIApiKey(CoreModel):
|
class AIApiKey(CoreModel):
|
||||||
__tablename__ = 'ai_api_key'
|
__tablename__ = 'ai_api_key'
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class ChatDBService:
|
|||||||
role_id=None,
|
role_id=None,
|
||||||
model=conversation.model,
|
model=conversation.model,
|
||||||
model_id=conversation.model_id,
|
model_id=conversation.model_id,
|
||||||
type=MessageType.TEXT,
|
type=MessageType.USER,
|
||||||
reply_id=None,
|
reply_id=None,
|
||||||
content=content,
|
content=content,
|
||||||
use_context=True,
|
use_context=True,
|
||||||
@@ -52,6 +52,28 @@ class ChatDBService:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def insert_ai_message(db: Session, conversation, user_id: int, content: str, model: str):
|
||||||
|
from datetime import datetime
|
||||||
|
from models.ai import MessageType
|
||||||
|
message = ChatMessage(
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=None,
|
||||||
|
model=model,
|
||||||
|
model_id=conversation.model_id,
|
||||||
|
type=MessageType.ASSISTANT,
|
||||||
|
reply_id=None,
|
||||||
|
content=content,
|
||||||
|
use_context=True,
|
||||||
|
segment_ids=None,
|
||||||
|
create_time=datetime.now(),
|
||||||
|
update_time=datetime.now(),
|
||||||
|
is_deleted=False
|
||||||
|
)
|
||||||
|
db.add(message)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
|
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
|
||||||
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()
|
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
from typing import Any, Optional
|
from typing import Generic, TypeVar, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
class CommonResponse(BaseModel):
|
T = TypeVar("T")
|
||||||
code: int = 0
|
|
||||||
message: str = "success"
|
|
||||||
data: Any = None
|
|
||||||
error: Optional[Any] = None
|
|
||||||
|
|
||||||
|
class Response(BaseModel, Generic[T]):
|
||||||
|
code: int
|
||||||
|
message: str
|
||||||
|
data: Optional[T] = None # ✅ 明确 data 可为 None
|
||||||
|
|
||||||
def resp_success(data=None, message="success"):
|
def resp_success(data: T, message: str = "success") -> Response[T]:
|
||||||
return CommonResponse(code=0, message=message, data=data, error=None)
|
return Response(code=0, message=message, data=data)
|
||||||
|
|
||||||
|
def resp_error(message="error", code=1) -> Response[None]:
|
||||||
def resp_error(message="error", code=1, error=None):
|
return Response(code=code, message=message, data=None)
|
||||||
return CommonResponse(code=code, message=message, data=None, error=error)
|
|
||||||
@@ -17,8 +17,8 @@ import {
|
|||||||
import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat';
|
import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat';
|
||||||
|
|
||||||
interface Message {
|
interface Message {
|
||||||
id: number;
|
id: null | number;
|
||||||
role: 'ai' | 'user';
|
type: 'assistant' | 'user';
|
||||||
content: string;
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,14 +28,14 @@ interface ChatItem {
|
|||||||
lastMessage: string;
|
lastMessage: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// mock 历史对话
|
// 历史对话
|
||||||
const chatList = ref<ChatItem[]>([]);
|
const chatList = ref<ChatItem[]>([]);
|
||||||
|
|
||||||
// mock 聊天消息
|
// 聊天消息
|
||||||
const messages = ref<Message[]>([]);
|
const messages = ref<Message[]>([]);
|
||||||
const currentMessages = ref<Message[]>([]);
|
const currentMessages = ref<Message[]>([]);
|
||||||
|
|
||||||
// mock 模型列表
|
// 模型列表
|
||||||
const modelOptions = [
|
const modelOptions = [
|
||||||
{ label: 'deepseek', value: 'deepseek' },
|
{ label: 'deepseek', value: 'deepseek' },
|
||||||
{ label: 'GPT-4', value: 'gpt-4' },
|
{ label: 'GPT-4', value: 'gpt-4' },
|
||||||
@@ -56,7 +56,7 @@ const filteredChats = computed(() => {
|
|||||||
|
|
||||||
async function selectChat(id: number) {
|
async function selectChat(id: number) {
|
||||||
selectedChatId.value = id;
|
selectedChatId.value = id;
|
||||||
const data = await getMessages(id);
|
const { data } = await getMessages(id);
|
||||||
currentMessages.value = data;
|
currentMessages.value = data;
|
||||||
nextTick(scrollToBottom);
|
nextTick(scrollToBottom);
|
||||||
}
|
}
|
||||||
@@ -74,16 +74,16 @@ function handleNewChat() {
|
|||||||
|
|
||||||
async function handleSend() {
|
async function handleSend() {
|
||||||
const msg: Message = {
|
const msg: Message = {
|
||||||
id: Date.now(),
|
id: null,
|
||||||
role: 'user',
|
type: 'user',
|
||||||
content: input.value,
|
content: input.value,
|
||||||
};
|
};
|
||||||
messages.value.push(msg);
|
messages.value.push(msg);
|
||||||
|
|
||||||
// 预留AI消息
|
// 预留AI消息
|
||||||
const aiMsgObj: Message = {
|
const aiMsgObj: Message = {
|
||||||
id: Date.now() + 1,
|
id: null,
|
||||||
role: 'ai',
|
type: 'assistant',
|
||||||
content: '',
|
content: '',
|
||||||
};
|
};
|
||||||
messages.value.push(aiMsgObj);
|
messages.value.push(aiMsgObj);
|
||||||
@@ -122,17 +122,16 @@ function scrollToBottom() {
|
|||||||
|
|
||||||
// 获取历史对话
|
// 获取历史对话
|
||||||
async function fetchConversations() {
|
async function fetchConversations() {
|
||||||
const data = await getConversations();
|
const { data } = await getConversations();
|
||||||
chatList.value = data.map((item: any) => ({
|
chatList.value = data.map((item: any) => ({
|
||||||
id: item.id,
|
id: item.id,
|
||||||
title: item.title,
|
title: item.title,
|
||||||
lastMessage: item.last_message || '',
|
lastMessage: item.last_message || '',
|
||||||
}));
|
}));
|
||||||
console.log(chatList.value, 'chatList');
|
|
||||||
// 默认选中第一个对话
|
// 默认选中第一个对话
|
||||||
if (chatList.value.length > 0) {
|
if (chatList.value.length > 0) {
|
||||||
selectedChatId.value = chatList.value[0].id;
|
selectedChatId.value = chatList.value[0].id;
|
||||||
selectChat(selectedChatId.value)
|
await selectChat(selectedChatId.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,15 +200,15 @@ onMounted(() => {
|
|||||||
v-for="msg in currentMessages"
|
v-for="msg in currentMessages"
|
||||||
:key="msg.id"
|
:key="msg.id"
|
||||||
class="chat-message"
|
class="chat-message"
|
||||||
:class="[msg.role]"
|
:class="[msg.type]"
|
||||||
>
|
>
|
||||||
<div class="bubble" :class="[msg.role]">
|
<div class="bubble" :class="[msg.type]">
|
||||||
<span class="role">{{ msg.role === 'user' ? '我' : 'AI' }}</span>
|
<span class="role">{{ msg.type === 'user' ? '我' : 'AI' }}</span>
|
||||||
<span class="bubble-content">
|
<span class="bubble-content">
|
||||||
{{ msg.content }}
|
{{ msg.content }}
|
||||||
<span
|
<span
|
||||||
v-if="
|
v-if="
|
||||||
msg.role === 'ai' && isAiTyping && msg === currentAiMessage
|
msg.type === 'ai' && isAiTyping && msg === currentAiMessage
|
||||||
"
|
"
|
||||||
class="typing-cursor"
|
class="typing-cursor"
|
||||||
></span>
|
></span>
|
||||||
|
|||||||
Reference in New Issue
Block a user