修改docker-compose
This commit is contained in:
2
ai_service/.env.example
Normal file
2
ai_service/.env.example
Normal file
@@ -0,0 +1,2 @@
|
||||
OPENAI_API_KEY=你的API密钥
|
||||
DEEPSEEK_API_KEY='你的API密钥'
|
||||
24
ai_service/Dockerfile
Normal file
24
ai_service/Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
FROM python:3.12.2 AS base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
|
||||
# 入口命令由 docker-compose 控制
|
||||
# 默认命令,开发和生产通过 docker-compose 覆盖
|
||||
#CMD ["python", "manage.py", "runserver", "0.0.0.0:8000"]
|
||||
|
||||
FROM base AS dev
|
||||
|
||||
#CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
#CMD ["daphne", "backend.asgi:application"]
|
||||
|
||||
# CMD ["sh", "-c", "sleep 5 && python manage.py runserver 0.0.0.0:8000"]
|
||||
|
||||
|
||||
FROM base AS prod
|
||||
|
||||
CMD ["gunicorn", "main:app", "-k", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8010", "--workers", "4"]
|
||||
120
ai_service/api/v1/ai_chat.py
Normal file
120
ai_service/api/v1/ai_chat.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from langchain.chains import ConversationChain
|
||||
|
||||
from api.v1.vo import MessageVO
|
||||
from deps.auth import get_current_user
|
||||
from services.chat_service import ChatDBService
|
||||
from db.session import get_db
|
||||
from models.ai import ChatConversation, ChatMessage
|
||||
from utils.resp import resp_success, Response
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_deepseek_llm(api_key: SecretStr, model: str):
|
||||
# deepseek 兼容 OpenAI API,需指定 base_url
|
||||
return ChatDeepSeek(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@router.post('/stream')
|
||||
async def chat_stream(request: Request, user=Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
body = await request.json()
|
||||
content = body.get('content')
|
||||
conversation_id = body.get('conversation_id')
|
||||
model = 'deepseek-chat'
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
llm = get_deepseek_llm(SecretStr(api_key), model)
|
||||
|
||||
if not content or not isinstance(content, str):
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
||||
|
||||
user_id = user["user_id"]
|
||||
# 1. 获取对话
|
||||
try:
|
||||
conversation = ChatDBService.get_conversation(db, conversation_id)
|
||||
conversation = db.merge(conversation) # ✅ 防止 DetachedInstanceError
|
||||
except ValueError as e:
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse({"error": str(e)}, status_code=400)
|
||||
# 2. 插入当前消息
|
||||
ChatDBService.add_message(db, conversation, user_id, content)
|
||||
context = [
|
||||
("system", "You are a helpful assistant. Answer all questions to the best of your ability in {language}.")
|
||||
]
|
||||
# 3. 查询历史消息,组装上下文
|
||||
history = ChatDBService.get_history(db, conversation.id)
|
||||
# === 新增:如果只有一条消息,更新 title ===
|
||||
if len(history) == 1:
|
||||
ChatDBService.update_conversation_title(db, conversation.id, content[:255])
|
||||
|
||||
for msg in history:
|
||||
# 假设 msg.type 存储的是 'user' 或 'assistant'
|
||||
# role = msg.type if msg.type in ("user", "assistant") else "user"
|
||||
context.append((msg.type, msg.content))
|
||||
|
||||
ai_reply = ""
|
||||
async def event_generator():
|
||||
nonlocal ai_reply
|
||||
async for chunk in llm.astream(context):
|
||||
if hasattr(chunk, 'content'):
|
||||
ai_reply += chunk.content
|
||||
yield f"data: {chunk.content}\n\n"
|
||||
else:
|
||||
ai_reply += chunk
|
||||
yield f"data: {chunk}\n\n"
|
||||
await asyncio.sleep(0.01)
|
||||
# 只保留最新AI回复
|
||||
if ai_reply:
|
||||
ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
||||
|
||||
@router.post("/conversations")
|
||||
def create_conversation(db: Session = Depends(get_db), user=Depends(get_current_user),):
|
||||
user_id = user["user_id"]
|
||||
model = 'deepseek-chat'
|
||||
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
|
||||
return resp_success(data=conversation.id)
|
||||
|
||||
@router.get('/conversations')
|
||||
async def get_conversations(
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user)
|
||||
):
|
||||
"""获取当前用户的聊天对话列表,last_message为字符串"""
|
||||
user_id = user["user_id"]
|
||||
conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all()
|
||||
return resp_success(data=[
|
||||
{
|
||||
'id': c.id,
|
||||
'title': c.title,
|
||||
'update_time': c.update_time,
|
||||
'last_message': c.messages[-1].content if c.messages else None,
|
||||
}
|
||||
for c in conversations
|
||||
])
|
||||
|
||||
|
||||
@router.get('/messages', response_model=Response[List[MessageVO]])
|
||||
def get_messages(
|
||||
conversation_id: int = Query(...),
|
||||
db: Session = Depends(get_db),
|
||||
user=Depends(get_current_user)
|
||||
):
|
||||
"""获取指定会话的消息列表(当前用户)"""
|
||||
user_id = user["user_id"]
|
||||
query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id,
|
||||
ChatMessage.user_id == user_id).order_by(ChatMessage.id).all()
|
||||
return resp_success(data=query)
|
||||
|
||||
20
ai_service/api/v1/vo.py
Normal file
20
ai_service/api/v1/vo.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
class MessageVO(BaseModel):
|
||||
id: int
|
||||
content: str
|
||||
conversation_id: int
|
||||
type: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True # 启用ORM模式支持
|
||||
|
||||
class ConversationsVO(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
update_time: datetime
|
||||
last_message: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
12
ai_service/config.py
Normal file
12
ai_service/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import os
|
||||
|
||||
# 数据库配置
|
||||
MYSQL_USER = os.getenv('DB_USER', 'root')
|
||||
MYSQL_PASSWORD = os.getenv('DB_PASSWORD', 'my-secret-pw')
|
||||
MYSQL_HOST = os.getenv('DB_HOST', 'localhost')
|
||||
MYSQL_PORT = os.getenv('DB_PORT', '3306')
|
||||
MYSQL_DB = os.getenv('DB_NAME', 'django_vue')
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = (
|
||||
f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}?charset=utf8mb4"
|
||||
)
|
||||
90
ai_service/crud/base.py
Normal file
90
ai_service/crud/base.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Generic, TypeVar, List, Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
|
||||
# 定义泛型变量(分别对应:SQLAlchemy模型、创建Pydantic模型、更新Pydantic模型)
|
||||
ModelType = TypeVar("ModelType")
|
||||
CreateSchemaType = TypeVar("CreateSchemaType")
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType")
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def __init__(self, model: ModelType):
|
||||
"""
|
||||
初始化CRUD类,需要传入SQLAlchemy模型
|
||||
:param model: SQLAlchemy模型类(如AIApiKey、AIModel等)
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
# 创建
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""创建一条记录"""
|
||||
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
|
||||
|
||||
# 自动填充时间字段(如果模型有created_at/updated_at)
|
||||
if hasattr(self.model, "created_at"):
|
||||
obj_in_data["created_at"] = datetime.now()
|
||||
if hasattr(self.model, "updated_at"):
|
||||
obj_in_data["updated_at"] = datetime.now()
|
||||
|
||||
db_obj = self.model(**obj_in_data) # 实例化模型
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
# 按ID查询
|
||||
def get(self, db: Session, id: int) -> Optional[ModelType]:
|
||||
"""按ID查询单条记录"""
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
# 按条件查询单条记录
|
||||
def get_by(self, db: Session, **kwargs) -> Optional[ModelType]:
|
||||
"""按条件查询单条记录(如get_by(name="test"))"""
|
||||
return db.query(self.model).filter_by(**kwargs).first()
|
||||
|
||||
# 分页查询所有
|
||||
def get_multi(
|
||||
self, db: Session, *, page: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
"""分页查询多条记录"""
|
||||
return db.query(self.model).offset(page).limit(limit).all()
|
||||
|
||||
# 更新
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: UpdateSchemaType | Dict[str, Any]
|
||||
) -> ModelType:
|
||||
"""更新记录(支持Pydantic模型或字典)"""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True) # 只更新提供的字段
|
||||
|
||||
# 遍历更新字段
|
||||
for field in update_data:
|
||||
if hasattr(db_obj, field):
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
# 自动更新updated_at(如果模型有该字段)
|
||||
if hasattr(db_obj, "updated_at"):
|
||||
db_obj.updated_at = datetime.now()
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
# 删除
|
||||
def remove(self, db: Session, *, id: int) -> ModelType:
|
||||
"""删除记录"""
|
||||
obj = db.query(self.model).get(id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"{self.model.__name__}不存在")
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
15
ai_service/db/session.py
Normal file
15
ai_service/db/session.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from config import SQLALCHEMY_DATABASE_URL
|
||||
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Base = declarative_base()
|
||||
21
ai_service/deps/auth.py
Normal file
21
ai_service/deps/auth.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.session import get_db
|
||||
from models.user import AuthToken, DjangoUser
|
||||
|
||||
|
||||
def get_current_user(request: Request, db: Session = Depends(get_db)):
|
||||
auth = request.headers.get('Authorization')
|
||||
if not auth or not auth.startswith('Bearer '):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='未登录')
|
||||
|
||||
token = auth.split(' ', 1)[1]
|
||||
token_obj = db.query(AuthToken).filter(AuthToken.key == token).first()
|
||||
if not token_obj:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Token无效或已过期')
|
||||
|
||||
user = db.query(DjangoUser).filter(DjangoUser.id == token_obj.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='用户不存在')
|
||||
return {"user_id": user.id, "username": user.username, "email": user.email}
|
||||
31
ai_service/main.py
Normal file
31
ai_service/main.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from dotenv import load_dotenv
|
||||
from api.v1 import ai_chat
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# 加载.env环境变量,优先项目根目录
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = [
|
||||
"http://localhost",
|
||||
"http://localhost:8010",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"])
|
||||
|
||||
# 健康检查
|
||||
@app.get("/ping")
|
||||
def ping():
|
||||
return {"msg": "pong"}
|
||||
253
ai_service/models/ai.py
Normal file
253
ai_service/models/ai.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
||||
)
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
|
||||
from db.session import Base
|
||||
from models.base import CoreModel
|
||||
from models.user import DjangoUser # 确保导入 DjangoUser
|
||||
|
||||
# 状态选择类(示例)
|
||||
class CommonStatus:
|
||||
DISABLED = 0
|
||||
ENABLED = 1
|
||||
|
||||
@staticmethod
|
||||
def choices():
|
||||
return [(0, '禁用'), (1, '启用')]
|
||||
|
||||
|
||||
# 平台选择类(示例)
|
||||
class PlatformChoices:
|
||||
OPENAI = 'openai'
|
||||
ALIMNS = 'alimns'
|
||||
|
||||
@staticmethod
|
||||
def choices():
|
||||
return [('openai', 'OpenAI'), ('alimns', '阿里云MNS')]
|
||||
|
||||
|
||||
# 消息类型选择类(示例)
|
||||
class MessageType:
|
||||
SYSTEM = "system" # 系统指令
|
||||
USER = "user" # 用户消息
|
||||
ASSISTANT = "assistant" # 助手回复
|
||||
FUNCTION = "function" # 函数返回结果
|
||||
|
||||
@staticmethod
|
||||
def choices():
|
||||
"""返回可用的消息角色选项"""
|
||||
return [
|
||||
(MessageType.SYSTEM, "系统"),
|
||||
(MessageType.USER, "用户"),
|
||||
(MessageType.ASSISTANT, "助手"),
|
||||
(MessageType.FUNCTION, "函数")
|
||||
]
|
||||
|
||||
|
||||
class MessageContentType:
|
||||
"""消息内容类型"""
|
||||
TEXT = "text"
|
||||
FUNCTION_CALL = "function_call"
|
||||
|
||||
@staticmethod
|
||||
def choices():
|
||||
"""返回可用的内容类型选项"""
|
||||
return [
|
||||
(MessageContentType.TEXT, "文本"),
|
||||
(MessageContentType.FUNCTION_CALL, "函数调用")
|
||||
]
|
||||
|
||||
# AI API 密钥表
|
||||
class AIApiKey(CoreModel):
|
||||
__tablename__ = 'ai_api_key'
|
||||
|
||||
name = Column(String(255), nullable=False)
|
||||
platform = Column(String(100), nullable=False)
|
||||
api_key = Column(String(255), nullable=False)
|
||||
url = Column(String(255), nullable=True)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 模型表
|
||||
class AIModel(CoreModel):
|
||||
__tablename__ = 'ai_model'
|
||||
|
||||
name = Column(String(64), nullable=False)
|
||||
sort = Column(Integer, default=0)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
key_id = Column(Integer, ForeignKey('ai_api_key.id'), nullable=False)
|
||||
model_type = Column(String(32), nullable=True)
|
||||
platform = Column(String(32), nullable=False)
|
||||
model = Column(String(64), nullable=False)
|
||||
temperature = Column(Float, nullable=True)
|
||||
max_tokens = Column(Integer, nullable=True)
|
||||
max_contexts = Column(Integer, nullable=True)
|
||||
|
||||
key = relationship('AIApiKey', backref='models')
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 工具表
|
||||
class Tool(CoreModel):
|
||||
__tablename__ = 'ai_tool'
|
||||
|
||||
name = Column(String(128), nullable=False)
|
||||
description = Column(String(256), nullable=True)
|
||||
status = Column(Integer, default=0)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 知识库表
|
||||
class Knowledge(CoreModel):
|
||||
__tablename__ = 'ai_knowledge'
|
||||
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
embedding_model = Column(String(32), nullable=False)
|
||||
top_k = Column(Integer, default=0)
|
||||
similarity_threshold = Column(Float, nullable=False)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
|
||||
embedding_model_rel = relationship('AIModel', backref='knowledges')
|
||||
documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
|
||||
segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
|
||||
roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 知识库文档表
|
||||
class KnowledgeDocument(CoreModel):
|
||||
__tablename__ = 'ai_knowledge_document'
|
||||
|
||||
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||
name = Column(String(255), nullable=False)
|
||||
url = Column(String(1024), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
content_length = Column(Integer, nullable=False)
|
||||
tokens = Column(Integer, nullable=False)
|
||||
segment_max_tokens = Column(Integer, nullable=False)
|
||||
retrieval_count = Column(Integer, default=0)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
|
||||
segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# AI 知识库分段表
|
||||
class KnowledgeSegment(CoreModel):
|
||||
__tablename__ = 'ai_knowledge_segment'
|
||||
|
||||
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
|
||||
document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
content_length = Column(Integer, nullable=False)
|
||||
tokens = Column(Integer, nullable=False)
|
||||
vector_id = Column(String(100), nullable=True)
|
||||
retrieval_count = Column(Integer, default=0)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
|
||||
def __str__(self):
|
||||
return f"Segment {self.id}"
|
||||
|
||||
|
||||
# AI 聊天角色表
|
||||
class ChatRole(CoreModel):
|
||||
__tablename__ = 'ai_chat_role'
|
||||
|
||||
name = Column(String(128), nullable=False)
|
||||
avatar = Column(String(256), nullable=False)
|
||||
description = Column(String(256), nullable=True)
|
||||
status = Column(Integer, default=CommonStatus.DISABLED)
|
||||
sort = Column(Integer, default=0)
|
||||
public_status = Column(Boolean, default=False)
|
||||
category = Column(String(32), nullable=True)
|
||||
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
system_message = Column(String(1024), nullable=True)
|
||||
user_id = Column(
|
||||
Integer,
|
||||
ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
|
||||
nullable=True # 允许为空(如匿名角色)
|
||||
)
|
||||
user = relationship(DjangoUser, backref='chat_roles') # 正确:DjangoUser 已定义并导入
|
||||
|
||||
model = relationship('AIModel', backref='chat_roles')
|
||||
tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
|
||||
# conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
|
||||
# messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
# AI 聊天对话表
|
||||
class ChatConversation(CoreModel):
|
||||
__tablename__ = 'ai_chat_conversation'
|
||||
|
||||
title = Column(String(256), nullable=False)
|
||||
pinned = Column(Boolean, default=False)
|
||||
pinned_time = Column(DateTime, nullable=True)
|
||||
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
model = Column(String(32), nullable=False)
|
||||
system_message = Column(String(1024), nullable=True)
|
||||
temperature = Column(Float, nullable=False)
|
||||
max_tokens = Column(Integer, nullable=False)
|
||||
max_contexts = Column(Integer, nullable=False)
|
||||
user = relationship(DjangoUser, backref='conversations') # 正确:DjangoUser 已定义并导入
|
||||
|
||||
model_rel = relationship('AIModel', backref='conversations')
|
||||
messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
|
||||
|
||||
# AI 聊天消息表
|
||||
class ChatMessage(CoreModel):
|
||||
__tablename__ = 'ai_chat_message'
|
||||
|
||||
conversation_id = Column(Integer, ForeignKey('ai_chat_conversation.id'), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
|
||||
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
|
||||
model = Column(String(32), nullable=False)
|
||||
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
|
||||
type = Column(String(16), nullable=False)
|
||||
reply_id = Column(Integer, nullable=True)
|
||||
content = Column(String(2048), nullable=False)
|
||||
use_context = Column(Boolean, default=False)
|
||||
segment_ids = Column(String(2048), nullable=True)
|
||||
|
||||
user = relationship(DjangoUser, backref='messages') # 正确:DjangoUser 已定义并导入
|
||||
model_rel = relationship('AIModel', backref='messages')
|
||||
|
||||
def __str__(self):
|
||||
return self.content[:30]
|
||||
|
||||
|
||||
# 聊天角色与知识库的关联表
|
||||
class ChatRoleKnowledge(Base):
|
||||
__tablename__ = 'ai_chat_role_knowledge'
|
||||
|
||||
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
|
||||
|
||||
|
||||
# 聊天角色与工具的关联表
|
||||
class ChatRoleTool(Base):
|
||||
__tablename__ = 'ai_chat_role_tool'
|
||||
|
||||
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
|
||||
tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)
|
||||
13
ai_service/models/base.py
Normal file
13
ai_service/models/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from db.session import Base
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
|
||||
)
|
||||
|
||||
# 基础模型类
|
||||
class CoreModel(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
create_time = Column(DateTime)
|
||||
update_time = Column(DateTime)
|
||||
is_deleted = Column(Boolean, default=False)
|
||||
24
ai_service/models/user.py
Normal file
24
ai_service/models/user.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
||||
|
||||
from db.session import Base
|
||||
|
||||
|
||||
class AuthToken(Base):
|
||||
__tablename__ = 'authtoken_token'
|
||||
key = Column(String(40), primary_key=True)
|
||||
user_id = Column(Integer, nullable=False)
|
||||
created = Column(DateTime)
|
||||
|
||||
|
||||
class DjangoUser(Base):
|
||||
__tablename__ = 'system_users'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
username = Column(String(150), nullable=False)
|
||||
email = Column(String(254))
|
||||
password = Column(String(128))
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_staff = Column(Boolean, default=False)
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
last_login = Column(DateTime)
|
||||
date_joined = Column(DateTime)
|
||||
8
ai_service/requirements.txt
Normal file
8
ai_service/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.116.1
|
||||
uvicorn[standard]==0.35.0
|
||||
langchain-openai==0.3.28
|
||||
langchain-deepseek==0.1.3
|
||||
langchain==0.3.26
|
||||
langchain-community==0.3.26
|
||||
PyMySQL==1.1.1
|
||||
SQLAlchemy==2.0.41
|
||||
101
ai_service/routers/base.py
Normal file
101
ai_service/routers/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Generic, TypeVar, List
|
||||
|
||||
from db.session import get_db
|
||||
from schemas.base import ReadSchemaType # 通用的响应模型基类
|
||||
from crud.base import CRUDBase
|
||||
|
||||
# 泛型变量(对应:CRUD类、创建模型、更新模型、响应模型)
|
||||
CRUDType = TypeVar("CRUDType")
|
||||
CreateSchemaType = TypeVar("CreateSchemaType")
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType")
|
||||
ReadSchemaType = TypeVar("ReadSchemaType")
|
||||
|
||||
|
||||
class GenericRouter(
|
||||
APIRouter,
|
||||
Generic[CRUDType, CreateSchemaType, UpdateSchemaType, ReadSchemaType]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
crud: CRUDType,
|
||||
create_schema: CreateSchemaType,
|
||||
update_schema: UpdateSchemaType,
|
||||
read_schema: ReadSchemaType,
|
||||
prefix: str,
|
||||
tags: List[str],
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
初始化通用路由
|
||||
:param crud: CRUD实例(如CRUDApiKey)
|
||||
:param create_schema: 创建Pydantic模型
|
||||
:param update_schema: 更新Pydantic模型
|
||||
:param read_schema: 响应Pydantic模型
|
||||
:param prefix: 路由前缀(如"/api/ai-api-keys")
|
||||
:param tags: 文档标签
|
||||
"""
|
||||
super().__init__(prefix=prefix, tags=tags,** kwargs)
|
||||
self.crud = crud
|
||||
self.create_schema = create_schema
|
||||
self.update_schema = update_schema
|
||||
self.read_schema = read_schema
|
||||
|
||||
# 注册通用路由
|
||||
self.add_api_route(
|
||||
"/",
|
||||
self.create,
|
||||
methods=["POST"],
|
||||
response_model=read_schema,
|
||||
status_code=201
|
||||
)
|
||||
self.add_api_route(
|
||||
"/",
|
||||
self.get_multi,
|
||||
methods=["GET"],
|
||||
response_model=List[read_schema]
|
||||
)
|
||||
self.add_api_route(
|
||||
"/{id}/",
|
||||
self.get,
|
||||
methods=["GET"],
|
||||
response_model=read_schema
|
||||
)
|
||||
self.add_api_route(
|
||||
"/{id}/",
|
||||
self.update,
|
||||
methods=["PUT"],
|
||||
response_model=read_schema
|
||||
)
|
||||
self.add_api_route(
|
||||
"/{id}/",
|
||||
self.remove,
|
||||
methods=["DELETE"]
|
||||
)
|
||||
|
||||
# 创建
|
||||
def create(self, obj_in: CreateSchemaType, db: Session = Depends(get_db)):
|
||||
return self.crud.create(db=db, obj_in=obj_in)
|
||||
|
||||
# 按ID查询
|
||||
def get(self, id: int, db: Session = Depends(get_db)):
|
||||
obj = self.crud.get(db=db, id=id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"记录不存在")
|
||||
return obj
|
||||
|
||||
# 分页查询
|
||||
def get_multi(self, page: int = 0, limit: int = 10, db: Session = Depends(get_db)):
|
||||
return self.crud.get_multi(db=db, page=page, limit=limit)
|
||||
|
||||
# 更新
|
||||
def update(self, id: int, obj_in: UpdateSchemaType, db: Session = Depends(get_db)):
|
||||
obj = self.crud.get(db=db, id=id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"记录不存在")
|
||||
return self.crud.update(db=db, db_obj=obj, obj_in=obj_in)
|
||||
|
||||
# 删除
|
||||
def remove(self, id: int, db: Session = Depends(get_db)):
|
||||
return self.crud.remove(db=db, id=id)
|
||||
33
ai_service/schemas/ai_api_key.py
Normal file
33
ai_service/schemas/ai_api_key.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
# 基础模型(共享字段)
|
||||
class AIApiKeyBase(BaseModel):
|
||||
name: str = Field(..., max_length=255, description="密钥名称")
|
||||
platform: str = Field(..., max_length=100, description="平台(如openai)")
|
||||
api_key: str = Field(..., max_length=255, description="API密钥")
|
||||
url: Optional[str] = Field(None, max_length=255, description="自定义API地址")
|
||||
status: int = Field(0, description="状态(0=禁用,1=启用)")
|
||||
|
||||
# 创建请求模型(无需ID和时间字段)
|
||||
class AIApiKeyCreate(AIApiKeyBase):
|
||||
pass
|
||||
|
||||
# 更新请求模型(所有字段可选)
|
||||
class AIApiKeyUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, max_length=255)
|
||||
platform: Optional[str] = Field(None, max_length=100)
|
||||
api_key: Optional[str] = Field(None, max_length=255)
|
||||
url: Optional[str] = Field(None, max_length=255)
|
||||
status: Optional[int] = None
|
||||
|
||||
# 响应模型(包含数据库自动生成的字段)
|
||||
class AIApiKeyRead(AIApiKeyBase):
|
||||
id: int
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
# 支持ORM模型直接转换为响应
|
||||
class Config:
|
||||
from_attributes = True # Pydantic v2用from_attributes,v1用orm_mode
|
||||
10
ai_service/schemas/ai_chat.py
Normal file
10
ai_service/schemas/ai_chat.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ChatCreate(BaseModel):
|
||||
pass
|
||||
|
||||
class Chat(ChatCreate):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
19
ai_service/schemas/base.py
Normal file
19
ai_service/schemas/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
class ReadSchemaType(BaseModel):
|
||||
"""
|
||||
所有响应模型的基类,包含公共字段和ORM转换配置
|
||||
"""
|
||||
id: int
|
||||
created_at: Optional[datetime] = None # 数据创建时间(可选,部分模型可能没有)
|
||||
updated_at: Optional[datetime] = None # 数据更新时间(可选)
|
||||
|
||||
class Config:
|
||||
"""
|
||||
配置Pydantic模型如何处理ORM对象:
|
||||
- from_attributes=True:支持直接从SQLAlchemy ORM模型转换(Pydantic v2)
|
||||
- 若使用Pydantic v1,需替换为 orm_mode=True
|
||||
"""
|
||||
from_attributes = True
|
||||
9
ai_service/schemas/user.py
Normal file
9
ai_service/schemas/user.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class UserOut(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
email: str = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
95
ai_service/services/chat_service.py
Normal file
95
ai_service/services/chat_service.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# LangChain集成示例
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime
|
||||
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||
|
||||
class ChatDBService:
|
||||
@staticmethod
|
||||
def get_conversation(db: Session, conversation_id: int):
|
||||
return db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
|
||||
if not conversation_id:
|
||||
conversation = ChatConversation(
|
||||
title=content,
|
||||
user_id=user_id,
|
||||
role_id=None,
|
||||
model_id=None, # 需根据实际模型id调整
|
||||
model=model,
|
||||
system_message=None,
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
max_contexts=10,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now(),
|
||||
is_deleted=False
|
||||
)
|
||||
db.add(conversation)
|
||||
db.commit()
|
||||
db.refresh(conversation)
|
||||
return conversation
|
||||
else:
|
||||
conversation = db.query(ChatConversation).get(conversation_id)
|
||||
if not conversation:
|
||||
raise ValueError("无效的conversation_id")
|
||||
return conversation
|
||||
|
||||
@staticmethod
|
||||
def update_conversation_title(db, conversation_id: int, title: str):
|
||||
conversation = db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
|
||||
if conversation:
|
||||
conversation.title = title[:255] # 保证不超过255字符
|
||||
db.add(conversation)
|
||||
db.commit()
|
||||
return conversation
|
||||
else:
|
||||
raise ValueError("Conversation not found")
|
||||
|
||||
@staticmethod
|
||||
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
|
||||
message = ChatMessage(
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_id,
|
||||
role_id=None,
|
||||
model=conversation.model,
|
||||
model_id=conversation.model_id,
|
||||
type=MessageType.USER,
|
||||
reply_id=None,
|
||||
content=content,
|
||||
use_context=True,
|
||||
segment_ids=None,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now(),
|
||||
is_deleted=False
|
||||
)
|
||||
db.add(message)
|
||||
db.commit()
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def insert_ai_message(db: Session, conversation, user_id: int, content: str, model: str):
|
||||
from datetime import datetime
|
||||
from models.ai import MessageType
|
||||
message = ChatMessage(
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_id,
|
||||
role_id=None,
|
||||
model=model,
|
||||
model_id=conversation.model_id,
|
||||
type=MessageType.ASSISTANT,
|
||||
reply_id=None,
|
||||
content=content,
|
||||
use_context=True,
|
||||
segment_ids=None,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now(),
|
||||
is_deleted=False
|
||||
)
|
||||
db.add(message)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
|
||||
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()
|
||||
|
||||
1
ai_service/utils/jwt.py
Normal file
1
ai_service/utils/jwt.py
Normal file
@@ -0,0 +1 @@
|
||||
# 预留:如需JWT校验可在此实现
|
||||
16
ai_service/utils/resp.py
Normal file
16
ai_service/utils/resp.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Generic, TypeVar, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class Response(BaseModel, Generic[T]):
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[T] = None # ✅ 明确 data 可为 None
|
||||
|
||||
def resp_success(data: T, message: str = "success") -> Response[T]:
|
||||
return Response(code=0, message=message, data=data)
|
||||
|
||||
def resp_error(message="error", code=1) -> Response[T]:
|
||||
return Response(code=code, message=message, data=None)
|
||||
Reference in New Issue
Block a user