优化model
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
# Generated by Django 5.2.1 on 2025-07-17 07:07
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("ai", "0003_aimodel_model_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="chatconversation",
|
||||
name="model_id",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
db_column="model_id",
|
||||
db_comment="向量模型编号",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="ai.aimodel",
|
||||
verbose_name="向量模型编号",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="chatmessage",
|
||||
name="model_id",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
db_column="model_id",
|
||||
db_comment="向量模型编号",
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="ai.aimodel",
|
||||
verbose_name="向量模型编号",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -258,6 +258,7 @@ class ChatConversation(CoreModel):
|
||||
model_id = models.ForeignKey(
|
||||
'AIModel',
|
||||
on_delete=models.CASCADE,
|
||||
null=True, blank=True,
|
||||
db_column='model_id',
|
||||
verbose_name="向量模型编号",
|
||||
db_comment='向量模型编号'
|
||||
@@ -302,6 +303,7 @@ class ChatMessage(CoreModel):
|
||||
model_id = models.ForeignKey(
|
||||
'AIModel',
|
||||
on_delete=models.CASCADE,
|
||||
null=True, blank=True,
|
||||
db_column='model_id',
|
||||
verbose_name="向量模型编号",
|
||||
db_comment='向量模型编号'
|
||||
|
||||
@@ -3,16 +3,15 @@ import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.chains import ConversationChain
|
||||
# from langchain.chat_models import ChatOpenAI
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
|
||||
from deps.auth import get_current_user
|
||||
from services.chat_service import chat_service
|
||||
from utils.resp import resp_success
|
||||
from services.chat_service import ChatDBService
|
||||
from db.session import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -30,21 +29,37 @@ def get_deepseek_llm(api_key: str, model: str, openai_api_base: str):
|
||||
)
|
||||
|
||||
@router.post('/stream')
|
||||
async def chat_stream(request: Request):
|
||||
async def chat_stream(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
body = await request.json()
|
||||
content = body.get('content')
|
||||
print(content, 'content')
|
||||
conversation_id = body.get('conversation_id')
|
||||
model = 'deepseek-chat'
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
openai_api_base="https://api.deepseek.com/v1"
|
||||
openai_api_base = "https://api.deepseek.com/v1"
|
||||
llm = get_deepseek_llm(api_key, model, openai_api_base)
|
||||
|
||||
if not content or not isinstance(content, str):
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
||||
|
||||
user_id = user["user_id"]
|
||||
|
||||
# 1. 获取或新建对话
|
||||
try:
|
||||
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model)
|
||||
except ValueError as e:
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse({"error": str(e)}, status_code=400)
|
||||
# 2. 插入当前消息
|
||||
ChatDBService.add_message(db, conversation, user_id, content)
|
||||
|
||||
# 3. 查询历史消息,组装上下文
|
||||
history = ChatDBService.get_history(db, conversation.id)
|
||||
history_contents = [msg.content for msg in history]
|
||||
context = '\n'.join(history_contents)
|
||||
|
||||
async def event_generator():
|
||||
async for chunk in llm.astream(content):
|
||||
async for chunk in llm.astream(context):
|
||||
# 只返回 chunk.content 内容
|
||||
if hasattr(chunk, 'content'):
|
||||
yield f"data: {chunk.content}\n\n"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from config import SQLALCHEMY_DATABASE_URL
|
||||
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||
@@ -10,4 +10,6 @@ def get_db():
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from dotenv import load_dotenv
|
||||
from api.v1 import ai_chat
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from routers.ai_api_key import router as ai_api_key_router
|
||||
|
||||
# 加载.env环境变量,优先项目根目录
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = [
|
||||
|
||||
@@ -2,11 +2,11 @@ from sqlalchemy import (
|
||||
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
||||
)
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
|
||||
from db.session import Base
|
||||
from models.base import CoreModel
|
||||
from models.user import DjangoUser # 确保导入 DjangoUser
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# 状态选择类(示例)
|
||||
class CommonStatus:
|
||||
DISABLED = 0
|
||||
@@ -37,16 +37,6 @@ class MessageType:
|
||||
return [('text', '文本'), ('image', '图片')]
|
||||
|
||||
|
||||
# 基础模型类
|
||||
class CoreModel(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
create_time = Column(DateTime)
|
||||
update_time = Column(DateTime)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
|
||||
|
||||
# AI API 密钥表
|
||||
class AIApiKey(CoreModel):
|
||||
__tablename__ = 'ai_api_key'
|
||||
@@ -83,161 +73,160 @@ class AIModel(CoreModel):
|
||||
|
||||
|
||||
# AI 工具表
|
||||
# class Tool(CoreModel):
|
||||
# __tablename__ = 'ai_tool'
|
||||
class Tool(CoreModel):
|
||||
__tablename__ = 'ai_tool'
|
||||
|
||||
# name = Column(String(128), nullable=False)
|
||||
# description = Column(String(256), nullable=True)
|
||||
# status = Column(Integer, default=0)
|
||||
name = Column(String(128), nullable=False)
|
||||
description = Column(String(256), nullable=True)
|
||||
status = Column(Integer, default=0)
|
||||
|
||||
# def __str__(self):
|
||||
# return self.name
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 知识库表
|
||||
# class Knowledge(CoreModel):
|
||||
# __tablename__ = 'ai_knowledge'
|
||||
class Knowledge(CoreModel):
|
||||
__tablename__ = 'ai_knowledge'
|
||||
|
||||
# name = Column(String(255), nullable=False)
|
||||
# description = Column(Text, nullable=True)
|
||||
# embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
# embedding_model = Column(String(32), nullable=False)
|
||||
# top_k = Column(Integer, default=0)
|
||||
# similarity_threshold = Column(Float, nullable=False)
|
||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
embedding_model = Column(String(32), nullable=False)
|
||||
top_k = Column(Integer, default=0)
|
||||
similarity_threshold = Column(Float, nullable=False)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
|
||||
# embedding_model_rel = relationship('AIModel', backref='knowledges')
|
||||
# documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
|
||||
# segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
|
||||
# roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
|
||||
embedding_model_rel = relationship('AIModel', backref='knowledges')
|
||||
documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
|
||||
segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
|
||||
roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
|
||||
|
||||
# def __str__(self):
|
||||
# return self.name
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 知识库文档表
|
||||
# class KnowledgeDocument(CoreModel):
|
||||
# __tablename__ = 'ai_knowledge_document'
|
||||
class KnowledgeDocument(CoreModel):
|
||||
__tablename__ = 'ai_knowledge_document'
|
||||
|
||||
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||
# name = Column(String(255), nullable=False)
|
||||
# url = Column(String(1024), nullable=False)
|
||||
# content = Column(Text, nullable=False)
|
||||
# content_length = Column(Integer, nullable=False)
|
||||
# tokens = Column(Integer, nullable=False)
|
||||
# segment_max_tokens = Column(Integer, nullable=False)
|
||||
# retrieval_count = Column(Integer, default=0)
|
||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||
name = Column(String(255), nullable=False)
|
||||
url = Column(String(1024), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
content_length = Column(Integer, nullable=False)
|
||||
tokens = Column(Integer, nullable=False)
|
||||
segment_max_tokens = Column(Integer, nullable=False)
|
||||
retrieval_count = Column(Integer, default=0)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
|
||||
# segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
|
||||
segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
|
||||
|
||||
# def __str__(self):
|
||||
# return self.name
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 知识库分段表
|
||||
# class KnowledgeSegment(CoreModel):
|
||||
# __tablename__ = 'ai_knowledge_segment'
|
||||
class KnowledgeSegment(CoreModel):
|
||||
__tablename__ = 'ai_knowledge_segment'
|
||||
|
||||
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||
# document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
|
||||
# content = Column(Text, nullable=False)
|
||||
# content_length = Column(Integer, nullable=False)
|
||||
# tokens = Column(Integer, nullable=False)
|
||||
# vector_id = Column(String(100), nullable=True)
|
||||
# retrieval_count = Column(Integer, default=0)
|
||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||
document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
content_length = Column(Integer, nullable=False)
|
||||
tokens = Column(Integer, nullable=False)
|
||||
vector_id = Column(String(100), nullable=True)
|
||||
retrieval_count = Column(Integer, default=0)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
|
||||
# def __str__(self):
|
||||
# return f"Segment {self.id}"
|
||||
def __str__(self):
|
||||
return f"Segment {self.id}"
|
||||
|
||||
|
||||
# AI 聊天角色表
|
||||
# class ChatRole(CoreModel):
|
||||
# __tablename__ = 'ai_chat_role'
|
||||
class ChatRole(CoreModel):
|
||||
__tablename__ = 'ai_chat_role'
|
||||
|
||||
# name = Column(String(128), nullable=False)
|
||||
# avatar = Column(String(256), nullable=False)
|
||||
# description = Column(String(256), nullable=True)
|
||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
# sort = Column(Integer, default=0)
|
||||
# public_status = Column(Boolean, default=False)
|
||||
# category = Column(String(32), nullable=True)
|
||||
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
# system_message = Column(String(1024), nullable=True)
|
||||
# user_id = Column(
|
||||
# Integer,
|
||||
# ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
|
||||
# nullable=True # 允许为空(如匿名角色)
|
||||
# )
|
||||
# user = relationship(DjangoUser, backref='chat_roles') # 正确:DjangoUser 已定义并导入
|
||||
name = Column(String(128), nullable=False)
|
||||
avatar = Column(String(256), nullable=False)
|
||||
description = Column(String(256), nullable=True)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
sort = Column(Integer, default=0)
|
||||
public_status = Column(Boolean, default=False)
|
||||
category = Column(String(32), nullable=True)
|
||||
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
system_message = Column(String(1024), nullable=True)
|
||||
user_id = Column(
|
||||
Integer,
|
||||
ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
|
||||
nullable=True # 允许为空(如匿名角色)
|
||||
)
|
||||
user = relationship(DjangoUser, backref='chat_roles') # 正确:DjangoUser 已定义并导入
|
||||
|
||||
# model = relationship('AIModel', backref='chat_roles')
|
||||
# tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
|
||||
# # conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
|
||||
# # messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
|
||||
|
||||
# def __str__(self):
|
||||
# return self.name
|
||||
model = relationship('AIModel', backref='chat_roles')
|
||||
tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
|
||||
# conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
|
||||
# messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
# AI 聊天对话表
|
||||
# class ChatConversation(CoreModel):
|
||||
# __tablename__ = 'ai_chat_conversation'
|
||||
class ChatConversation(CoreModel):
|
||||
__tablename__ = 'ai_chat_conversation'
|
||||
|
||||
# title = Column(String(256), nullable=False)
|
||||
# pinned = Column(Boolean, default=False)
|
||||
# pinned_time = Column(DateTime, nullable=True)
|
||||
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
# model = Column(String(32), nullable=False)
|
||||
# system_message = Column(String(1024), nullable=True)
|
||||
# temperature = Column(Float, nullable=False)
|
||||
# max_tokens = Column(Integer, nullable=False)
|
||||
# max_contexts = Column(Integer, nullable=False)
|
||||
# # user = relationship(DjangoUser, backref='conversations') # 正确:DjangoUser 已定义并导入
|
||||
title = Column(String(256), nullable=False)
|
||||
pinned = Column(Boolean, default=False)
|
||||
pinned_time = Column(DateTime, nullable=True)
|
||||
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
model = Column(String(32), nullable=False)
|
||||
system_message = Column(String(1024), nullable=True)
|
||||
temperature = Column(Float, nullable=False)
|
||||
max_tokens = Column(Integer, nullable=False)
|
||||
max_contexts = Column(Integer, nullable=False)
|
||||
user = relationship(DjangoUser, backref='conversations') # 正确:DjangoUser 已定义并导入
|
||||
|
||||
# model_rel = relationship('AIModel', backref='conversations')
|
||||
# messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
|
||||
model_rel = relationship('AIModel', backref='conversations')
|
||||
messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
|
||||
|
||||
# def __str__(self):
|
||||
# return self.title
|
||||
def __str__(self):
|
||||
return self.title
|
||||
|
||||
|
||||
# AI 聊天消息表
|
||||
# class ChatMessage(CoreModel):
|
||||
# __tablename__ = 'ai_chat_message'
|
||||
class ChatMessage(CoreModel):
|
||||
__tablename__ = 'ai_chat_message'
|
||||
|
||||
# conversation_id = Column(Integer, nullable=False)
|
||||
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||
# model = Column(String(32), nullable=False)
|
||||
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
# type = Column(String(16), nullable=False)
|
||||
# reply_id = Column(Integer, nullable=True)
|
||||
# content = Column(String(2048), nullable=False)
|
||||
# use_context = Column(Boolean, default=False)
|
||||
# segment_ids = Column(String(2048), nullable=True)
|
||||
conversation_id = Column(Integer, ForeignKey('ai_chat_conversation.id'), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||
model = Column(String(32), nullable=False)
|
||||
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
type = Column(String(16), nullable=False)
|
||||
reply_id = Column(Integer, nullable=True)
|
||||
content = Column(String(2048), nullable=False)
|
||||
use_context = Column(Boolean, default=False)
|
||||
segment_ids = Column(String(2048), nullable=True)
|
||||
|
||||
# # user = relationship(DjangoUser, backref='messages') # 正确:DjangoUser 已定义并导入
|
||||
# model_rel = relationship('AIModel', backref='messages')
|
||||
user = relationship(DjangoUser, backref='messages') # 正确:DjangoUser 已定义并导入
|
||||
model_rel = relationship('AIModel', backref='messages')
|
||||
|
||||
# def __str__(self):
|
||||
# return self.content[:30]
|
||||
def __str__(self):
|
||||
return self.content[:30]
|
||||
|
||||
|
||||
# # 聊天角色与知识库的关联表
|
||||
# class ChatRoleKnowledge(Base):
|
||||
# __tablename__ = 'ai_chat_role_knowledge'
|
||||
# 聊天角色与知识库的关联表
|
||||
class ChatRoleKnowledge(Base):
|
||||
__tablename__ = 'ai_chat_role_knowledge'
|
||||
|
||||
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
|
||||
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
|
||||
|
||||
|
||||
# # 聊天角色与工具的关联表
|
||||
# class ChatRoleTool(Base):
|
||||
# __tablename__ = 'ai_chat_role_tool'
|
||||
# 聊天角色与工具的关联表
|
||||
class ChatRoleTool(Base):
|
||||
__tablename__ = 'ai_chat_role_tool'
|
||||
|
||||
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||
# tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)
|
||||
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||
tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)
|
||||
13
chat/models/base.py
Normal file
13
chat/models/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from db.session import Base
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
||||
)
|
||||
|
||||
# 基础模型类
|
||||
class CoreModel(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
create_time = Column(DateTime)
|
||||
update_time = Column(DateTime)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
from db.session import Base
|
||||
|
||||
|
||||
class AuthToken(Base):
|
||||
__tablename__ = 'authtoken_token'
|
||||
@@ -12,6 +12,7 @@ class AuthToken(Base):
|
||||
|
||||
class DjangoUser(Base):
|
||||
__tablename__ = 'system_users'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
username = Column(String(150), nullable=False)
|
||||
email = Column(String(254))
|
||||
|
||||
@@ -1,15 +1,58 @@
|
||||
# LangChain集成示例
|
||||
from langchain_openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||
|
||||
class ChatService:
|
||||
def __init__(self):
|
||||
# 这里以OpenAI为例,实际可根据需要配置
|
||||
self.llm = OpenAI(temperature=0.7, api_key='sssss')
|
||||
class ChatDBService:
|
||||
@staticmethod
|
||||
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str) -> ChatConversation:
|
||||
if not conversation_id:
|
||||
conversation = ChatConversation(
|
||||
title="新对话",
|
||||
user_id=user_id,
|
||||
role_id=None,
|
||||
model_id=1, # 需根据实际模型id调整
|
||||
model=model,
|
||||
system_message=None,
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
max_contexts=10,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now(),
|
||||
is_deleted=False
|
||||
)
|
||||
db.add(conversation)
|
||||
db.commit()
|
||||
db.refresh(conversation)
|
||||
return conversation
|
||||
else:
|
||||
conversation = db.query(ChatConversation).get(conversation_id)
|
||||
if not conversation:
|
||||
raise ValueError("无效的conversation_id")
|
||||
return conversation
|
||||
|
||||
def chat(self, prompt: str) -> str:
|
||||
# 简单调用LLM
|
||||
return self.llm(prompt)
|
||||
@staticmethod
|
||||
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
|
||||
message = ChatMessage(
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_id,
|
||||
role_id=None,
|
||||
model=conversation.model,
|
||||
model_id=conversation.model_id,
|
||||
type=MessageType.TEXT,
|
||||
reply_id=None,
|
||||
content=content,
|
||||
use_context=True,
|
||||
segment_ids=None,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now(),
|
||||
is_deleted=False
|
||||
)
|
||||
db.add(message)
|
||||
db.commit()
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
|
||||
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()
|
||||
|
||||
chat_service = ChatService()
|
||||
@@ -4,9 +4,13 @@ import { formatToken } from '#/utils/auth';
|
||||
|
||||
export interface FetchAIStreamParams {
|
||||
content: string;
|
||||
conversation_id?: null | number;
|
||||
}
|
||||
|
||||
export async function fetchAIStream({ content }: FetchAIStreamParams) {
|
||||
export async function fetchAIStream({
|
||||
content,
|
||||
conversation_id,
|
||||
}: FetchAIStreamParams) {
|
||||
const accessStore = useAccessStore();
|
||||
const token = accessStore.accessToken;
|
||||
const headers = new Headers();
|
||||
@@ -19,6 +23,7 @@ export async function fetchAIStream({ content }: FetchAIStreamParams) {
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
content,
|
||||
conversation_id,
|
||||
}),
|
||||
});
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
} from 'ant-design-vue';
|
||||
|
||||
import { fetchAIStream } from '#/api/ai/chat';
|
||||
// 移除 import typingSound from '@/assets/typing.mp3';
|
||||
|
||||
interface Message {
|
||||
id: number;
|
||||
@@ -23,27 +22,17 @@ interface Message {
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface ChatItem {
|
||||
id: number;
|
||||
title: string;
|
||||
lastMessage: string;
|
||||
}
|
||||
|
||||
// mock 历史对话
|
||||
const chatList = ref([
|
||||
{
|
||||
id: 1,
|
||||
title: '和deepseek的对话',
|
||||
lastMessage: 'AI: 你好,有什么可以帮您?',
|
||||
},
|
||||
{ id: 2, title: '工作助理', lastMessage: 'AI: 今天的日程已为您安排。' },
|
||||
]);
|
||||
const chatList = ref<ChatItem[]>([]);
|
||||
|
||||
// mock 聊天消息
|
||||
const messages = ref<Record<number, Message[]>>({
|
||||
1: [
|
||||
{ id: 1, role: 'user', content: '你好' },
|
||||
{ id: 2, role: 'ai', content: '你好,有什么可以帮您?' },
|
||||
],
|
||||
2: [
|
||||
{ id: 1, role: 'user', content: '帮我安排下今天的日程' },
|
||||
{ id: 2, role: 'ai', content: '今天的日程已为您安排。' },
|
||||
],
|
||||
});
|
||||
const messages = ref<Message[]>([]);
|
||||
|
||||
// mock 模型列表
|
||||
const modelOptions = [
|
||||
@@ -51,8 +40,8 @@ const modelOptions = [
|
||||
{ label: 'GPT-4', value: 'gpt-4' },
|
||||
];
|
||||
|
||||
const selectedChatId = ref(chatList.value[0]?.id || 1);
|
||||
const selectedModel = ref(modelOptions[0].value);
|
||||
const selectedChatId = ref<null | number>(chatList.value[0]?.id ?? null);
|
||||
const selectedModel = ref(modelOptions[0]?.value);
|
||||
const search = ref('');
|
||||
const input = ref('');
|
||||
const messagesRef = ref<HTMLElement | null>(null);
|
||||
@@ -64,9 +53,14 @@ const filteredChats = computed(() => {
|
||||
return chatList.value.filter((chat) => chat.title.includes(search.value));
|
||||
});
|
||||
|
||||
const currentMessages = computed(
|
||||
() => messages.value?.[selectedChatId.value] || [],
|
||||
);
|
||||
// 直接用conversationId过滤
|
||||
const currentMessages = computed(() => {
|
||||
if (!selectedChatId.value) return [];
|
||||
return [];
|
||||
// return messages.value.filter(
|
||||
// (msg) => msg.conversationId === selectedChatId.value,
|
||||
// );
|
||||
});
|
||||
|
||||
function selectChat(id: number) {
|
||||
selectedChatId.value = id;
|
||||
@@ -80,38 +74,46 @@ function handleNewChat() {
|
||||
title: `新对话${chatList.value.length + 1}`,
|
||||
lastMessage: '',
|
||||
});
|
||||
messages.value[newId] = [];
|
||||
selectedChatId.value = newId;
|
||||
nextTick(scrollToBottom);
|
||||
}
|
||||
|
||||
async function handleSend() {
|
||||
if (!input.value.trim()) return;
|
||||
const msg: Message = { id: Date.now(), role: 'user', content: input.value };
|
||||
if (!messages.value[selectedChatId.value]) {
|
||||
messages.value[selectedChatId.value] = [];
|
||||
}
|
||||
messages.value[selectedChatId.value].push(msg);
|
||||
console.log(111);
|
||||
const msg: Message = {
|
||||
id: Date.now(),
|
||||
role: 'user',
|
||||
content: input.value,
|
||||
};
|
||||
messages.value.push(msg);
|
||||
|
||||
// 预留AI消息
|
||||
const aiMsgObj: Message = { id: Date.now() + 1, role: 'ai', content: '' };
|
||||
messages.value[selectedChatId.value].push(aiMsgObj);
|
||||
const aiMsgObj: Message = {
|
||||
id: Date.now() + 1,
|
||||
role: 'ai',
|
||||
content: '',
|
||||
};
|
||||
messages.value.push(aiMsgObj);
|
||||
currentAiMessage.value = aiMsgObj;
|
||||
isAiTyping.value = true;
|
||||
|
||||
const stream = await fetchAIStream({
|
||||
content: input.value,
|
||||
conversation_id: selectedChatId.value, // 新增
|
||||
});
|
||||
|
||||
// 移除打字音效播放
|
||||
|
||||
for await (const chunk of stream) {
|
||||
for (const char of chunk) {
|
||||
aiMsgObj.content += char;
|
||||
// 保证messages数组响应式更新
|
||||
const idx = messages.value.indexOf(aiMsgObj);
|
||||
if (idx !== -1) {
|
||||
messages.value.splice(idx, 1, { ...aiMsgObj });
|
||||
}
|
||||
currentAiMessage.value = { ...aiMsgObj };
|
||||
// 移除打字音效播放
|
||||
await new Promise(resolve => setTimeout(resolve, 15));
|
||||
nextTick(scrollToBottom);
|
||||
await nextTick();
|
||||
scrollToBottom();
|
||||
await new Promise((resolve) => setTimeout(resolve, 15));
|
||||
}
|
||||
}
|
||||
isAiTyping.value = false;
|
||||
@@ -365,7 +367,13 @@ function scrollToBottom() {
|
||||
border-radius: 2px;
|
||||
}
|
||||
@keyframes blink-cursor {
|
||||
0%, 50% { opacity: 1; }
|
||||
51%, 100% { opacity: 0; }
|
||||
0%,
|
||||
50% {
|
||||
opacity: 1;
|
||||
}
|
||||
51%,
|
||||
100% {
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
Reference in New Issue
Block a user