优化model
This commit is contained in:
@@ -0,0 +1,40 @@
|
|||||||
|
# Generated by Django 5.2.1 on 2025-07-17 07:07
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("ai", "0003_aimodel_model_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="chatconversation",
|
||||||
|
name="model_id",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
db_column="model_id",
|
||||||
|
db_comment="向量模型编号",
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="ai.aimodel",
|
||||||
|
verbose_name="向量模型编号",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="chatmessage",
|
||||||
|
name="model_id",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
db_column="model_id",
|
||||||
|
db_comment="向量模型编号",
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="ai.aimodel",
|
||||||
|
verbose_name="向量模型编号",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -258,6 +258,7 @@ class ChatConversation(CoreModel):
|
|||||||
model_id = models.ForeignKey(
|
model_id = models.ForeignKey(
|
||||||
'AIModel',
|
'AIModel',
|
||||||
on_delete=models.CASCADE,
|
on_delete=models.CASCADE,
|
||||||
|
null=True, blank=True,
|
||||||
db_column='model_id',
|
db_column='model_id',
|
||||||
verbose_name="向量模型编号",
|
verbose_name="向量模型编号",
|
||||||
db_comment='向量模型编号'
|
db_comment='向量模型编号'
|
||||||
@@ -302,6 +303,7 @@ class ChatMessage(CoreModel):
|
|||||||
model_id = models.ForeignKey(
|
model_id = models.ForeignKey(
|
||||||
'AIModel',
|
'AIModel',
|
||||||
on_delete=models.CASCADE,
|
on_delete=models.CASCADE,
|
||||||
|
null=True, blank=True,
|
||||||
db_column='model_id',
|
db_column='model_id',
|
||||||
verbose_name="向量模型编号",
|
verbose_name="向量模型编号",
|
||||||
db_comment='向量模型编号'
|
db_comment='向量模型编号'
|
||||||
|
|||||||
@@ -3,16 +3,15 @@ import asyncio
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from langchain.memory import ConversationBufferMemory
|
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
# from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
|
||||||
from deps.auth import get_current_user
|
from deps.auth import get_current_user
|
||||||
from services.chat_service import chat_service
|
from services.chat_service import ChatDBService
|
||||||
from utils.resp import resp_success
|
from db.session import get_db
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -30,21 +29,37 @@ def get_deepseek_llm(api_key: str, model: str, openai_api_base: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@router.post('/stream')
|
@router.post('/stream')
|
||||||
async def chat_stream(request: Request):
|
async def chat_stream(request: Request, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
content = body.get('content')
|
content = body.get('content')
|
||||||
print(content, 'content')
|
conversation_id = body.get('conversation_id')
|
||||||
model = 'deepseek-chat'
|
model = 'deepseek-chat'
|
||||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||||
openai_api_base="https://api.deepseek.com/v1"
|
openai_api_base = "https://api.deepseek.com/v1"
|
||||||
llm = get_deepseek_llm(api_key, model, openai_api_base)
|
llm = get_deepseek_llm(api_key, model, openai_api_base)
|
||||||
|
|
||||||
if not content or not isinstance(content, str):
|
if not content or not isinstance(content, str):
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
||||||
|
|
||||||
|
user_id = user["user_id"]
|
||||||
|
|
||||||
|
# 1. 获取或新建对话
|
||||||
|
try:
|
||||||
|
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model)
|
||||||
|
except ValueError as e:
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
return JSONResponse({"error": str(e)}, status_code=400)
|
||||||
|
# 2. 插入当前消息
|
||||||
|
ChatDBService.add_message(db, conversation, user_id, content)
|
||||||
|
|
||||||
|
# 3. 查询历史消息,组装上下文
|
||||||
|
history = ChatDBService.get_history(db, conversation.id)
|
||||||
|
history_contents = [msg.content for msg in history]
|
||||||
|
context = '\n'.join(history_contents)
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for chunk in llm.astream(content):
|
async for chunk in llm.astream(context):
|
||||||
# 只返回 chunk.content 内容
|
# 只返回 chunk.content 内容
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
yield f"data: {chunk.content}\n\n"
|
yield f"data: {chunk.content}\n\n"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
from config import SQLALCHEMY_DATABASE_URL
|
from config import SQLALCHEMY_DATABASE_URL
|
||||||
|
|
||||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||||
@@ -10,4 +10,6 @@ def get_db():
|
|||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
|
import os
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from dotenv import load_dotenv
|
||||||
from api.v1 import ai_chat
|
from api.v1 import ai_chat
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from routers.ai_api_key import router as ai_api_key_router
|
from routers.ai_api_key import router as ai_api_key_router
|
||||||
|
|
||||||
|
# 加载.env环境变量,优先项目根目录
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ from sqlalchemy import (
|
|||||||
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import relationship, declarative_base
|
from sqlalchemy.orm import relationship, declarative_base
|
||||||
|
|
||||||
|
from db.session import Base
|
||||||
|
from models.base import CoreModel
|
||||||
from models.user import DjangoUser # 确保导入 DjangoUser
|
from models.user import DjangoUser # 确保导入 DjangoUser
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
# 状态选择类(示例)
|
# 状态选择类(示例)
|
||||||
class CommonStatus:
|
class CommonStatus:
|
||||||
DISABLED = 0
|
DISABLED = 0
|
||||||
@@ -37,16 +37,6 @@ class MessageType:
|
|||||||
return [('text', '文本'), ('image', '图片')]
|
return [('text', '文本'), ('image', '图片')]
|
||||||
|
|
||||||
|
|
||||||
# 基础模型类
|
|
||||||
class CoreModel(Base):
|
|
||||||
__abstract__ = True
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
create_time = Column(DateTime)
|
|
||||||
update_time = Column(DateTime)
|
|
||||||
is_deleted = Column(Boolean, default=False)
|
|
||||||
|
|
||||||
|
|
||||||
# AI API 密钥表
|
# AI API 密钥表
|
||||||
class AIApiKey(CoreModel):
|
class AIApiKey(CoreModel):
|
||||||
__tablename__ = 'ai_api_key'
|
__tablename__ = 'ai_api_key'
|
||||||
@@ -83,161 +73,160 @@ class AIModel(CoreModel):
|
|||||||
|
|
||||||
|
|
||||||
# AI 工具表
|
# AI 工具表
|
||||||
# class Tool(CoreModel):
|
class Tool(CoreModel):
|
||||||
# __tablename__ = 'ai_tool'
|
__tablename__ = 'ai_tool'
|
||||||
|
|
||||||
# name = Column(String(128), nullable=False)
|
name = Column(String(128), nullable=False)
|
||||||
# description = Column(String(256), nullable=True)
|
description = Column(String(256), nullable=True)
|
||||||
# status = Column(Integer, default=0)
|
status = Column(Integer, default=0)
|
||||||
|
|
||||||
# def __str__(self):
|
def __str__(self):
|
||||||
# return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
# AI 知识库表
|
# AI 知识库表
|
||||||
# class Knowledge(CoreModel):
|
class Knowledge(CoreModel):
|
||||||
# __tablename__ = 'ai_knowledge'
|
__tablename__ = 'ai_knowledge'
|
||||||
|
|
||||||
# name = Column(String(255), nullable=False)
|
name = Column(String(255), nullable=False)
|
||||||
# description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
# embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||||
# embedding_model = Column(String(32), nullable=False)
|
embedding_model = Column(String(32), nullable=False)
|
||||||
# top_k = Column(Integer, default=0)
|
top_k = Column(Integer, default=0)
|
||||||
# similarity_threshold = Column(Float, nullable=False)
|
similarity_threshold = Column(Float, nullable=False)
|
||||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||||
|
|
||||||
# embedding_model_rel = relationship('AIModel', backref='knowledges')
|
embedding_model_rel = relationship('AIModel', backref='knowledges')
|
||||||
# documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
|
documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
|
||||||
# segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
|
segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
|
||||||
# roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
|
roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
|
||||||
|
|
||||||
# def __str__(self):
|
def __str__(self):
|
||||||
# return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
# AI 知识库文档表
|
# AI 知识库文档表
|
||||||
# class KnowledgeDocument(CoreModel):
|
class KnowledgeDocument(CoreModel):
|
||||||
# __tablename__ = 'ai_knowledge_document'
|
__tablename__ = 'ai_knowledge_document'
|
||||||
|
|
||||||
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||||
# name = Column(String(255), nullable=False)
|
name = Column(String(255), nullable=False)
|
||||||
# url = Column(String(1024), nullable=False)
|
url = Column(String(1024), nullable=False)
|
||||||
# content = Column(Text, nullable=False)
|
content = Column(Text, nullable=False)
|
||||||
# content_length = Column(Integer, nullable=False)
|
content_length = Column(Integer, nullable=False)
|
||||||
# tokens = Column(Integer, nullable=False)
|
tokens = Column(Integer, nullable=False)
|
||||||
# segment_max_tokens = Column(Integer, nullable=False)
|
segment_max_tokens = Column(Integer, nullable=False)
|
||||||
# retrieval_count = Column(Integer, default=0)
|
retrieval_count = Column(Integer, default=0)
|
||||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||||
|
|
||||||
# segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
|
segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
|
||||||
|
|
||||||
# def __str__(self):
|
def __str__(self):
|
||||||
# return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
# AI 知识库分段表
|
# AI 知识库分段表
|
||||||
# class KnowledgeSegment(CoreModel):
|
class KnowledgeSegment(CoreModel):
|
||||||
# __tablename__ = 'ai_knowledge_segment'
|
__tablename__ = 'ai_knowledge_segment'
|
||||||
|
|
||||||
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||||
# document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
|
document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
|
||||||
# content = Column(Text, nullable=False)
|
content = Column(Text, nullable=False)
|
||||||
# content_length = Column(Integer, nullable=False)
|
content_length = Column(Integer, nullable=False)
|
||||||
# tokens = Column(Integer, nullable=False)
|
tokens = Column(Integer, nullable=False)
|
||||||
# vector_id = Column(String(100), nullable=True)
|
vector_id = Column(String(100), nullable=True)
|
||||||
# retrieval_count = Column(Integer, default=0)
|
retrieval_count = Column(Integer, default=0)
|
||||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||||
|
|
||||||
# def __str__(self):
|
def __str__(self):
|
||||||
# return f"Segment {self.id}"
|
return f"Segment {self.id}"
|
||||||
|
|
||||||
|
|
||||||
# AI 聊天角色表
|
# AI 聊天角色表
|
||||||
# class ChatRole(CoreModel):
|
class ChatRole(CoreModel):
|
||||||
# __tablename__ = 'ai_chat_role'
|
__tablename__ = 'ai_chat_role'
|
||||||
|
|
||||||
# name = Column(String(128), nullable=False)
|
name = Column(String(128), nullable=False)
|
||||||
# avatar = Column(String(256), nullable=False)
|
avatar = Column(String(256), nullable=False)
|
||||||
# description = Column(String(256), nullable=True)
|
description = Column(String(256), nullable=True)
|
||||||
# status = Column(Integer, default=CommonStatus.DISABLED)
|
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||||
# sort = Column(Integer, default=0)
|
sort = Column(Integer, default=0)
|
||||||
# public_status = Column(Boolean, default=False)
|
public_status = Column(Boolean, default=False)
|
||||||
# category = Column(String(32), nullable=True)
|
category = Column(String(32), nullable=True)
|
||||||
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||||
# system_message = Column(String(1024), nullable=True)
|
system_message = Column(String(1024), nullable=True)
|
||||||
# user_id = Column(
|
user_id = Column(
|
||||||
# Integer,
|
Integer,
|
||||||
# ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
|
ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
|
||||||
# nullable=True # 允许为空(如匿名角色)
|
nullable=True # 允许为空(如匿名角色)
|
||||||
# )
|
)
|
||||||
# user = relationship(DjangoUser, backref='chat_roles') # 正确:DjangoUser 已定义并导入
|
user = relationship(DjangoUser, backref='chat_roles') # 正确:DjangoUser 已定义并导入
|
||||||
|
|
||||||
# model = relationship('AIModel', backref='chat_roles')
|
model = relationship('AIModel', backref='chat_roles')
|
||||||
# tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
|
tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
|
||||||
# # conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
|
# conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
|
||||||
# # messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
|
# messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
|
||||||
|
|
||||||
# def __str__(self):
|
|
||||||
# return self.name
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
# AI 聊天对话表
|
# AI 聊天对话表
|
||||||
# class ChatConversation(CoreModel):
|
class ChatConversation(CoreModel):
|
||||||
# __tablename__ = 'ai_chat_conversation'
|
__tablename__ = 'ai_chat_conversation'
|
||||||
|
|
||||||
# title = Column(String(256), nullable=False)
|
title = Column(String(256), nullable=False)
|
||||||
# pinned = Column(Boolean, default=False)
|
pinned = Column(Boolean, default=False)
|
||||||
# pinned_time = Column(DateTime, nullable=True)
|
pinned_time = Column(DateTime, nullable=True)
|
||||||
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||||
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||||
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||||
# model = Column(String(32), nullable=False)
|
model = Column(String(32), nullable=False)
|
||||||
# system_message = Column(String(1024), nullable=True)
|
system_message = Column(String(1024), nullable=True)
|
||||||
# temperature = Column(Float, nullable=False)
|
temperature = Column(Float, nullable=False)
|
||||||
# max_tokens = Column(Integer, nullable=False)
|
max_tokens = Column(Integer, nullable=False)
|
||||||
# max_contexts = Column(Integer, nullable=False)
|
max_contexts = Column(Integer, nullable=False)
|
||||||
# # user = relationship(DjangoUser, backref='conversations') # 正确:DjangoUser 已定义并导入
|
user = relationship(DjangoUser, backref='conversations') # 正确:DjangoUser 已定义并导入
|
||||||
|
|
||||||
# model_rel = relationship('AIModel', backref='conversations')
|
model_rel = relationship('AIModel', backref='conversations')
|
||||||
# messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
|
messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
|
||||||
|
|
||||||
# def __str__(self):
|
def __str__(self):
|
||||||
# return self.title
|
return self.title
|
||||||
|
|
||||||
|
|
||||||
# AI 聊天消息表
|
# AI 聊天消息表
|
||||||
# class ChatMessage(CoreModel):
|
class ChatMessage(CoreModel):
|
||||||
# __tablename__ = 'ai_chat_message'
|
__tablename__ = 'ai_chat_message'
|
||||||
|
|
||||||
# conversation_id = Column(Integer, nullable=False)
|
conversation_id = Column(Integer, ForeignKey('ai_chat_conversation.id'), nullable=False)
|
||||||
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||||
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||||
# model = Column(String(32), nullable=False)
|
model = Column(String(32), nullable=False)
|
||||||
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||||
# type = Column(String(16), nullable=False)
|
type = Column(String(16), nullable=False)
|
||||||
# reply_id = Column(Integer, nullable=True)
|
reply_id = Column(Integer, nullable=True)
|
||||||
# content = Column(String(2048), nullable=False)
|
content = Column(String(2048), nullable=False)
|
||||||
# use_context = Column(Boolean, default=False)
|
use_context = Column(Boolean, default=False)
|
||||||
# segment_ids = Column(String(2048), nullable=True)
|
segment_ids = Column(String(2048), nullable=True)
|
||||||
|
|
||||||
# # user = relationship(DjangoUser, backref='messages') # 正确:DjangoUser 已定义并导入
|
user = relationship(DjangoUser, backref='messages') # 正确:DjangoUser 已定义并导入
|
||||||
# model_rel = relationship('AIModel', backref='messages')
|
model_rel = relationship('AIModel', backref='messages')
|
||||||
|
|
||||||
# def __str__(self):
|
def __str__(self):
|
||||||
# return self.content[:30]
|
return self.content[:30]
|
||||||
|
|
||||||
|
|
||||||
# # 聊天角色与知识库的关联表
|
# 聊天角色与知识库的关联表
|
||||||
# class ChatRoleKnowledge(Base):
|
class ChatRoleKnowledge(Base):
|
||||||
# __tablename__ = 'ai_chat_role_knowledge'
|
__tablename__ = 'ai_chat_role_knowledge'
|
||||||
|
|
||||||
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||||
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
|
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
# # 聊天角色与工具的关联表
|
# 聊天角色与工具的关联表
|
||||||
# class ChatRoleTool(Base):
|
class ChatRoleTool(Base):
|
||||||
# __tablename__ = 'ai_chat_role_tool'
|
__tablename__ = 'ai_chat_role_tool'
|
||||||
|
|
||||||
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||||
# tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)
|
tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)
|
||||||
13
chat/models/base.py
Normal file
13
chat/models/base.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from db.session import Base
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
||||||
|
)
|
||||||
|
|
||||||
|
# 基础模型类
|
||||||
|
class CoreModel(Base):
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
create_time = Column(DateTime)
|
||||||
|
update_time = Column(DateTime)
|
||||||
|
is_deleted = Column(Boolean, default=False)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
Base = declarative_base()
|
from db.session import Base
|
||||||
|
|
||||||
|
|
||||||
class AuthToken(Base):
|
class AuthToken(Base):
|
||||||
__tablename__ = 'authtoken_token'
|
__tablename__ = 'authtoken_token'
|
||||||
@@ -12,6 +12,7 @@ class AuthToken(Base):
|
|||||||
|
|
||||||
class DjangoUser(Base):
|
class DjangoUser(Base):
|
||||||
__tablename__ = 'system_users'
|
__tablename__ = 'system_users'
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
username = Column(String(150), nullable=False)
|
username = Column(String(150), nullable=False)
|
||||||
email = Column(String(254))
|
email = Column(String(254))
|
||||||
|
|||||||
@@ -1,15 +1,58 @@
|
|||||||
# LangChain集成示例
|
# LangChain集成示例
|
||||||
from langchain_openai import OpenAI
|
from sqlalchemy.orm import Session
|
||||||
from dotenv import load_dotenv
|
from datetime import datetime
|
||||||
load_dotenv()
|
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||||
|
|
||||||
class ChatService:
|
class ChatDBService:
|
||||||
def __init__(self):
|
@staticmethod
|
||||||
# 这里以OpenAI为例,实际可根据需要配置
|
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str) -> ChatConversation:
|
||||||
self.llm = OpenAI(temperature=0.7, api_key='sssss')
|
if not conversation_id:
|
||||||
|
conversation = ChatConversation(
|
||||||
|
title="新对话",
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=None,
|
||||||
|
model_id=1, # 需根据实际模型id调整
|
||||||
|
model=model,
|
||||||
|
system_message=None,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=2048,
|
||||||
|
max_contexts=10,
|
||||||
|
create_time=datetime.now(),
|
||||||
|
update_time=datetime.now(),
|
||||||
|
is_deleted=False
|
||||||
|
)
|
||||||
|
db.add(conversation)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(conversation)
|
||||||
|
return conversation
|
||||||
|
else:
|
||||||
|
conversation = db.query(ChatConversation).get(conversation_id)
|
||||||
|
if not conversation:
|
||||||
|
raise ValueError("无效的conversation_id")
|
||||||
|
return conversation
|
||||||
|
|
||||||
def chat(self, prompt: str) -> str:
|
@staticmethod
|
||||||
# 简单调用LLM
|
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
|
||||||
return self.llm(prompt)
|
message = ChatMessage(
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
user_id=user_id,
|
||||||
|
role_id=None,
|
||||||
|
model=conversation.model,
|
||||||
|
model_id=conversation.model_id,
|
||||||
|
type=MessageType.TEXT,
|
||||||
|
reply_id=None,
|
||||||
|
content=content,
|
||||||
|
use_context=True,
|
||||||
|
segment_ids=None,
|
||||||
|
create_time=datetime.now(),
|
||||||
|
update_time=datetime.now(),
|
||||||
|
is_deleted=False
|
||||||
|
)
|
||||||
|
db.add(message)
|
||||||
|
db.commit()
|
||||||
|
return message
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
|
||||||
|
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()
|
||||||
|
|
||||||
chat_service = ChatService()
|
|
||||||
@@ -4,9 +4,13 @@ import { formatToken } from '#/utils/auth';
|
|||||||
|
|
||||||
export interface FetchAIStreamParams {
|
export interface FetchAIStreamParams {
|
||||||
content: string;
|
content: string;
|
||||||
|
conversation_id?: null | number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function fetchAIStream({ content }: FetchAIStreamParams) {
|
export async function fetchAIStream({
|
||||||
|
content,
|
||||||
|
conversation_id,
|
||||||
|
}: FetchAIStreamParams) {
|
||||||
const accessStore = useAccessStore();
|
const accessStore = useAccessStore();
|
||||||
const token = accessStore.accessToken;
|
const token = accessStore.accessToken;
|
||||||
const headers = new Headers();
|
const headers = new Headers();
|
||||||
@@ -19,6 +23,7 @@ export async function fetchAIStream({ content }: FetchAIStreamParams) {
|
|||||||
headers,
|
headers,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
content,
|
content,
|
||||||
|
conversation_id,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import {
|
|||||||
} from 'ant-design-vue';
|
} from 'ant-design-vue';
|
||||||
|
|
||||||
import { fetchAIStream } from '#/api/ai/chat';
|
import { fetchAIStream } from '#/api/ai/chat';
|
||||||
// 移除 import typingSound from '@/assets/typing.mp3';
|
|
||||||
|
|
||||||
interface Message {
|
interface Message {
|
||||||
id: number;
|
id: number;
|
||||||
@@ -23,27 +22,17 @@ interface Message {
|
|||||||
content: string;
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface ChatItem {
|
||||||
|
id: number;
|
||||||
|
title: string;
|
||||||
|
lastMessage: string;
|
||||||
|
}
|
||||||
|
|
||||||
// mock 历史对话
|
// mock 历史对话
|
||||||
const chatList = ref([
|
const chatList = ref<ChatItem[]>([]);
|
||||||
{
|
|
||||||
id: 1,
|
|
||||||
title: '和deepseek的对话',
|
|
||||||
lastMessage: 'AI: 你好,有什么可以帮您?',
|
|
||||||
},
|
|
||||||
{ id: 2, title: '工作助理', lastMessage: 'AI: 今天的日程已为您安排。' },
|
|
||||||
]);
|
|
||||||
|
|
||||||
// mock 聊天消息
|
// mock 聊天消息
|
||||||
const messages = ref<Record<number, Message[]>>({
|
const messages = ref<Message[]>([]);
|
||||||
1: [
|
|
||||||
{ id: 1, role: 'user', content: '你好' },
|
|
||||||
{ id: 2, role: 'ai', content: '你好,有什么可以帮您?' },
|
|
||||||
],
|
|
||||||
2: [
|
|
||||||
{ id: 1, role: 'user', content: '帮我安排下今天的日程' },
|
|
||||||
{ id: 2, role: 'ai', content: '今天的日程已为您安排。' },
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
// mock 模型列表
|
// mock 模型列表
|
||||||
const modelOptions = [
|
const modelOptions = [
|
||||||
@@ -51,8 +40,8 @@ const modelOptions = [
|
|||||||
{ label: 'GPT-4', value: 'gpt-4' },
|
{ label: 'GPT-4', value: 'gpt-4' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const selectedChatId = ref(chatList.value[0]?.id || 1);
|
const selectedChatId = ref<null | number>(chatList.value[0]?.id ?? null);
|
||||||
const selectedModel = ref(modelOptions[0].value);
|
const selectedModel = ref(modelOptions[0]?.value);
|
||||||
const search = ref('');
|
const search = ref('');
|
||||||
const input = ref('');
|
const input = ref('');
|
||||||
const messagesRef = ref<HTMLElement | null>(null);
|
const messagesRef = ref<HTMLElement | null>(null);
|
||||||
@@ -64,9 +53,14 @@ const filteredChats = computed(() => {
|
|||||||
return chatList.value.filter((chat) => chat.title.includes(search.value));
|
return chatList.value.filter((chat) => chat.title.includes(search.value));
|
||||||
});
|
});
|
||||||
|
|
||||||
const currentMessages = computed(
|
// 直接用conversationId过滤
|
||||||
() => messages.value?.[selectedChatId.value] || [],
|
const currentMessages = computed(() => {
|
||||||
);
|
if (!selectedChatId.value) return [];
|
||||||
|
return [];
|
||||||
|
// return messages.value.filter(
|
||||||
|
// (msg) => msg.conversationId === selectedChatId.value,
|
||||||
|
// );
|
||||||
|
});
|
||||||
|
|
||||||
function selectChat(id: number) {
|
function selectChat(id: number) {
|
||||||
selectedChatId.value = id;
|
selectedChatId.value = id;
|
||||||
@@ -80,38 +74,46 @@ function handleNewChat() {
|
|||||||
title: `新对话${chatList.value.length + 1}`,
|
title: `新对话${chatList.value.length + 1}`,
|
||||||
lastMessage: '',
|
lastMessage: '',
|
||||||
});
|
});
|
||||||
messages.value[newId] = [];
|
|
||||||
selectedChatId.value = newId;
|
selectedChatId.value = newId;
|
||||||
nextTick(scrollToBottom);
|
nextTick(scrollToBottom);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleSend() {
|
async function handleSend() {
|
||||||
if (!input.value.trim()) return;
|
console.log(111);
|
||||||
const msg: Message = { id: Date.now(), role: 'user', content: input.value };
|
const msg: Message = {
|
||||||
if (!messages.value[selectedChatId.value]) {
|
id: Date.now(),
|
||||||
messages.value[selectedChatId.value] = [];
|
role: 'user',
|
||||||
}
|
content: input.value,
|
||||||
messages.value[selectedChatId.value].push(msg);
|
};
|
||||||
|
messages.value.push(msg);
|
||||||
|
|
||||||
// 预留AI消息
|
// 预留AI消息
|
||||||
const aiMsgObj: Message = { id: Date.now() + 1, role: 'ai', content: '' };
|
const aiMsgObj: Message = {
|
||||||
messages.value[selectedChatId.value].push(aiMsgObj);
|
id: Date.now() + 1,
|
||||||
|
role: 'ai',
|
||||||
|
content: '',
|
||||||
|
};
|
||||||
|
messages.value.push(aiMsgObj);
|
||||||
currentAiMessage.value = aiMsgObj;
|
currentAiMessage.value = aiMsgObj;
|
||||||
isAiTyping.value = true;
|
isAiTyping.value = true;
|
||||||
|
|
||||||
const stream = await fetchAIStream({
|
const stream = await fetchAIStream({
|
||||||
content: input.value,
|
content: input.value,
|
||||||
|
conversation_id: selectedChatId.value, // 新增
|
||||||
});
|
});
|
||||||
|
|
||||||
// 移除打字音效播放
|
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
for (const char of chunk) {
|
for (const char of chunk) {
|
||||||
aiMsgObj.content += char;
|
aiMsgObj.content += char;
|
||||||
|
// 保证messages数组响应式更新
|
||||||
|
const idx = messages.value.indexOf(aiMsgObj);
|
||||||
|
if (idx !== -1) {
|
||||||
|
messages.value.splice(idx, 1, { ...aiMsgObj });
|
||||||
|
}
|
||||||
currentAiMessage.value = { ...aiMsgObj };
|
currentAiMessage.value = { ...aiMsgObj };
|
||||||
// 移除打字音效播放
|
await nextTick();
|
||||||
await new Promise(resolve => setTimeout(resolve, 15));
|
scrollToBottom();
|
||||||
nextTick(scrollToBottom);
|
await new Promise((resolve) => setTimeout(resolve, 15));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
isAiTyping.value = false;
|
isAiTyping.value = false;
|
||||||
@@ -365,7 +367,13 @@ function scrollToBottom() {
|
|||||||
border-radius: 2px;
|
border-radius: 2px;
|
||||||
}
|
}
|
||||||
@keyframes blink-cursor {
|
@keyframes blink-cursor {
|
||||||
0%, 50% { opacity: 1; }
|
0%,
|
||||||
51%, 100% { opacity: 0; }
|
50% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
51%,
|
||||||
|
100% {
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
Reference in New Issue
Block a user