优化model

This commit is contained in:
XIE7654
2025-07-17 16:17:57 +08:00
parent 9b30115444
commit 6ed606f7a4
11 changed files with 316 additions and 193 deletions

View File

@@ -0,0 +1,40 @@
# Generated by Django 5.2.1 on 2025-07-17 07:07
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("ai", "0003_aimodel_model_type"),
]
operations = [
migrations.AlterField(
model_name="chatconversation",
name="model_id",
field=models.ForeignKey(
blank=True,
db_column="model_id",
db_comment="向量模型编号",
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="ai.aimodel",
verbose_name="向量模型编号",
),
),
migrations.AlterField(
model_name="chatmessage",
name="model_id",
field=models.ForeignKey(
blank=True,
db_column="model_id",
db_comment="向量模型编号",
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="ai.aimodel",
verbose_name="向量模型编号",
),
),
]

View File

@@ -258,6 +258,7 @@ class ChatConversation(CoreModel):
model_id = models.ForeignKey( model_id = models.ForeignKey(
'AIModel', 'AIModel',
on_delete=models.CASCADE, on_delete=models.CASCADE,
null=True, blank=True,
db_column='model_id', db_column='model_id',
verbose_name="向量模型编号", verbose_name="向量模型编号",
db_comment='向量模型编号' db_comment='向量模型编号'
@@ -302,6 +303,7 @@ class ChatMessage(CoreModel):
model_id = models.ForeignKey( model_id = models.ForeignKey(
'AIModel', 'AIModel',
on_delete=models.CASCADE, on_delete=models.CASCADE,
null=True, blank=True,
db_column='model_id', db_column='model_id',
verbose_name="向量模型编号", verbose_name="向量模型编号",
db_comment='向量模型编号' db_comment='向量模型编号'

View File

@@ -3,16 +3,15 @@ import asyncio
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from pydantic import BaseModel from pydantic import BaseModel
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
# from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI from langchain_community.chat_models import ChatOpenAI
from deps.auth import get_current_user from deps.auth import get_current_user
from services.chat_service import chat_service from services.chat_service import ChatDBService
from utils.resp import resp_success from db.session import get_db
router = APIRouter() router = APIRouter()
@@ -30,10 +29,10 @@ def get_deepseek_llm(api_key: str, model: str, openai_api_base: str):
) )
@router.post('/stream') @router.post('/stream')
async def chat_stream(request: Request): async def chat_stream(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user)):
body = await request.json() body = await request.json()
content = body.get('content') content = body.get('content')
print(content, 'content') conversation_id = body.get('conversation_id')
model = 'deepseek-chat' model = 'deepseek-chat'
api_key = os.getenv("DEEPSEEK_API_KEY") api_key = os.getenv("DEEPSEEK_API_KEY")
openai_api_base = "https://api.deepseek.com/v1" openai_api_base = "https://api.deepseek.com/v1"
@@ -43,8 +42,24 @@ async def chat_stream(request: Request):
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
return JSONResponse({"error": "content不能为空"}, status_code=400) return JSONResponse({"error": "content不能为空"}, status_code=400)
user_id = user["user_id"]
# 1. 获取或新建对话
try:
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model)
except ValueError as e:
from fastapi.responses import JSONResponse
return JSONResponse({"error": str(e)}, status_code=400)
# 2. 插入当前消息
ChatDBService.add_message(db, conversation, user_id, content)
# 3. 查询历史消息,组装上下文
history = ChatDBService.get_history(db, conversation.id)
history_contents = [msg.content for msg in history]
context = '\n'.join(history_contents)
async def event_generator(): async def event_generator():
async for chunk in llm.astream(content): async for chunk in llm.astream(context):
# 只返回 chunk.content 内容 # 只返回 chunk.content 内容
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
yield f"data: {chunk.content}\n\n" yield f"data: {chunk.content}\n\n"

View File

@@ -1,5 +1,5 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, declarative_base
from config import SQLALCHEMY_DATABASE_URL from config import SQLALCHEMY_DATABASE_URL
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
@@ -11,3 +11,5 @@ def get_db():
yield db yield db
finally: finally:
db.close() db.close()
Base = declarative_base()

View File

@@ -1,8 +1,13 @@
import os
from fastapi import FastAPI from fastapi import FastAPI
from dotenv import load_dotenv
from api.v1 import ai_chat from api.v1 import ai_chat
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from routers.ai_api_key import router as ai_api_key_router from routers.ai_api_key import router as ai_api_key_router
# 加载.env环境变量优先项目根目录
load_dotenv()
app = FastAPI() app = FastAPI()
origins = [ origins = [

View File

@@ -2,11 +2,11 @@ from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
) )
from sqlalchemy.orm import relationship, declarative_base from sqlalchemy.orm import relationship, declarative_base
from db.session import Base
from models.base import CoreModel
from models.user import DjangoUser # 确保导入 DjangoUser from models.user import DjangoUser # 确保导入 DjangoUser
Base = declarative_base()
# 状态选择类(示例) # 状态选择类(示例)
class CommonStatus: class CommonStatus:
DISABLED = 0 DISABLED = 0
@@ -37,16 +37,6 @@ class MessageType:
return [('text', '文本'), ('image', '图片')] return [('text', '文本'), ('image', '图片')]
# 基础模型类
class CoreModel(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
create_time = Column(DateTime)
update_time = Column(DateTime)
is_deleted = Column(Boolean, default=False)
# AI API 密钥表 # AI API 密钥表
class AIApiKey(CoreModel): class AIApiKey(CoreModel):
__tablename__ = 'ai_api_key' __tablename__ = 'ai_api_key'
@@ -83,161 +73,160 @@ class AIModel(CoreModel):
# AI 工具表 # AI 工具表
# class Tool(CoreModel): class Tool(CoreModel):
# __tablename__ = 'ai_tool' __tablename__ = 'ai_tool'
# name = Column(String(128), nullable=False) name = Column(String(128), nullable=False)
# description = Column(String(256), nullable=True) description = Column(String(256), nullable=True)
# status = Column(Integer, default=0) status = Column(Integer, default=0)
# def __str__(self): def __str__(self):
# return self.name return self.name
# AI 知识库表 # AI 知识库表
# class Knowledge(CoreModel): class Knowledge(CoreModel):
# __tablename__ = 'ai_knowledge' __tablename__ = 'ai_knowledge'
# name = Column(String(255), nullable=False) name = Column(String(255), nullable=False)
# description = Column(Text, nullable=True) description = Column(Text, nullable=True)
# embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False) embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# embedding_model = Column(String(32), nullable=False) embedding_model = Column(String(32), nullable=False)
# top_k = Column(Integer, default=0) top_k = Column(Integer, default=0)
# similarity_threshold = Column(Float, nullable=False) similarity_threshold = Column(Float, nullable=False)
# status = Column(Integer, default=CommonStatus.DISABLED) status = Column(Integer, default=CommonStatus.DISABLED)
# embedding_model_rel = relationship('AIModel', backref='knowledges') embedding_model_rel = relationship('AIModel', backref='knowledges')
# documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan') documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
# segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan') segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
# roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges') roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
# def __str__(self): def __str__(self):
# return self.name return self.name
# AI 知识库文档表 # AI 知识库文档表
# class KnowledgeDocument(CoreModel): class KnowledgeDocument(CoreModel):
# __tablename__ = 'ai_knowledge_document' __tablename__ = 'ai_knowledge_document'
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False) knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
# name = Column(String(255), nullable=False) name = Column(String(255), nullable=False)
# url = Column(String(1024), nullable=False) url = Column(String(1024), nullable=False)
# content = Column(Text, nullable=False) content = Column(Text, nullable=False)
# content_length = Column(Integer, nullable=False) content_length = Column(Integer, nullable=False)
# tokens = Column(Integer, nullable=False) tokens = Column(Integer, nullable=False)
# segment_max_tokens = Column(Integer, nullable=False) segment_max_tokens = Column(Integer, nullable=False)
# retrieval_count = Column(Integer, default=0) retrieval_count = Column(Integer, default=0)
# status = Column(Integer, default=CommonStatus.DISABLED) status = Column(Integer, default=CommonStatus.DISABLED)
# segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan') segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
# def __str__(self): def __str__(self):
# return self.name return self.name
# AI 知识库分段表 # AI 知识库分段表
# class KnowledgeSegment(CoreModel): class KnowledgeSegment(CoreModel):
# __tablename__ = 'ai_knowledge_segment' __tablename__ = 'ai_knowledge_segment'
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False) knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
# document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False) document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
# content = Column(Text, nullable=False) content = Column(Text, nullable=False)
# content_length = Column(Integer, nullable=False) content_length = Column(Integer, nullable=False)
# tokens = Column(Integer, nullable=False) tokens = Column(Integer, nullable=False)
# vector_id = Column(String(100), nullable=True) vector_id = Column(String(100), nullable=True)
# retrieval_count = Column(Integer, default=0) retrieval_count = Column(Integer, default=0)
# status = Column(Integer, default=CommonStatus.DISABLED) status = Column(Integer, default=CommonStatus.DISABLED)
# def __str__(self): def __str__(self):
# return f"Segment {self.id}" return f"Segment {self.id}"
# AI 聊天角色表 # AI 聊天角色表
# class ChatRole(CoreModel): class ChatRole(CoreModel):
# __tablename__ = 'ai_chat_role' __tablename__ = 'ai_chat_role'
# name = Column(String(128), nullable=False) name = Column(String(128), nullable=False)
# avatar = Column(String(256), nullable=False) avatar = Column(String(256), nullable=False)
# description = Column(String(256), nullable=True) description = Column(String(256), nullable=True)
# status = Column(Integer, default=CommonStatus.DISABLED) status = Column(Integer, default=CommonStatus.DISABLED)
# sort = Column(Integer, default=0) sort = Column(Integer, default=0)
# public_status = Column(Boolean, default=False) public_status = Column(Boolean, default=False)
# category = Column(String(32), nullable=True) category = Column(String(32), nullable=True)
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False) model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# system_message = Column(String(1024), nullable=True) system_message = Column(String(1024), nullable=True)
# user_id = Column( user_id = Column(
# Integer, Integer,
# ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
# nullable=True # 允许为空(如匿名角色) nullable=True # 允许为空(如匿名角色)
# ) )
# user = relationship(DjangoUser, backref='chat_roles') # 正确DjangoUser 已定义并导入 user = relationship(DjangoUser, backref='chat_roles') # 正确DjangoUser 已定义并导入
# model = relationship('AIModel', backref='chat_roles') model = relationship('AIModel', backref='chat_roles')
# tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles') tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
# # conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan') # conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
# # messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan') # messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
# def __str__(self):
# return self.name
def __str__(self):
return self.name
# AI 聊天对话表 # AI 聊天对话表
# class ChatConversation(CoreModel): class ChatConversation(CoreModel):
# __tablename__ = 'ai_chat_conversation' __tablename__ = 'ai_chat_conversation'
# title = Column(String(256), nullable=False) title = Column(String(256), nullable=False)
# pinned = Column(Boolean, default=False) pinned = Column(Boolean, default=False)
# pinned_time = Column(DateTime, nullable=True) pinned_time = Column(DateTime, nullable=True)
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True) user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True) role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False) model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# model = Column(String(32), nullable=False) model = Column(String(32), nullable=False)
# system_message = Column(String(1024), nullable=True) system_message = Column(String(1024), nullable=True)
# temperature = Column(Float, nullable=False) temperature = Column(Float, nullable=False)
# max_tokens = Column(Integer, nullable=False) max_tokens = Column(Integer, nullable=False)
# max_contexts = Column(Integer, nullable=False) max_contexts = Column(Integer, nullable=False)
# # user = relationship(DjangoUser, backref='conversations') # 正确DjangoUser 已定义并导入 user = relationship(DjangoUser, backref='conversations') # 正确DjangoUser 已定义并导入
# model_rel = relationship('AIModel', backref='conversations') model_rel = relationship('AIModel', backref='conversations')
# messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan') messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
# def __str__(self): def __str__(self):
# return self.title return self.title
# AI 聊天消息表 # AI 聊天消息表
# class ChatMessage(CoreModel): class ChatMessage(CoreModel):
# __tablename__ = 'ai_chat_message' __tablename__ = 'ai_chat_message'
# conversation_id = Column(Integer, nullable=False) conversation_id = Column(Integer, ForeignKey('ai_chat_conversation.id'), nullable=False)
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True) user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True) role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
# model = Column(String(32), nullable=False) model = Column(String(32), nullable=False)
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False) model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# type = Column(String(16), nullable=False) type = Column(String(16), nullable=False)
# reply_id = Column(Integer, nullable=True) reply_id = Column(Integer, nullable=True)
# content = Column(String(2048), nullable=False) content = Column(String(2048), nullable=False)
# use_context = Column(Boolean, default=False) use_context = Column(Boolean, default=False)
# segment_ids = Column(String(2048), nullable=True) segment_ids = Column(String(2048), nullable=True)
# # user = relationship(DjangoUser, backref='messages') # 正确DjangoUser 已定义并导入 user = relationship(DjangoUser, backref='messages') # 正确DjangoUser 已定义并导入
# model_rel = relationship('AIModel', backref='messages') model_rel = relationship('AIModel', backref='messages')
# def __str__(self): def __str__(self):
# return self.content[:30] return self.content[:30]
# # 聊天角色与知识库的关联表 # 聊天角色与知识库的关联表
# class ChatRoleKnowledge(Base): class ChatRoleKnowledge(Base):
# __tablename__ = 'ai_chat_role_knowledge' __tablename__ = 'ai_chat_role_knowledge'
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True) chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True) knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
# # 聊天角色与工具的关联表 # 聊天角色与工具的关联表
# class ChatRoleTool(Base): class ChatRoleTool(Base):
# __tablename__ = 'ai_chat_role_tool' __tablename__ = 'ai_chat_role_tool'
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True) chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
# tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True) tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)

13
chat/models/base.py Normal file
View File

@@ -0,0 +1,13 @@
from db.session import Base
from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
)
# 基础模型类
class CoreModel(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
create_time = Column(DateTime)
update_time = Column(DateTime)
is_deleted = Column(Boolean, default=False)

View File

@@ -1,7 +1,7 @@
from sqlalchemy import Column, Integer, String, DateTime, Boolean from sqlalchemy import Column, Integer, String, DateTime, Boolean
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() from db.session import Base
class AuthToken(Base): class AuthToken(Base):
__tablename__ = 'authtoken_token' __tablename__ = 'authtoken_token'
@@ -12,6 +12,7 @@ class AuthToken(Base):
class DjangoUser(Base): class DjangoUser(Base):
__tablename__ = 'system_users' __tablename__ = 'system_users'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
username = Column(String(150), nullable=False) username = Column(String(150), nullable=False)
email = Column(String(254)) email = Column(String(254))

View File

@@ -1,15 +1,58 @@
# LangChain集成示例 # LangChain集成示例
from langchain_openai import OpenAI from sqlalchemy.orm import Session
from dotenv import load_dotenv from datetime import datetime
load_dotenv() from models.ai import ChatConversation, ChatMessage, MessageType
class ChatService: class ChatDBService:
def __init__(self): @staticmethod
# 这里以OpenAI为例实际可根据需要配置 def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str) -> ChatConversation:
self.llm = OpenAI(temperature=0.7, api_key='sssss') if not conversation_id:
conversation = ChatConversation(
title="新对话",
user_id=user_id,
role_id=None,
model_id=1, # 需根据实际模型id调整
model=model,
system_message=None,
temperature=0.7,
max_tokens=2048,
max_contexts=10,
create_time=datetime.now(),
update_time=datetime.now(),
is_deleted=False
)
db.add(conversation)
db.commit()
db.refresh(conversation)
return conversation
else:
conversation = db.query(ChatConversation).get(conversation_id)
if not conversation:
raise ValueError("无效的conversation_id")
return conversation
def chat(self, prompt: str) -> str: @staticmethod
# 简单调用LLM def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
return self.llm(prompt) message = ChatMessage(
conversation_id=conversation.id,
user_id=user_id,
role_id=None,
model=conversation.model,
model_id=conversation.model_id,
type=MessageType.TEXT,
reply_id=None,
content=content,
use_context=True,
segment_ids=None,
create_time=datetime.now(),
update_time=datetime.now(),
is_deleted=False
)
db.add(message)
db.commit()
return message
@staticmethod
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()
chat_service = ChatService()

View File

@@ -4,9 +4,13 @@ import { formatToken } from '#/utils/auth';
export interface FetchAIStreamParams { export interface FetchAIStreamParams {
content: string; content: string;
conversation_id?: null | number;
} }
export async function fetchAIStream({ content }: FetchAIStreamParams) { export async function fetchAIStream({
content,
conversation_id,
}: FetchAIStreamParams) {
const accessStore = useAccessStore(); const accessStore = useAccessStore();
const token = accessStore.accessToken; const token = accessStore.accessToken;
const headers = new Headers(); const headers = new Headers();
@@ -19,6 +23,7 @@ export async function fetchAIStream({ content }: FetchAIStreamParams) {
headers, headers,
body: JSON.stringify({ body: JSON.stringify({
content, content,
conversation_id,
}), }),
}); });

View File

@@ -15,7 +15,6 @@ import {
} from 'ant-design-vue'; } from 'ant-design-vue';
import { fetchAIStream } from '#/api/ai/chat'; import { fetchAIStream } from '#/api/ai/chat';
// 移除 import typingSound from '@/assets/typing.mp3';
interface Message { interface Message {
id: number; id: number;
@@ -23,27 +22,17 @@ interface Message {
content: string; content: string;
} }
interface ChatItem {
id: number;
title: string;
lastMessage: string;
}
// mock 历史对话 // mock 历史对话
const chatList = ref([ const chatList = ref<ChatItem[]>([]);
{
id: 1,
title: '和deepseek的对话',
lastMessage: 'AI: 你好,有什么可以帮您?',
},
{ id: 2, title: '工作助理', lastMessage: 'AI: 今天的日程已为您安排。' },
]);
// mock 聊天消息 // mock 聊天消息
const messages = ref<Record<number, Message[]>>({ const messages = ref<Message[]>([]);
1: [
{ id: 1, role: 'user', content: '你好' },
{ id: 2, role: 'ai', content: '你好,有什么可以帮您?' },
],
2: [
{ id: 1, role: 'user', content: '帮我安排下今天的日程' },
{ id: 2, role: 'ai', content: '今天的日程已为您安排。' },
],
});
// mock 模型列表 // mock 模型列表
const modelOptions = [ const modelOptions = [
@@ -51,8 +40,8 @@ const modelOptions = [
{ label: 'GPT-4', value: 'gpt-4' }, { label: 'GPT-4', value: 'gpt-4' },
]; ];
const selectedChatId = ref(chatList.value[0]?.id || 1); const selectedChatId = ref<null | number>(chatList.value[0]?.id ?? null);
const selectedModel = ref(modelOptions[0].value); const selectedModel = ref(modelOptions[0]?.value);
const search = ref(''); const search = ref('');
const input = ref(''); const input = ref('');
const messagesRef = ref<HTMLElement | null>(null); const messagesRef = ref<HTMLElement | null>(null);
@@ -64,9 +53,14 @@ const filteredChats = computed(() => {
return chatList.value.filter((chat) => chat.title.includes(search.value)); return chatList.value.filter((chat) => chat.title.includes(search.value));
}); });
const currentMessages = computed( // 直接用conversationId过滤
() => messages.value?.[selectedChatId.value] || [], const currentMessages = computed(() => {
); if (!selectedChatId.value) return [];
return [];
// return messages.value.filter(
// (msg) => msg.conversationId === selectedChatId.value,
// );
});
function selectChat(id: number) { function selectChat(id: number) {
selectedChatId.value = id; selectedChatId.value = id;
@@ -80,38 +74,46 @@ function handleNewChat() {
title: `新对话${chatList.value.length + 1}`, title: `新对话${chatList.value.length + 1}`,
lastMessage: '', lastMessage: '',
}); });
messages.value[newId] = [];
selectedChatId.value = newId; selectedChatId.value = newId;
nextTick(scrollToBottom); nextTick(scrollToBottom);
} }
async function handleSend() { async function handleSend() {
if (!input.value.trim()) return; console.log(111);
const msg: Message = { id: Date.now(), role: 'user', content: input.value }; const msg: Message = {
if (!messages.value[selectedChatId.value]) { id: Date.now(),
messages.value[selectedChatId.value] = []; role: 'user',
} content: input.value,
messages.value[selectedChatId.value].push(msg); };
messages.value.push(msg);
// 预留AI消息 // 预留AI消息
const aiMsgObj: Message = { id: Date.now() + 1, role: 'ai', content: '' }; const aiMsgObj: Message = {
messages.value[selectedChatId.value].push(aiMsgObj); id: Date.now() + 1,
role: 'ai',
content: '',
};
messages.value.push(aiMsgObj);
currentAiMessage.value = aiMsgObj; currentAiMessage.value = aiMsgObj;
isAiTyping.value = true; isAiTyping.value = true;
const stream = await fetchAIStream({ const stream = await fetchAIStream({
content: input.value, content: input.value,
conversation_id: selectedChatId.value, // 新增
}); });
// 移除打字音效播放
for await (const chunk of stream) { for await (const chunk of stream) {
for (const char of chunk) { for (const char of chunk) {
aiMsgObj.content += char; aiMsgObj.content += char;
// 保证messages数组响应式更新
const idx = messages.value.indexOf(aiMsgObj);
if (idx !== -1) {
messages.value.splice(idx, 1, { ...aiMsgObj });
}
currentAiMessage.value = { ...aiMsgObj }; currentAiMessage.value = { ...aiMsgObj };
// 移除打字音效播放 await nextTick();
await new Promise(resolve => setTimeout(resolve, 15)); scrollToBottom();
nextTick(scrollToBottom); await new Promise((resolve) => setTimeout(resolve, 15));
} }
} }
isAiTyping.value = false; isAiTyping.value = false;
@@ -365,7 +367,13 @@ function scrollToBottom() {
border-radius: 2px; border-radius: 2px;
} }
@keyframes blink-cursor { @keyframes blink-cursor {
0%, 50% { opacity: 1; } 0%,
51%, 100% { opacity: 0; } 50% {
opacity: 1;
}
51%,
100% {
opacity: 0;
}
} }
</style> </style>