新加ai 对话页面

This commit is contained in:
XIE7654
2025-07-17 10:59:48 +08:00
parent 682e3805eb
commit 6505e69d4f
14 changed files with 548 additions and 4 deletions

View File

@@ -11,7 +11,6 @@ class ChatRequest(BaseModel):
prompt: str
@router.post("/")
def chat_api(data: ChatRequest, user=Depends(get_current_user)):
# return {"msg": "pong"}

18
chat/crud/ai_api_key.py Normal file
View File

@@ -0,0 +1,18 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from crud.base import CRUDBase
from models.ai import AIApiKey # SQLAlchemy模型
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate
# 继承通用CRUD基类指定模型和Pydantic类型
class CRUDApiKey(CRUDBase[AIApiKey, AIApiKeyCreate, AIApiKeyUpdate]):
# 如有特殊逻辑,可重写父类方法(如创建时验证平台唯一性)
def create(self, db: Session, *, obj_in: AIApiKeyCreate):
# 示例:验证平台+名称唯一
if self.get_by(db, platform=obj_in.platform, name=obj_in.name):
raise HTTPException(status_code=400, detail="该平台下名称已存在")
return super().create(db, obj_in=obj_in)
# 创建CRUD实例
ai_api_key_crud = CRUDApiKey(AIApiKey)

90
chat/crud/base.py Normal file
View File

@@ -0,0 +1,90 @@
from typing import Generic, TypeVar, List, Optional, Dict, Any
from fastapi import HTTPException
from sqlalchemy.orm import Session
from datetime import datetime
# 定义泛型变量分别对应SQLAlchemy模型、创建Pydantic模型、更新Pydantic模型
ModelType = TypeVar("ModelType")
CreateSchemaType = TypeVar("CreateSchemaType")
UpdateSchemaType = TypeVar("UpdateSchemaType")
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: ModelType):
"""
初始化CRUD类需要传入SQLAlchemy模型
:param model: SQLAlchemy模型类如AIApiKey、AIModel等
"""
self.model = model
# 创建
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
"""创建一条记录"""
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
# 自动填充时间字段如果模型有created_at/updated_at
if hasattr(self.model, "created_at"):
obj_in_data["created_at"] = datetime.now()
if hasattr(self.model, "updated_at"):
obj_in_data["updated_at"] = datetime.now()
db_obj = self.model(**obj_in_data) # 实例化模型
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
# 按ID查询
def get(self, db: Session, id: int) -> Optional[ModelType]:
"""按ID查询单条记录"""
return db.query(self.model).filter(self.model.id == id).first()
# 按条件查询单条记录
def get_by(self, db: Session, **kwargs) -> Optional[ModelType]:
"""按条件查询单条记录如get_by(name="test")"""
return db.query(self.model).filter_by(**kwargs).first()
# 分页查询所有
def get_multi(
self, db: Session, *, page: int = 0, limit: int = 100
) -> List[ModelType]:
"""分页查询多条记录"""
return db.query(self.model).offset(page).limit(limit).all()
# 更新
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | Dict[str, Any]
) -> ModelType:
"""更新记录支持Pydantic模型或字典"""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True) # 只更新提供的字段
# 遍历更新字段
for field in update_data:
if hasattr(db_obj, field):
setattr(db_obj, field, update_data[field])
# 自动更新updated_at如果模型有该字段
if hasattr(db_obj, "updated_at"):
db_obj.updated_at = datetime.now()
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
# 删除
def remove(self, db: Session, *, id: int) -> ModelType:
"""删除记录"""
obj = db.query(self.model).get(id)
if not obj:
raise HTTPException(status_code=404, detail=f"{self.model.__name__}不存在")
db.delete(obj)
db.commit()
return obj

View File

@@ -1,6 +1,7 @@
from fastapi import FastAPI
from api.v1 import ai_chat
from fastapi.middleware.cors import CORSMiddleware
from routers.ai_api_key import router as ai_api_key_router
app = FastAPI()
@@ -19,6 +20,7 @@ app.add_middleware(
# 注册路由
app.include_router(ai_chat.router, prefix="/chat/api/v1", tags=["chat"])
app.include_router(ai_api_key_router, tags=["chat"])
# 健康检查
@app.get("/ping")

243
chat/models/ai.py Normal file
View File

@@ -0,0 +1,243 @@
from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
)
from sqlalchemy.orm import relationship, declarative_base
from models.user import DjangoUser # 确保导入 DjangoUser
Base = declarative_base()
# 状态选择类(示例)
class CommonStatus:
DISABLED = 0
ENABLED = 1
@staticmethod
def choices():
return [(0, '禁用'), (1, '启用')]
# 平台选择类(示例)
class PlatformChoices:
OPENAI = 'openai'
ALIMNS = 'alimns'
@staticmethod
def choices():
return [('openai', 'OpenAI'), ('alimns', '阿里云MNS')]
# 消息类型选择类(示例)
class MessageType:
TEXT = 'text'
IMAGE = 'image'
@staticmethod
def choices():
return [('text', '文本'), ('image', '图片')]
# 基础模型类
class CoreModel(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
create_time = Column(DateTime)
update_time = Column(DateTime)
is_deleted = Column(Boolean, default=False)
# AI API 密钥表
class AIApiKey(CoreModel):
__tablename__ = 'ai_api_key'
name = Column(String(255), nullable=False)
platform = Column(String(100), nullable=False)
api_key = Column(String(255), nullable=False)
url = Column(String(255), nullable=True)
status = Column(Integer, default=CommonStatus.DISABLED)
def __str__(self):
return self.name
# AI 模型表
class AIModel(CoreModel):
__tablename__ = 'ai_model'
name = Column(String(64), nullable=False)
sort = Column(Integer, default=0)
status = Column(Integer, default=CommonStatus.DISABLED)
key_id = Column(Integer, ForeignKey('ai_api_key.id'), nullable=False)
model_type = Column(String(32), nullable=True)
platform = Column(String(32), nullable=False)
model = Column(String(64), nullable=False)
temperature = Column(Float, nullable=True)
max_tokens = Column(Integer, nullable=True)
max_contexts = Column(Integer, nullable=True)
key = relationship('AIApiKey', backref='models')
def __str__(self):
return self.name
# AI 工具表
# class Tool(CoreModel):
# __tablename__ = 'ai_tool'
# name = Column(String(128), nullable=False)
# description = Column(String(256), nullable=True)
# status = Column(Integer, default=0)
# def __str__(self):
# return self.name
# AI 知识库表
# class Knowledge(CoreModel):
# __tablename__ = 'ai_knowledge'
# name = Column(String(255), nullable=False)
# description = Column(Text, nullable=True)
# embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# embedding_model = Column(String(32), nullable=False)
# top_k = Column(Integer, default=0)
# similarity_threshold = Column(Float, nullable=False)
# status = Column(Integer, default=CommonStatus.DISABLED)
# embedding_model_rel = relationship('AIModel', backref='knowledges')
# documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
# segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
# roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
# def __str__(self):
# return self.name
# AI 知识库文档表
# class KnowledgeDocument(CoreModel):
# __tablename__ = 'ai_knowledge_document'
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
# name = Column(String(255), nullable=False)
# url = Column(String(1024), nullable=False)
# content = Column(Text, nullable=False)
# content_length = Column(Integer, nullable=False)
# tokens = Column(Integer, nullable=False)
# segment_max_tokens = Column(Integer, nullable=False)
# retrieval_count = Column(Integer, default=0)
# status = Column(Integer, default=CommonStatus.DISABLED)
# segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
# def __str__(self):
# return self.name
# AI 知识库分段表
# class KnowledgeSegment(CoreModel):
# __tablename__ = 'ai_knowledge_segment'
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
# document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
# content = Column(Text, nullable=False)
# content_length = Column(Integer, nullable=False)
# tokens = Column(Integer, nullable=False)
# vector_id = Column(String(100), nullable=True)
# retrieval_count = Column(Integer, default=0)
# status = Column(Integer, default=CommonStatus.DISABLED)
# def __str__(self):
# return f"Segment {self.id}"
# AI 聊天角色表
# class ChatRole(CoreModel):
# __tablename__ = 'ai_chat_role'
# name = Column(String(128), nullable=False)
# avatar = Column(String(256), nullable=False)
# description = Column(String(256), nullable=True)
# status = Column(Integer, default=CommonStatus.DISABLED)
# sort = Column(Integer, default=0)
# public_status = Column(Boolean, default=False)
# category = Column(String(32), nullable=True)
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# system_message = Column(String(1024), nullable=True)
# user_id = Column(
# Integer,
# ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
# nullable=True # 允许为空(如匿名角色)
# )
# user = relationship(DjangoUser, backref='chat_roles') # 正确DjangoUser 已定义并导入
# model = relationship('AIModel', backref='chat_roles')
# tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
# # conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
# # messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
# def __str__(self):
# return self.name
# AI 聊天对话表
# class ChatConversation(CoreModel):
# __tablename__ = 'ai_chat_conversation'
# title = Column(String(256), nullable=False)
# pinned = Column(Boolean, default=False)
# pinned_time = Column(DateTime, nullable=True)
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# model = Column(String(32), nullable=False)
# system_message = Column(String(1024), nullable=True)
# temperature = Column(Float, nullable=False)
# max_tokens = Column(Integer, nullable=False)
# max_contexts = Column(Integer, nullable=False)
# # user = relationship(DjangoUser, backref='conversations') # 正确DjangoUser 已定义并导入
# model_rel = relationship('AIModel', backref='conversations')
# messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
# def __str__(self):
# return self.title
# AI 聊天消息表
# class ChatMessage(CoreModel):
# __tablename__ = 'ai_chat_message'
# conversation_id = Column(Integer, nullable=False)
# # user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
# role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
# model = Column(String(32), nullable=False)
# model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
# type = Column(String(16), nullable=False)
# reply_id = Column(Integer, nullable=True)
# content = Column(String(2048), nullable=False)
# use_context = Column(Boolean, default=False)
# segment_ids = Column(String(2048), nullable=True)
# # user = relationship(DjangoUser, backref='messages') # 正确DjangoUser 已定义并导入
# model_rel = relationship('AIModel', backref='messages')
# def __str__(self):
# return self.content[:30]
# # 聊天角色与知识库的关联表
# class ChatRoleKnowledge(Base):
# __tablename__ = 'ai_chat_role_knowledge'
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
# knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
# # 聊天角色与工具的关联表
# class ChatRoleTool(Base):
# __tablename__ = 'ai_chat_role_tool'
# chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
# tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy import Column, Integer, String, DateTime, Boolean
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
@@ -9,8 +9,15 @@ class AuthToken(Base):
user_id = Column(Integer, nullable=False)
created = Column(DateTime)
class DjangoUser(Base):
__tablename__ = 'system_users'
id = Column(Integer, primary_key=True)
username = Column(String(150), nullable=False)
email = Column(String(254))
email = Column(String(254))
password = Column(String(128))
is_active = Column(Boolean, default=True)
is_staff = Column(Boolean, default=False)
is_superuser = Column(Boolean, default=False)
last_login = Column(DateTime)
date_joined = Column(DateTime)

View File

@@ -0,0 +1,13 @@
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate, AIApiKeyRead
from crud.ai_api_key import ai_api_key_crud
from routers.base import GenericRouter
# 继承通用路由基类传入参数即可生成所有CRUD接口
router = GenericRouter(
crud=ai_api_key_crud,
create_schema=AIApiKeyCreate,
update_schema=AIApiKeyUpdate,
read_schema=AIApiKeyRead,
prefix="/chat/api/ai-api-keys",
tags=["AI API密钥"]
)

101
chat/routers/base.py Normal file
View File

@@ -0,0 +1,101 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Generic, TypeVar, List
from db.session import get_db
from schemas.base import ReadSchemaType # 通用的响应模型基类
from crud.base import CRUDBase
# 泛型变量对应CRUD类、创建模型、更新模型、响应模型
CRUDType = TypeVar("CRUDType")
CreateSchemaType = TypeVar("CreateSchemaType")
UpdateSchemaType = TypeVar("UpdateSchemaType")
ReadSchemaType = TypeVar("ReadSchemaType")
class GenericRouter(
APIRouter,
Generic[CRUDType, CreateSchemaType, UpdateSchemaType, ReadSchemaType]
):
def __init__(
self,
crud: CRUDType,
create_schema: CreateSchemaType,
update_schema: UpdateSchemaType,
read_schema: ReadSchemaType,
prefix: str,
tags: List[str],
**kwargs
):
"""
初始化通用路由
:param crud: CRUD实例如CRUDApiKey
:param create_schema: 创建Pydantic模型
:param update_schema: 更新Pydantic模型
:param read_schema: 响应Pydantic模型
:param prefix: 路由前缀(如"/api/ai-api-keys"
:param tags: 文档标签
"""
super().__init__(prefix=prefix, tags=tags,** kwargs)
self.crud = crud
self.create_schema = create_schema
self.update_schema = update_schema
self.read_schema = read_schema
# 注册通用路由
self.add_api_route(
"/",
self.create,
methods=["POST"],
response_model=read_schema,
status_code=201
)
self.add_api_route(
"/",
self.get_multi,
methods=["GET"],
response_model=List[read_schema]
)
self.add_api_route(
"/{id}/",
self.get,
methods=["GET"],
response_model=read_schema
)
self.add_api_route(
"/{id}/",
self.update,
methods=["PUT"],
response_model=read_schema
)
self.add_api_route(
"/{id}/",
self.remove,
methods=["DELETE"]
)
# 创建
def create(self, obj_in: CreateSchemaType, db: Session = Depends(get_db)):
return self.crud.create(db=db, obj_in=obj_in)
# 按ID查询
def get(self, id: int, db: Session = Depends(get_db)):
obj = self.crud.get(db=db, id=id)
if not obj:
raise HTTPException(status_code=404, detail=f"记录不存在")
return obj
# 分页查询
def get_multi(self, page: int = 0, limit: int = 10, db: Session = Depends(get_db)):
return self.crud.get_multi(db=db, page=page, limit=limit)
# 更新
def update(self, id: int, obj_in: UpdateSchemaType, db: Session = Depends(get_db)):
obj = self.crud.get(db=db, id=id)
if not obj:
raise HTTPException(status_code=404, detail=f"记录不存在")
return self.crud.update(db=db, db_obj=obj, obj_in=obj_in)
# 删除
def remove(self, id: int, db: Session = Depends(get_db)):
return self.crud.remove(db=db, id=id)

View File

@@ -0,0 +1,33 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
# 基础模型(共享字段)
class AIApiKeyBase(BaseModel):
name: str = Field(..., max_length=255, description="密钥名称")
platform: str = Field(..., max_length=100, description="平台如openai")
api_key: str = Field(..., max_length=255, description="API密钥")
url: Optional[str] = Field(None, max_length=255, description="自定义API地址")
status: int = Field(0, description="状态0=禁用1=启用)")
# 创建请求模型无需ID和时间字段
class AIApiKeyCreate(AIApiKeyBase):
pass
# 更新请求模型(所有字段可选)
class AIApiKeyUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=255)
platform: Optional[str] = Field(None, max_length=100)
api_key: Optional[str] = Field(None, max_length=255)
url: Optional[str] = Field(None, max_length=255)
status: Optional[int] = None
# 响应模型(包含数据库自动生成的字段)
class AIApiKeyRead(AIApiKeyBase):
id: int
created_at: Optional[datetime]
updated_at: Optional[datetime]
# 支持ORM模型直接转换为响应
class Config:
from_attributes = True # Pydantic v2用from_attributesv1用orm_mode

19
chat/schemas/base.py Normal file
View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from datetime import datetime
from typing import Optional
class ReadSchemaType(BaseModel):
"""
所有响应模型的基类包含公共字段和ORM转换配置
"""
id: int
created_at: Optional[datetime] = None # 数据创建时间(可选,部分模型可能没有)
updated_at: Optional[datetime] = None # 数据更新时间(可选)
class Config:
"""
配置Pydantic模型如何处理ORM对象
- from_attributes=True支持直接从SQLAlchemy ORM模型转换Pydantic v2
- 若使用Pydantic v1需替换为 orm_mode=True
"""
from_attributes = True

View File

@@ -12,4 +12,4 @@ class ChatService:
# 简单调用LLM
return self.llm(prompt)
chat_service = ChatService()
chat_service = ChatService()

View File

@@ -15,5 +15,9 @@
"knowledge": {
"title": "KNOWLEDGE Management",
"name": "KNOWLEDGE Management"
},
"chat": {
"title": "AI CHAT",
"name": "AI CHAT"
}
}

View File

@@ -15,5 +15,9 @@
"knowledge": {
"title": "知识库管理",
"name": "知识库管理"
},
"chat": {
"title": "AI对话",
"name": "AI对话"
}
}

View File

@@ -0,0 +1,11 @@
<script setup lang="ts">
</script>
<template>
<div>dsads</div>
</template>
<style scoped lang="css">
</style>