diff --git a/chat/api/v1/ai_chat.py b/chat/api/v1/ai_chat.py index c2f3087..46a3544 100644 --- a/chat/api/v1/ai_chat.py +++ b/chat/api/v1/ai_chat.py @@ -72,12 +72,12 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess @router.get('/conversations') def get_conversations( - user_id: int = Query(None), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + user=Depends(get_current_user) ): - """获取指定用户的聊天对话列表""" + """获取当前用户的聊天对话列表""" + user_id = user["user_id"] conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all() - # 可根据需要序列化 return [ { 'id': c.id, @@ -90,19 +90,18 @@ def get_conversations( @router.get('/messages') def get_messages( - conversation_id: int = Query(None), - user_id: int = Query(None), - db: Session = Depends(get_db) + conversation_id: int = Query(...), + db: Session = Depends(get_db), + user=Depends(get_current_user) ): - """获取指定会话的消息列表,可选user_id过滤""" - query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id) - if user_id is not None: - query = query.filter(ChatMessage.user_id == user_id) + """获取指定会话的消息列表(当前用户)""" + user_id = user["user_id"] + query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id, ChatMessage.user_id == user_id) messages = query.order_by(ChatMessage.id).all() return [ { 'id': m.id, - 'role': m.role_id, # 如需role名可再查 + 'role': m.role_id, 'content': m.content, 'user_id': m.user_id, 'conversation_id': m.conversation_id, diff --git a/web/apps/web-antd/src/api/ai/chat.ts b/web/apps/web-antd/src/api/ai/chat.ts index c3375bc..537a610 100644 --- a/web/apps/web-antd/src/api/ai/chat.ts +++ b/web/apps/web-antd/src/api/ai/chat.ts @@ -1,7 +1,18 @@ -import { useAccessStore } from '@vben/stores'; +import { fetchWithAuth } from '#/utils/fetch-with-auth'; -import { formatToken } from '#/utils/auth'; +export async function getConversations() { + const res = await fetchWithAuth('/chat/api/v1/conversations'); + return await res.json(); +} +export async function getMessages(conversationId: number) { + const res = await fetchWithAuth( + `/chat/api/v1/messages?conversation_id=${conversationId}`, + ); + return await res.json(); +} + +// 你原有的fetchAIStream方法保留 export interface FetchAIStreamParams { content: string; conversation_id?: null | number; @@ -11,28 +22,14 @@ export async function fetchAIStream({ content, conversation_id, }: FetchAIStreamParams) { - const accessStore = useAccessStore(); - const token = accessStore.accessToken; - const headers = new Headers(); - - headers.append('Content-Type', 'application/json'); - headers.append('Authorization', formatToken(token)); - - const response = await fetch('/chat/api/v1/stream', { + const res = await fetchWithAuth('/chat/api/v1/stream', { method: 'POST', - headers, - body: JSON.stringify({ - content, - conversation_id, - }), + body: JSON.stringify({ content, conversation_id }), }); - - if (!response.body) throw new Error('No stream body'); - - const reader = response.body.getReader(); + if (!res.body) throw new Error('No stream body'); + const reader = res.body.getReader(); const decoder = new TextDecoder('utf8'); let buffer = ''; - return { async *[Symbol.asyncIterator]() { while (true) { diff --git a/web/apps/web-antd/src/utils/fetch-with-auth.ts b/web/apps/web-antd/src/utils/fetch-with-auth.ts new file mode 100644 index 0000000..1a1543b --- /dev/null +++ b/web/apps/web-antd/src/utils/fetch-with-auth.ts @@ -0,0 +1,11 @@ +import { formatToken } from '#/utils/auth'; +import { useAccessStore } from '@vben/stores'; + +export function fetchWithAuth(input: RequestInfo, init: RequestInit = {}) { + const accessStore = useAccessStore(); + const token = accessStore.accessToken; + const headers = new Headers(init.headers || {}); + headers.append('Content-Type', 'application/json'); + headers.append('Authorization', formatToken(token) as string); + return fetch(input, { ...init, headers }); +} diff --git a/web/apps/web-antd/src/views/ai/chat/index.vue b/web/apps/web-antd/src/views/ai/chat/index.vue index 6777777..97ea09c 100644 --- a/web/apps/web-antd/src/views/ai/chat/index.vue +++ b/web/apps/web-antd/src/views/ai/chat/index.vue @@ -14,7 +14,7 @@ import { Select, } from 'ant-design-vue'; -import { fetchAIStream } from '#/api/ai/chat'; +import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat'; interface Message { id: number; @@ -33,6 +33,7 @@ const chatList = ref([]); // mock 聊天消息 const messages = ref([]); +const currentMessages = ref([]); // mock 模型列表 const modelOptions = [ @@ -52,17 +53,11 @@ const filteredChats = computed(() => { if (!search.value) return chatList.value; return chatList.value.filter((chat) => chat.title.includes(search.value)); }); -// 直接用conversationId过滤 -const currentMessages = computed(() => { - if (!selectedChatId.value) return []; - return []; - // return messages.value.filter( - // (msg) => msg.conversationId === selectedChatId.value, - // ); -}); -function selectChat(id: number) { +async function selectChat(id: number) { selectedChatId.value = id; + const data = await getMessages(id); + currentMessages.value = data; nextTick(scrollToBottom); } @@ -78,7 +73,6 @@ function handleNewChat() { } async function handleSend() { - console.log(111); const msg: Message = { id: Date.now(), role: 'user', @@ -128,10 +122,7 @@ function scrollToBottom() { // 获取历史对话 async function fetchConversations() { - // 这里假设user_id为1,实际应从登录信息获取 - const params = new URLSearchParams({ user_id: '1' }); - const res = await fetch(`/chat/api/v1/conversations?${params.toString()}`); - const data = await res.json(); + const data = await getConversations(); chatList.value = data.map((item: any) => ({ id: item.id, title: item.title, @@ -141,6 +132,7 @@ async function fetchConversations() { // 默认选中第一个对话 if (chatList.value.length > 0) { selectedChatId.value = chatList.value[0].id; + selectChat(selectedChatId.value) } } @@ -491,7 +483,7 @@ onMounted(() => { .chat-list { flex: 1; overflow-y: auto; /* 只在对话列表区滚动 */ - min-height: 0; /* 关键:flex子项内滚动时必须加 */ + min-height: 0; /* 关键:flex子项内滚动时必须加 */ max-height: calc(100vh - 120px); /* 可根据实际header/footer高度调整 */ }