This commit is contained in:
XIE7654
2025-07-18 10:39:05 +08:00
parent 66d5971570
commit aef25112f6
16 changed files with 83 additions and 74 deletions

2
chat/.env.example Normal file
View File

@@ -0,0 +1,2 @@
OPENAI_API_KEY=你的API密钥
DEEPSEEK_API_KEY='你的API密钥'

View File

@@ -4,27 +4,20 @@ from fastapi import APIRouter, Depends, Request, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from datetime import datetime
from pydantic import BaseModel, SecretStr from pydantic import BaseModel, SecretStr
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
from langchain_community.chat_models import ChatOpenAI
from api.v1.vo import MessageVO, ConversationsVO from api.v1.vo import MessageVO
from deps.auth import get_current_user from deps.auth import get_current_user
from services.chat_service import ChatDBService from services.chat_service import ChatDBService
from db.session import get_db from db.session import get_db
from models.ai import ChatConversation, ChatMessage, MessageType from models.ai import ChatConversation, ChatMessage
from utils.resp import resp_success, Response from utils.resp import resp_success, Response
from langchain_deepseek import ChatDeepSeek from langchain_deepseek import ChatDeepSeek
router = APIRouter() router = APIRouter()
class ChatRequest(BaseModel):
prompt: str
def get_deepseek_llm(api_key: SecretStr, model: str): def get_deepseek_llm(api_key: SecretStr, model: str):
# deepseek 兼容 OpenAI API需指定 base_url # deepseek 兼容 OpenAI API需指定 base_url
return ChatDeepSeek( return ChatDeepSeek(
@@ -38,26 +31,22 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
body = await request.json() body = await request.json()
content = body.get('content') content = body.get('content')
conversation_id = body.get('conversation_id') conversation_id = body.get('conversation_id')
print(content, 'content')
model = 'deepseek-chat' model = 'deepseek-chat'
api_key = os.getenv("DEEPSEEK_API_KEY") api_key = os.getenv("DEEPSEEK_API_KEY")
openai_api_base = "https://api.deepseek.com/v1" llm = get_deepseek_llm(SecretStr(api_key), model)
llm = get_deepseek_llm(api_key, model)
if not content or not isinstance(content, str): if not content or not isinstance(content, str):
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
return JSONResponse({"error": "content不能为空"}, status_code=400) return JSONResponse({"error": "content不能为空"}, status_code=400)
user_id = user["user_id"] user_id = user["user_id"]
print(conversation_id, 'conversation_id') # 1. 获取对话
# 1. 获取或新建对话
try: try:
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model, content) conversation = ChatDBService.get_conversation(db, conversation_id)
conversation = db.merge(conversation) # ✅ 防止 DetachedInstanceError
except ValueError as e: except ValueError as e:
print(23232)
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
return JSONResponse({"error": str(e)}, status_code=400) return JSONResponse({"error": str(e)}, status_code=400)
print(conversation, 'dsds')
# 2. 插入当前消息 # 2. 插入当前消息
ChatDBService.add_message(db, conversation, user_id, content) ChatDBService.add_message(db, conversation, user_id, content)
context = [ context = [
@@ -65,13 +54,16 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
] ]
# 3. 查询历史消息,组装上下文 # 3. 查询历史消息,组装上下文
history = ChatDBService.get_history(db, conversation.id) history = ChatDBService.get_history(db, conversation.id)
# === 新增:如果只有一条消息,更新 title ===
if len(history) == 1:
ChatDBService.update_conversation_title(db, conversation.id, content[:255])
for msg in history: for msg in history:
# 假设 msg.type 存储的是 'user' 或 'assistant' # 假设 msg.type 存储的是 'user' 或 'assistant'
# role = msg.type if msg.type in ("user", "assistant") else "user" # role = msg.type if msg.type in ("user", "assistant") else "user"
context.append((msg.type, msg.content)) context.append((msg.type, msg.content))
print('context', context)
ai_reply = ""
ai_reply = ""
async def event_generator(): async def event_generator():
nonlocal ai_reply nonlocal ai_reply
async for chunk in llm.astream(context): async for chunk in llm.astream(context):
@@ -88,6 +80,13 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
return StreamingResponse(event_generator(), media_type='text/event-stream') return StreamingResponse(event_generator(), media_type='text/event-stream')
@router.post("/conversations")
def create_conversation(db: Session = Depends(get_db), user=Depends(get_current_user),):
user_id = user["user_id"]
model = 'deepseek-chat'
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
return resp_success(data=conversation.id)
@router.get('/conversations') @router.get('/conversations')
async def get_conversations( async def get_conversations(
db: Session = Depends(get_db), db: Session = Depends(get_db),

View File

@@ -1,18 +0,0 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from crud.base import CRUDBase
from models.ai import AIApiKey # SQLAlchemy模型
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate
# 继承通用CRUD基类指定模型和Pydantic类型
class CRUDApiKey(CRUDBase[AIApiKey, AIApiKeyCreate, AIApiKeyUpdate]):
# 如有特殊逻辑,可重写父类方法(如创建时验证平台唯一性)
def create(self, db: Session, *, obj_in: AIApiKeyCreate):
# 示例:验证平台+名称唯一
if self.get_by(db, platform=obj_in.platform, name=obj_in.name):
raise HTTPException(status_code=400, detail="该平台下名称已存在")
return super().create(db, obj_in=obj_in)
# 创建CRUD实例
ai_api_key_crud = CRUDApiKey(AIApiKey)

View File

@@ -3,7 +3,6 @@ from fastapi import FastAPI
from dotenv import load_dotenv from dotenv import load_dotenv
from api.v1 import ai_chat from api.v1 import ai_chat
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from routers.ai_api_key import router as ai_api_key_router
# 加载.env环境变量优先项目根目录 # 加载.env环境变量优先项目根目录
load_dotenv() load_dotenv()
@@ -25,7 +24,6 @@ app.add_middleware(
# 注册路由 # 注册路由
app.include_router(ai_chat.router, prefix="/chat/api/v1", tags=["chat"]) app.include_router(ai_chat.router, prefix="/chat/api/v1", tags=["chat"])
app.include_router(ai_api_key_router, tags=["chat"])
# 健康检查 # 健康检查
@app.get("/ping") @app.get("/ping")

View File

@@ -1,3 +1,6 @@
fastapi fastapi==0.116.1
uvicorn[standard] uvicorn[standard]==0.35.0
langchain-openai langchain-openai==0.3.28
langchain-deepseek==0.1.3
langchain==0.3.26
langchain-community==0.3.26

View File

@@ -1,13 +0,0 @@
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate, AIApiKeyRead
from crud.ai_api_key import ai_api_key_crud
from routers.base import GenericRouter
# 继承通用路由基类传入参数即可生成所有CRUD接口
router = GenericRouter(
crud=ai_api_key_crud,
create_schema=AIApiKeyCreate,
update_schema=AIApiKeyUpdate,
read_schema=AIApiKeyRead,
prefix="/chat/api/ai-api-keys",
tags=["AI API密钥"]
)

10
chat/schemas/ai_chat.py Normal file
View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class ChatCreate(BaseModel):
pass
class Chat(ChatCreate):
id: int
class Config:
orm_mode = True

View File

@@ -4,10 +4,13 @@ from datetime import datetime
from models.ai import ChatConversation, ChatMessage, MessageType from models.ai import ChatConversation, ChatMessage, MessageType
class ChatDBService: class ChatDBService:
@staticmethod
def get_conversation(db: Session, conversation_id: int):
return db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
@staticmethod @staticmethod
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation: def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
if not conversation_id: if not conversation_id:
print(conversation_id, 'conversation_id')
conversation = ChatConversation( conversation = ChatConversation(
title=content, title=content,
user_id=user_id, user_id=user_id,
@@ -32,6 +35,17 @@ class ChatDBService:
raise ValueError("无效的conversation_id") raise ValueError("无效的conversation_id")
return conversation return conversation
@staticmethod
def update_conversation_title(db, conversation_id: int, title: str):
conversation = db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
if conversation:
conversation.title = title[:255] # 保证不超过255字符
db.add(conversation)
db.commit()
return conversation
else:
raise ValueError("Conversation not found")
@staticmethod @staticmethod
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage: def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
message = ChatMessage( message = ChatMessage(

View File

@@ -12,5 +12,5 @@ class Response(BaseModel, Generic[T]):
def resp_success(data: T, message: str = "success") -> Response[T]: def resp_success(data: T, message: str = "success") -> Response[T]:
return Response(code=0, message=message, data=data) return Response(code=0, message=message, data=data)
def resp_error(message="error", code=1) -> Response[None]: def resp_error(message="error", code=1) -> Response[T]:
return Response(code=code, message=message, data=None) return Response(code=code, message=message, data=None)

View File

@@ -5,6 +5,16 @@ export async function getConversations() {
return await res.json(); return await res.json();
} }
export async function createConversation() {
const response = await fetchWithAuth('/chat/api/v1/conversations', {
method: 'POST',
});
if (!response.ok) {
throw new Error('创建对话失败');
}
return await response.json();
}
export async function getMessages(conversationId: number) { export async function getMessages(conversationId: number) {
const res = await fetchWithAuth( const res = await fetchWithAuth(
`/chat/api/v1/messages?conversation_id=${conversationId}`, `/chat/api/v1/messages?conversation_id=${conversationId}`,

View File

@@ -1,6 +1,6 @@
{ {
"title": "AI Management", "title": "AI Management",
"ai_api_key": { "api_key": {
"title": "KEY Management", "title": "KEY Management",
"name": "KEY Management" "name": "KEY Management"
}, },

View File

@@ -1,6 +1,6 @@
{ {
"title": "AI大模型", "title": "AI大模型",
"ai_api_key": { "api_key": {
"title": "API 密钥", "title": "API 密钥",
"name": "API 密钥" "name": "API 密钥"
}, },

View File

@@ -133,7 +133,7 @@ function refreshGrid() {
v-permission="'ai:ai_api_key:create'" v-permission="'ai:ai_api_key:create'"
> >
<Plus class="size-5" /> <Plus class="size-5" />
{{ $t('ui.actionTitle.create', [$t('ai.ai_api_key.name')]) }} {{ $t('ui.actionTitle.create', [$t('ai.api_key.name')]) }}
</Button> </Button>
</template> </template>
</Grid> </Grid>

View File

@@ -20,8 +20,8 @@ const formModel = new AiAIApiKeyModel();
const formData = ref<AiAIApiKeyApi.AiAIApiKey>(); const formData = ref<AiAIApiKeyApi.AiAIApiKey>();
const getTitle = computed(() => { const getTitle = computed(() => {
return formData.value?.id return formData.value?.id
? $t('ui.actionTitle.edit', [$t('ai.ai_api_key.name')]) ? $t('ui.actionTitle.edit', [$t('ai.api_key.name')])
: $t('ui.actionTitle.create', [$t('ai.ai_api_key.name')]); : $t('ui.actionTitle.create', [$t('ai.api_key.name')]);
}); });
const [Form, formApi] = useVbenForm({ const [Form, formApi] = useVbenForm({

View File

@@ -14,16 +14,21 @@ import {
Select, Select,
} from 'ant-design-vue'; } from 'ant-design-vue';
import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat'; import {
createConversation,
fetchAIStream,
getConversations,
getMessages,
} from '#/api/ai/chat';
interface Message { interface Message {
id: null | number; id: number;
type: 'assistant' | 'user'; type: 'assistant' | 'user';
content: string; content: string;
} }
interface ChatItem { interface ChatItem {
id: null | number; id: number;
title: string; title: string;
lastMessage: string; lastMessage: string;
} }
@@ -60,14 +65,13 @@ async function selectChat(id: number) {
nextTick(scrollToBottom); nextTick(scrollToBottom);
} }
function handleNewChat() { async function handleNewChat() {
const newId = null; // 调用后端新建对话
chatList.value.unshift({ const { data } = await createConversation();
id: newId, // 刷新对话列表
title: `新对话${chatList.value.length + 1}`, await fetchConversations();
lastMessage: '', // 选中新建的对话
}); selectedChatId.value = data;
selectedChatId.value = newId;
messages.value = []; messages.value = [];
nextTick(scrollToBottom); nextTick(scrollToBottom);
} }
@@ -195,7 +199,7 @@ onMounted(() => {
/> />
</div> </div>
</div> </div>
<div class="chat-messages" style="height: 100%;" ref="messagesRef"> <div class="chat-messages" style="height: 100%" ref="messagesRef">
<div <div
v-for="msg in messages" v-for="msg in messages"
:key="msg.id" :key="msg.id"