修改docker-compose

This commit is contained in:
XIE7654
2025-07-18 22:14:37 +08:00
parent 5aaf78ae6c
commit b8bdb5d206
36 changed files with 739 additions and 59 deletions

2
ai_service/.env.example Normal file
View File

@@ -0,0 +1,2 @@
OPENAI_API_KEY=你的API密钥
DEEPSEEK_API_KEY='你的API密钥'

24
ai_service/Dockerfile Normal file
View File

@@ -0,0 +1,24 @@
# syntax=docker/dockerfile:1
FROM python:3.12.2 AS base
WORKDIR /app
COPY . .
RUN pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
# 入口命令由 docker-compose 控制
# 默认命令,开发和生产通过 docker-compose 覆盖
#CMD ["python", "manage.py", "runserver", "0.0.0.0:8000"]
FROM base AS dev
#CMD ["tail", "-f", "/dev/null"]
#CMD ["daphne", "backend.asgi:application"]
# CMD ["sh", "-c", "sleep 5 && python manage.py runserver 0.0.0.0:8000"]
FROM base AS prod
CMD ["gunicorn", "main:app", "-k", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0:8010", "--workers", "4"]

View File

@@ -0,0 +1,120 @@
import os
import asyncio
from fastapi import APIRouter, Depends, Request, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from typing import List
from pydantic import BaseModel, SecretStr
from langchain.chains import ConversationChain
from api.v1.vo import MessageVO
from deps.auth import get_current_user
from services.chat_service import ChatDBService
from db.session import get_db
from models.ai import ChatConversation, ChatMessage
from utils.resp import resp_success, Response
from langchain_deepseek import ChatDeepSeek
router = APIRouter()
def get_deepseek_llm(api_key: SecretStr, model: str):
# deepseek 兼容 OpenAI API需指定 base_url
return ChatDeepSeek(
api_key=api_key,
model=model,
streaming=True,
)
@router.post('/stream')
async def chat_stream(request: Request, user=Depends(get_current_user), db: Session = Depends(get_db)):
body = await request.json()
content = body.get('content')
conversation_id = body.get('conversation_id')
model = 'deepseek-chat'
api_key = os.getenv("DEEPSEEK_API_KEY")
llm = get_deepseek_llm(SecretStr(api_key), model)
if not content or not isinstance(content, str):
from fastapi.responses import JSONResponse
return JSONResponse({"error": "content不能为空"}, status_code=400)
user_id = user["user_id"]
# 1. 获取对话
try:
conversation = ChatDBService.get_conversation(db, conversation_id)
conversation = db.merge(conversation) # ✅ 防止 DetachedInstanceError
except ValueError as e:
from fastapi.responses import JSONResponse
return JSONResponse({"error": str(e)}, status_code=400)
# 2. 插入当前消息
ChatDBService.add_message(db, conversation, user_id, content)
context = [
("system", "You are a helpful assistant. Answer all questions to the best of your ability in {language}.")
]
# 3. 查询历史消息,组装上下文
history = ChatDBService.get_history(db, conversation.id)
# === 新增:如果只有一条消息,更新 title ===
if len(history) == 1:
ChatDBService.update_conversation_title(db, conversation.id, content[:255])
for msg in history:
# 假设 msg.type 存储的是 'user' 或 'assistant'
# role = msg.type if msg.type in ("user", "assistant") else "user"
context.append((msg.type, msg.content))
ai_reply = ""
async def event_generator():
nonlocal ai_reply
async for chunk in llm.astream(context):
if hasattr(chunk, 'content'):
ai_reply += chunk.content
yield f"data: {chunk.content}\n\n"
else:
ai_reply += chunk
yield f"data: {chunk}\n\n"
await asyncio.sleep(0.01)
# 只保留最新AI回复
if ai_reply:
ChatDBService.insert_ai_message(db, conversation, user_id, ai_reply, model)
return StreamingResponse(event_generator(), media_type='text/event-stream')
@router.post("/conversations")
def create_conversation(db: Session = Depends(get_db), user=Depends(get_current_user),):
user_id = user["user_id"]
model = 'deepseek-chat'
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
return resp_success(data=conversation.id)
@router.get('/conversations')
async def get_conversations(
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
"""获取当前用户的聊天对话列表last_message为字符串"""
user_id = user["user_id"]
conversations = db.query(ChatConversation).filter(ChatConversation.user_id == user_id).order_by(ChatConversation.update_time.desc()).all()
return resp_success(data=[
{
'id': c.id,
'title': c.title,
'update_time': c.update_time,
'last_message': c.messages[-1].content if c.messages else None,
}
for c in conversations
])
@router.get('/messages', response_model=Response[List[MessageVO]])
def get_messages(
conversation_id: int = Query(...),
db: Session = Depends(get_db),
user=Depends(get_current_user)
):
"""获取指定会话的消息列表(当前用户)"""
user_id = user["user_id"]
query = db.query(ChatMessage).filter(ChatMessage.conversation_id == conversation_id,
ChatMessage.user_id == user_id).order_by(ChatMessage.id).all()
return resp_success(data=query)

20
ai_service/api/v1/vo.py Normal file
View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel
from datetime import datetime
class MessageVO(BaseModel):
id: int
content: str
conversation_id: int
type: str
class Config:
from_attributes = True # 启用ORM模式支持
class ConversationsVO(BaseModel):
id: int
title: str
update_time: datetime
last_message: str | None = None
class Config:
from_attributes = True

12
ai_service/config.py Normal file
View File

@@ -0,0 +1,12 @@
import os
# 数据库配置
MYSQL_USER = os.getenv('DB_USER', 'root')
MYSQL_PASSWORD = os.getenv('DB_PASSWORD', 'my-secret-pw')
MYSQL_HOST = os.getenv('DB_HOST', 'localhost')
MYSQL_PORT = os.getenv('DB_PORT', '3306')
MYSQL_DB = os.getenv('DB_NAME', 'django_vue')
SQLALCHEMY_DATABASE_URL = (
f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}?charset=utf8mb4"
)

90
ai_service/crud/base.py Normal file
View File

@@ -0,0 +1,90 @@
from typing import Generic, TypeVar, List, Optional, Dict, Any
from fastapi import HTTPException
from sqlalchemy.orm import Session
from datetime import datetime
# 定义泛型变量分别对应SQLAlchemy模型、创建Pydantic模型、更新Pydantic模型
ModelType = TypeVar("ModelType")
CreateSchemaType = TypeVar("CreateSchemaType")
UpdateSchemaType = TypeVar("UpdateSchemaType")
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: ModelType):
"""
初始化CRUD类需要传入SQLAlchemy模型
:param model: SQLAlchemy模型类如AIApiKey、AIModel等
"""
self.model = model
# 创建
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
"""创建一条记录"""
obj_in_data = obj_in.model_dump() # 解构Pydantic模型为字典
# 自动填充时间字段如果模型有created_at/updated_at
if hasattr(self.model, "created_at"):
obj_in_data["created_at"] = datetime.now()
if hasattr(self.model, "updated_at"):
obj_in_data["updated_at"] = datetime.now()
db_obj = self.model(**obj_in_data) # 实例化模型
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
# 按ID查询
def get(self, db: Session, id: int) -> Optional[ModelType]:
"""按ID查询单条记录"""
return db.query(self.model).filter(self.model.id == id).first()
# 按条件查询单条记录
def get_by(self, db: Session, **kwargs) -> Optional[ModelType]:
"""按条件查询单条记录如get_by(name="test")"""
return db.query(self.model).filter_by(**kwargs).first()
# 分页查询所有
def get_multi(
self, db: Session, *, page: int = 0, limit: int = 100
) -> List[ModelType]:
"""分页查询多条记录"""
return db.query(self.model).offset(page).limit(limit).all()
# 更新
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | Dict[str, Any]
) -> ModelType:
"""更新记录支持Pydantic模型或字典"""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True) # 只更新提供的字段
# 遍历更新字段
for field in update_data:
if hasattr(db_obj, field):
setattr(db_obj, field, update_data[field])
# 自动更新updated_at如果模型有该字段
if hasattr(db_obj, "updated_at"):
db_obj.updated_at = datetime.now()
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
# 删除
def remove(self, db: Session, *, id: int) -> ModelType:
"""删除记录"""
obj = db.query(self.model).get(id)
if not obj:
raise HTTPException(status_code=404, detail=f"{self.model.__name__}不存在")
db.delete(obj)
db.commit()
return obj

15
ai_service/db/session.py Normal file
View File

@@ -0,0 +1,15 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from config import SQLALCHEMY_DATABASE_URL
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
Base = declarative_base()

21
ai_service/deps/auth.py Normal file
View File

@@ -0,0 +1,21 @@
from fastapi import Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from db.session import get_db
from models.user import AuthToken, DjangoUser
def get_current_user(request: Request, db: Session = Depends(get_db)):
auth = request.headers.get('Authorization')
if not auth or not auth.startswith('Bearer '):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='未登录')
token = auth.split(' ', 1)[1]
token_obj = db.query(AuthToken).filter(AuthToken.key == token).first()
if not token_obj:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Token无效或已过期')
user = db.query(DjangoUser).filter(DjangoUser.id == token_obj.user_id).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='用户不存在')
return {"user_id": user.id, "username": user.username, "email": user.email}

31
ai_service/main.py Normal file
View File

@@ -0,0 +1,31 @@
import os
from fastapi import FastAPI
from dotenv import load_dotenv
from api.v1 import ai_chat
from fastapi.middleware.cors import CORSMiddleware
# 加载.env环境变量优先项目根目录
load_dotenv()
app = FastAPI()
origins = [
"http://localhost",
"http://localhost:8010",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(ai_chat.router, prefix="/api/ai/v1", tags=["chat"])
# 健康检查
@app.get("/ping")
def ping():
return {"msg": "pong"}

253
ai_service/models/ai.py Normal file
View File

@@ -0,0 +1,253 @@
from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
)
from sqlalchemy.orm import relationship, declarative_base
from db.session import Base
from models.base import CoreModel
from models.user import DjangoUser # 确保导入 DjangoUser
# 状态选择类(示例)
class CommonStatus:
DISABLED = 0
ENABLED = 1
@staticmethod
def choices():
return [(0, '禁用'), (1, '启用')]
# 平台选择类(示例)
class PlatformChoices:
OPENAI = 'openai'
ALIMNS = 'alimns'
@staticmethod
def choices():
return [('openai', 'OpenAI'), ('alimns', '阿里云MNS')]
# 消息类型选择类(示例)
class MessageType:
SYSTEM = "system" # 系统指令
USER = "user" # 用户消息
ASSISTANT = "assistant" # 助手回复
FUNCTION = "function" # 函数返回结果
@staticmethod
def choices():
"""返回可用的消息角色选项"""
return [
(MessageType.SYSTEM, "系统"),
(MessageType.USER, "用户"),
(MessageType.ASSISTANT, "助手"),
(MessageType.FUNCTION, "函数")
]
class MessageContentType:
"""消息内容类型"""
TEXT = "text"
FUNCTION_CALL = "function_call"
@staticmethod
def choices():
"""返回可用的内容类型选项"""
return [
(MessageContentType.TEXT, "文本"),
(MessageContentType.FUNCTION_CALL, "函数调用")
]
# AI API 密钥表
class AIApiKey(CoreModel):
__tablename__ = 'ai_api_key'
name = Column(String(255), nullable=False)
platform = Column(String(100), nullable=False)
api_key = Column(String(255), nullable=False)
url = Column(String(255), nullable=True)
status = Column(Integer, default=CommonStatus.DISABLED)
def __str__(self):
return self.name
# AI 模型表
class AIModel(CoreModel):
__tablename__ = 'ai_model'
name = Column(String(64), nullable=False)
sort = Column(Integer, default=0)
status = Column(Integer, default=CommonStatus.DISABLED)
key_id = Column(Integer, ForeignKey('ai_api_key.id'), nullable=False)
model_type = Column(String(32), nullable=True)
platform = Column(String(32), nullable=False)
model = Column(String(64), nullable=False)
temperature = Column(Float, nullable=True)
max_tokens = Column(Integer, nullable=True)
max_contexts = Column(Integer, nullable=True)
key = relationship('AIApiKey', backref='models')
def __str__(self):
return self.name
# AI 工具表
class Tool(CoreModel):
__tablename__ = 'ai_tool'
name = Column(String(128), nullable=False)
description = Column(String(256), nullable=True)
status = Column(Integer, default=0)
def __str__(self):
return self.name
# AI 知识库表
class Knowledge(CoreModel):
__tablename__ = 'ai_knowledge'
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
embedding_model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
embedding_model = Column(String(32), nullable=False)
top_k = Column(Integer, default=0)
similarity_threshold = Column(Float, nullable=False)
status = Column(Integer, default=CommonStatus.DISABLED)
embedding_model_rel = relationship('AIModel', backref='knowledges')
documents = relationship('KnowledgeDocument', backref='knowledge', cascade='all, delete-orphan')
segments = relationship('KnowledgeSegment', backref='knowledge', cascade='all, delete-orphan')
roles = relationship('ChatRole', secondary='ai_chat_role_knowledge', backref='knowledges')
def __str__(self):
return self.name
# AI 知识库文档表
class KnowledgeDocument(CoreModel):
__tablename__ = 'ai_knowledge_document'
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
name = Column(String(255), nullable=False)
url = Column(String(1024), nullable=False)
content = Column(Text, nullable=False)
content_length = Column(Integer, nullable=False)
tokens = Column(Integer, nullable=False)
segment_max_tokens = Column(Integer, nullable=False)
retrieval_count = Column(Integer, default=0)
status = Column(Integer, default=CommonStatus.DISABLED)
segments = relationship('KnowledgeSegment', backref='document', cascade='all, delete-orphan')
def __str__(self):
return self.name
# AI 知识库分段表
class KnowledgeSegment(CoreModel):
__tablename__ = 'ai_knowledge_segment'
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), nullable=False)
document_id = Column(Integer, ForeignKey('ai_knowledge_document.id'), nullable=False)
content = Column(Text, nullable=False)
content_length = Column(Integer, nullable=False)
tokens = Column(Integer, nullable=False)
vector_id = Column(String(100), nullable=True)
retrieval_count = Column(Integer, default=0)
status = Column(Integer, default=CommonStatus.DISABLED)
def __str__(self):
return f"Segment {self.id}"
# AI 聊天角色表
class ChatRole(CoreModel):
__tablename__ = 'ai_chat_role'
name = Column(String(128), nullable=False)
avatar = Column(String(256), nullable=False)
description = Column(String(256), nullable=True)
status = Column(Integer, default=CommonStatus.DISABLED)
sort = Column(Integer, default=0)
public_status = Column(Boolean, default=False)
category = Column(String(32), nullable=True)
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
system_message = Column(String(1024), nullable=True)
user_id = Column(
Integer,
ForeignKey('system_users.id'), # 假设DjangoUser表名是system_users
nullable=True # 允许为空(如匿名角色)
)
user = relationship(DjangoUser, backref='chat_roles') # 正确DjangoUser 已定义并导入
model = relationship('AIModel', backref='chat_roles')
tools = relationship('Tool', secondary='ai_chat_role_tool', backref='roles')
# conversations = relationship('ChatConversation', backref='role', cascade='all, delete-orphan')
# messages = relationship('ChatMessage', backref='role', cascade='all, delete-orphan')
def __str__(self):
return self.name
# AI 聊天对话表
class ChatConversation(CoreModel):
__tablename__ = 'ai_chat_conversation'
title = Column(String(256), nullable=False)
pinned = Column(Boolean, default=False)
pinned_time = Column(DateTime, nullable=True)
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
model = Column(String(32), nullable=False)
system_message = Column(String(1024), nullable=True)
temperature = Column(Float, nullable=False)
max_tokens = Column(Integer, nullable=False)
max_contexts = Column(Integer, nullable=False)
user = relationship(DjangoUser, backref='conversations') # 正确DjangoUser 已定义并导入
model_rel = relationship('AIModel', backref='conversations')
messages = relationship('ChatMessage', backref='conversation', cascade='all, delete-orphan')
def __str__(self):
return self.title
# AI 聊天消息表
class ChatMessage(CoreModel):
__tablename__ = 'ai_chat_message'
conversation_id = Column(Integer, ForeignKey('ai_chat_conversation.id'), nullable=False)
user_id = Column(Integer, ForeignKey('system_users.id'), nullable=True)
role_id = Column(Integer, ForeignKey('ai_chat_role.id'), nullable=True)
model = Column(String(32), nullable=False)
model_id = Column(Integer, ForeignKey('ai_model.id'), nullable=False)
type = Column(String(16), nullable=False)
reply_id = Column(Integer, nullable=True)
content = Column(String(2048), nullable=False)
use_context = Column(Boolean, default=False)
segment_ids = Column(String(2048), nullable=True)
user = relationship(DjangoUser, backref='messages') # 正确DjangoUser 已定义并导入
model_rel = relationship('AIModel', backref='messages')
def __str__(self):
return self.content[:30]
# 聊天角色与知识库的关联表
class ChatRoleKnowledge(Base):
__tablename__ = 'ai_chat_role_knowledge'
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
knowledge_id = Column(Integer, ForeignKey('ai_knowledge.id'), primary_key=True)
# 聊天角色与工具的关联表
class ChatRoleTool(Base):
__tablename__ = 'ai_chat_role_tool'
chat_role_id = Column(Integer, ForeignKey('ai_chat_role.id'), primary_key=True)
tool_id = Column(Integer, ForeignKey('ai_tool.id'), primary_key=True)

13
ai_service/models/base.py Normal file
View File

@@ -0,0 +1,13 @@
from db.session import Base
from sqlalchemy import (
Column, Integer, String, Text, DateTime, Boolean, Float, ForeignKey
)
# 基础模型类
class CoreModel(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
create_time = Column(DateTime)
update_time = Column(DateTime)
is_deleted = Column(Boolean, default=False)

24
ai_service/models/user.py Normal file
View File

@@ -0,0 +1,24 @@
from sqlalchemy import Column, Integer, String, DateTime, Boolean
from db.session import Base
class AuthToken(Base):
__tablename__ = 'authtoken_token'
key = Column(String(40), primary_key=True)
user_id = Column(Integer, nullable=False)
created = Column(DateTime)
class DjangoUser(Base):
__tablename__ = 'system_users'
id = Column(Integer, primary_key=True)
username = Column(String(150), nullable=False)
email = Column(String(254))
password = Column(String(128))
is_active = Column(Boolean, default=True)
is_staff = Column(Boolean, default=False)
is_superuser = Column(Boolean, default=False)
last_login = Column(DateTime)
date_joined = Column(DateTime)

View File

@@ -0,0 +1,8 @@
fastapi==0.116.1
uvicorn[standard]==0.35.0
langchain-openai==0.3.28
langchain-deepseek==0.1.3
langchain==0.3.26
langchain-community==0.3.26
PyMySQL==1.1.1
SQLAlchemy==2.0.41

101
ai_service/routers/base.py Normal file
View File

@@ -0,0 +1,101 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Generic, TypeVar, List
from db.session import get_db
from schemas.base import ReadSchemaType # 通用的响应模型基类
from crud.base import CRUDBase
# 泛型变量对应CRUD类、创建模型、更新模型、响应模型
CRUDType = TypeVar("CRUDType")
CreateSchemaType = TypeVar("CreateSchemaType")
UpdateSchemaType = TypeVar("UpdateSchemaType")
ReadSchemaType = TypeVar("ReadSchemaType")
class GenericRouter(
APIRouter,
Generic[CRUDType, CreateSchemaType, UpdateSchemaType, ReadSchemaType]
):
def __init__(
self,
crud: CRUDType,
create_schema: CreateSchemaType,
update_schema: UpdateSchemaType,
read_schema: ReadSchemaType,
prefix: str,
tags: List[str],
**kwargs
):
"""
初始化通用路由
:param crud: CRUD实例如CRUDApiKey
:param create_schema: 创建Pydantic模型
:param update_schema: 更新Pydantic模型
:param read_schema: 响应Pydantic模型
:param prefix: 路由前缀(如"/api/ai-api-keys"
:param tags: 文档标签
"""
super().__init__(prefix=prefix, tags=tags,** kwargs)
self.crud = crud
self.create_schema = create_schema
self.update_schema = update_schema
self.read_schema = read_schema
# 注册通用路由
self.add_api_route(
"/",
self.create,
methods=["POST"],
response_model=read_schema,
status_code=201
)
self.add_api_route(
"/",
self.get_multi,
methods=["GET"],
response_model=List[read_schema]
)
self.add_api_route(
"/{id}/",
self.get,
methods=["GET"],
response_model=read_schema
)
self.add_api_route(
"/{id}/",
self.update,
methods=["PUT"],
response_model=read_schema
)
self.add_api_route(
"/{id}/",
self.remove,
methods=["DELETE"]
)
# 创建
def create(self, obj_in: CreateSchemaType, db: Session = Depends(get_db)):
return self.crud.create(db=db, obj_in=obj_in)
# 按ID查询
def get(self, id: int, db: Session = Depends(get_db)):
obj = self.crud.get(db=db, id=id)
if not obj:
raise HTTPException(status_code=404, detail=f"记录不存在")
return obj
# 分页查询
def get_multi(self, page: int = 0, limit: int = 10, db: Session = Depends(get_db)):
return self.crud.get_multi(db=db, page=page, limit=limit)
# 更新
def update(self, id: int, obj_in: UpdateSchemaType, db: Session = Depends(get_db)):
obj = self.crud.get(db=db, id=id)
if not obj:
raise HTTPException(status_code=404, detail=f"记录不存在")
return self.crud.update(db=db, db_obj=obj, obj_in=obj_in)
# 删除
def remove(self, id: int, db: Session = Depends(get_db)):
return self.crud.remove(db=db, id=id)

View File

@@ -0,0 +1,33 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
# 基础模型(共享字段)
class AIApiKeyBase(BaseModel):
name: str = Field(..., max_length=255, description="密钥名称")
platform: str = Field(..., max_length=100, description="平台如openai")
api_key: str = Field(..., max_length=255, description="API密钥")
url: Optional[str] = Field(None, max_length=255, description="自定义API地址")
status: int = Field(0, description="状态0=禁用1=启用)")
# 创建请求模型无需ID和时间字段
class AIApiKeyCreate(AIApiKeyBase):
pass
# 更新请求模型(所有字段可选)
class AIApiKeyUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=255)
platform: Optional[str] = Field(None, max_length=100)
api_key: Optional[str] = Field(None, max_length=255)
url: Optional[str] = Field(None, max_length=255)
status: Optional[int] = None
# 响应模型(包含数据库自动生成的字段)
class AIApiKeyRead(AIApiKeyBase):
id: int
created_at: Optional[datetime]
updated_at: Optional[datetime]
# 支持ORM模型直接转换为响应
class Config:
from_attributes = True # Pydantic v2用from_attributesv1用orm_mode

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class ChatCreate(BaseModel):
pass
class Chat(ChatCreate):
id: int
class Config:
orm_mode = True

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from datetime import datetime
from typing import Optional
class ReadSchemaType(BaseModel):
"""
所有响应模型的基类包含公共字段和ORM转换配置
"""
id: int
created_at: Optional[datetime] = None # 数据创建时间(可选,部分模型可能没有)
updated_at: Optional[datetime] = None # 数据更新时间(可选)
class Config:
"""
配置Pydantic模型如何处理ORM对象
- from_attributes=True支持直接从SQLAlchemy ORM模型转换Pydantic v2
- 若使用Pydantic v1需替换为 orm_mode=True
"""
from_attributes = True

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel
class UserOut(BaseModel):
id: int
username: str
email: str = None
class Config:
orm_mode = True

View File

@@ -0,0 +1,95 @@
# LangChain集成示例
from sqlalchemy.orm import Session
from datetime import datetime
from models.ai import ChatConversation, ChatMessage, MessageType
class ChatDBService:
@staticmethod
def get_conversation(db: Session, conversation_id: int):
return db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
@staticmethod
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
if not conversation_id:
conversation = ChatConversation(
title=content,
user_id=user_id,
role_id=None,
model_id=None, # 需根据实际模型id调整
model=model,
system_message=None,
temperature=0.7,
max_tokens=2048,
max_contexts=10,
create_time=datetime.now(),
update_time=datetime.now(),
is_deleted=False
)
db.add(conversation)
db.commit()
db.refresh(conversation)
return conversation
else:
conversation = db.query(ChatConversation).get(conversation_id)
if not conversation:
raise ValueError("无效的conversation_id")
return conversation
@staticmethod
def update_conversation_title(db, conversation_id: int, title: str):
conversation = db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
if conversation:
conversation.title = title[:255] # 保证不超过255字符
db.add(conversation)
db.commit()
return conversation
else:
raise ValueError("Conversation not found")
@staticmethod
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
message = ChatMessage(
conversation_id=conversation.id,
user_id=user_id,
role_id=None,
model=conversation.model,
model_id=conversation.model_id,
type=MessageType.USER,
reply_id=None,
content=content,
use_context=True,
segment_ids=None,
create_time=datetime.now(),
update_time=datetime.now(),
is_deleted=False
)
db.add(message)
db.commit()
return message
@staticmethod
def insert_ai_message(db: Session, conversation, user_id: int, content: str, model: str):
from datetime import datetime
from models.ai import MessageType
message = ChatMessage(
conversation_id=conversation.id,
user_id=user_id,
role_id=None,
model=model,
model_id=conversation.model_id,
type=MessageType.ASSISTANT,
reply_id=None,
content=content,
use_context=True,
segment_ids=None,
create_time=datetime.now(),
update_time=datetime.now(),
is_deleted=False
)
db.add(message)
db.commit()
@staticmethod
def get_history(db: Session, conversation_id: int) -> list[ChatMessage]:
return db.query(ChatMessage).filter_by(conversation_id=conversation_id).order_by(ChatMessage.id).all()

1
ai_service/utils/jwt.py Normal file
View File

@@ -0,0 +1 @@
# 预留如需JWT校验可在此实现

16
ai_service/utils/resp.py Normal file
View File

@@ -0,0 +1,16 @@
from typing import Generic, TypeVar, Optional
from pydantic import BaseModel
T = TypeVar("T")
class Response(BaseModel, Generic[T]):
code: int
message: str
data: Optional[T] = None # ✅ 明确 data 可为 None
def resp_success(data: T, message: str = "success") -> Response[T]:
return Response(code=0, message=message, data=data)
def resp_error(message="error", code=1) -> Response[T]:
return Response(code=code, message=message, data=None)