对话ds
This commit is contained in:
2
chat/.env.example
Normal file
2
chat/.env.example
Normal file
@@ -0,0 +1,2 @@
|
||||
OPENAI_API_KEY=你的API密钥
|
||||
DEEPSEEK_API_KEY='你的API密钥'
|
||||
@@ -4,27 +4,20 @@ from fastapi import APIRouter, Depends, Request, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
|
||||
from api.v1.vo import MessageVO, ConversationsVO
|
||||
from api.v1.vo import MessageVO
|
||||
from deps.auth import get_current_user
|
||||
from services.chat_service import ChatDBService
|
||||
from db.session import get_db
|
||||
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||
from models.ai import ChatConversation, ChatMessage
|
||||
from utils.resp import resp_success, Response
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
|
||||
def get_deepseek_llm(api_key: SecretStr, model: str):
|
||||
# deepseek 兼容 OpenAI API,需指定 base_url
|
||||
return ChatDeepSeek(
|
||||
@@ -38,26 +31,22 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
||||
body = await request.json()
|
||||
content = body.get('content')
|
||||
conversation_id = body.get('conversation_id')
|
||||
print(content, 'content')
|
||||
model = 'deepseek-chat'
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
openai_api_base = "https://api.deepseek.com/v1"
|
||||
llm = get_deepseek_llm(api_key, model)
|
||||
llm = get_deepseek_llm(SecretStr(api_key), model)
|
||||
|
||||
if not content or not isinstance(content, str):
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
||||
|
||||
user_id = user["user_id"]
|
||||
print(conversation_id, 'conversation_id')
|
||||
# 1. 获取或新建对话
|
||||
# 1. 获取对话
|
||||
try:
|
||||
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model, content)
|
||||
conversation = ChatDBService.get_conversation(db, conversation_id)
|
||||
conversation = db.merge(conversation) # ✅ 防止 DetachedInstanceError
|
||||
except ValueError as e:
|
||||
print(23232)
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse({"error": str(e)}, status_code=400)
|
||||
print(conversation, 'dsds')
|
||||
# 2. 插入当前消息
|
||||
ChatDBService.add_message(db, conversation, user_id, content)
|
||||
context = [
|
||||
@@ -65,13 +54,16 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
||||
]
|
||||
# 3. 查询历史消息,组装上下文
|
||||
history = ChatDBService.get_history(db, conversation.id)
|
||||
# === 新增:如果只有一条消息,更新 title ===
|
||||
if len(history) == 1:
|
||||
ChatDBService.update_conversation_title(db, conversation.id, content[:255])
|
||||
|
||||
for msg in history:
|
||||
# 假设 msg.type 存储的是 'user' 或 'assistant'
|
||||
# role = msg.type if msg.type in ("user", "assistant") else "user"
|
||||
context.append((msg.type, msg.content))
|
||||
print('context', context)
|
||||
ai_reply = ""
|
||||
|
||||
ai_reply = ""
|
||||
async def event_generator():
|
||||
nonlocal ai_reply
|
||||
async for chunk in llm.astream(context):
|
||||
@@ -88,6 +80,13 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
||||
|
||||
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
||||
|
||||
@router.post("/conversations")
|
||||
def create_conversation(db: Session = Depends(get_db), user=Depends(get_current_user),):
|
||||
user_id = user["user_id"]
|
||||
model = 'deepseek-chat'
|
||||
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
|
||||
return resp_success(data=conversation.id)
|
||||
|
||||
@router.get('/conversations')
|
||||
async def get_conversations(
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from crud.base import CRUDBase
|
||||
from models.ai import AIApiKey # SQLAlchemy模型
|
||||
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate
|
||||
|
||||
# 继承通用CRUD基类,指定模型和Pydantic类型
|
||||
class CRUDApiKey(CRUDBase[AIApiKey, AIApiKeyCreate, AIApiKeyUpdate]):
|
||||
# 如有特殊逻辑,可重写父类方法(如创建时验证平台唯一性)
|
||||
def create(self, db: Session, *, obj_in: AIApiKeyCreate):
|
||||
# 示例:验证平台+名称唯一
|
||||
if self.get_by(db, platform=obj_in.platform, name=obj_in.name):
|
||||
raise HTTPException(status_code=400, detail="该平台下名称已存在")
|
||||
return super().create(db, obj_in=obj_in)
|
||||
|
||||
# 创建CRUD实例
|
||||
ai_api_key_crud = CRUDApiKey(AIApiKey)
|
||||
@@ -3,7 +3,6 @@ from fastapi import FastAPI
|
||||
from dotenv import load_dotenv
|
||||
from api.v1 import ai_chat
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from routers.ai_api_key import router as ai_api_key_router
|
||||
|
||||
# 加载.env环境变量,优先项目根目录
|
||||
load_dotenv()
|
||||
@@ -25,7 +24,6 @@ app.add_middleware(
|
||||
|
||||
# 注册路由
|
||||
app.include_router(ai_chat.router, prefix="/chat/api/v1", tags=["chat"])
|
||||
app.include_router(ai_api_key_router, tags=["chat"])
|
||||
|
||||
# 健康检查
|
||||
@app.get("/ping")
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
langchain-openai
|
||||
fastapi==0.116.1
|
||||
uvicorn[standard]==0.35.0
|
||||
langchain-openai==0.3.28
|
||||
langchain-deepseek==0.1.3
|
||||
langchain==0.3.26
|
||||
langchain-community==0.3.26
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate, AIApiKeyRead
|
||||
from crud.ai_api_key import ai_api_key_crud
|
||||
from routers.base import GenericRouter
|
||||
|
||||
# 继承通用路由基类,传入参数即可生成所有CRUD接口
|
||||
router = GenericRouter(
|
||||
crud=ai_api_key_crud,
|
||||
create_schema=AIApiKeyCreate,
|
||||
update_schema=AIApiKeyUpdate,
|
||||
read_schema=AIApiKeyRead,
|
||||
prefix="/chat/api/ai-api-keys",
|
||||
tags=["AI API密钥"]
|
||||
)
|
||||
10
chat/schemas/ai_chat.py
Normal file
10
chat/schemas/ai_chat.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ChatCreate(BaseModel):
|
||||
pass
|
||||
|
||||
class Chat(ChatCreate):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
@@ -4,10 +4,13 @@ from datetime import datetime
|
||||
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||
|
||||
class ChatDBService:
|
||||
@staticmethod
|
||||
def get_conversation(db: Session, conversation_id: int):
|
||||
return db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
|
||||
if not conversation_id:
|
||||
print(conversation_id, 'conversation_id')
|
||||
conversation = ChatConversation(
|
||||
title=content,
|
||||
user_id=user_id,
|
||||
@@ -32,6 +35,17 @@ class ChatDBService:
|
||||
raise ValueError("无效的conversation_id")
|
||||
return conversation
|
||||
|
||||
@staticmethod
|
||||
def update_conversation_title(db, conversation_id: int, title: str):
|
||||
conversation = db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
|
||||
if conversation:
|
||||
conversation.title = title[:255] # 保证不超过255字符
|
||||
db.add(conversation)
|
||||
db.commit()
|
||||
return conversation
|
||||
else:
|
||||
raise ValueError("Conversation not found")
|
||||
|
||||
@staticmethod
|
||||
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
|
||||
message = ChatMessage(
|
||||
|
||||
@@ -12,5 +12,5 @@ class Response(BaseModel, Generic[T]):
|
||||
def resp_success(data: T, message: str = "success") -> Response[T]:
|
||||
return Response(code=0, message=message, data=data)
|
||||
|
||||
def resp_error(message="error", code=1) -> Response[None]:
|
||||
def resp_error(message="error", code=1) -> Response[T]:
|
||||
return Response(code=code, message=message, data=None)
|
||||
@@ -5,6 +5,16 @@ export async function getConversations() {
|
||||
return await res.json();
|
||||
}
|
||||
|
||||
export async function createConversation() {
|
||||
const response = await fetchWithAuth('/chat/api/v1/conversations', {
|
||||
method: 'POST',
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error('创建对话失败');
|
||||
}
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
export async function getMessages(conversationId: number) {
|
||||
const res = await fetchWithAuth(
|
||||
`/chat/api/v1/messages?conversation_id=${conversationId}`,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"title": "AI Management",
|
||||
"ai_api_key": {
|
||||
"api_key": {
|
||||
"title": "KEY Management",
|
||||
"name": "KEY Management"
|
||||
},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"title": "AI大模型",
|
||||
"ai_api_key": {
|
||||
"api_key": {
|
||||
"title": "API 密钥",
|
||||
"name": "API 密钥"
|
||||
},
|
||||
|
||||
@@ -133,7 +133,7 @@ function refreshGrid() {
|
||||
v-permission="'ai:ai_api_key:create'"
|
||||
>
|
||||
<Plus class="size-5" />
|
||||
{{ $t('ui.actionTitle.create', [$t('ai.ai_api_key.name')]) }}
|
||||
{{ $t('ui.actionTitle.create', [$t('ai.api_key.name')]) }}
|
||||
</Button>
|
||||
</template>
|
||||
</Grid>
|
||||
@@ -20,8 +20,8 @@ const formModel = new AiAIApiKeyModel();
|
||||
const formData = ref<AiAIApiKeyApi.AiAIApiKey>();
|
||||
const getTitle = computed(() => {
|
||||
return formData.value?.id
|
||||
? $t('ui.actionTitle.edit', [$t('ai.ai_api_key.name')])
|
||||
: $t('ui.actionTitle.create', [$t('ai.ai_api_key.name')]);
|
||||
? $t('ui.actionTitle.edit', [$t('ai.api_key.name')])
|
||||
: $t('ui.actionTitle.create', [$t('ai.api_key.name')]);
|
||||
});
|
||||
|
||||
const [Form, formApi] = useVbenForm({
|
||||
@@ -14,16 +14,21 @@ import {
|
||||
Select,
|
||||
} from 'ant-design-vue';
|
||||
|
||||
import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat';
|
||||
import {
|
||||
createConversation,
|
||||
fetchAIStream,
|
||||
getConversations,
|
||||
getMessages,
|
||||
} from '#/api/ai/chat';
|
||||
|
||||
interface Message {
|
||||
id: null | number;
|
||||
id: number;
|
||||
type: 'assistant' | 'user';
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface ChatItem {
|
||||
id: null | number;
|
||||
id: number;
|
||||
title: string;
|
||||
lastMessage: string;
|
||||
}
|
||||
@@ -60,14 +65,13 @@ async function selectChat(id: number) {
|
||||
nextTick(scrollToBottom);
|
||||
}
|
||||
|
||||
function handleNewChat() {
|
||||
const newId = null;
|
||||
chatList.value.unshift({
|
||||
id: newId,
|
||||
title: `新对话${chatList.value.length + 1}`,
|
||||
lastMessage: '',
|
||||
});
|
||||
selectedChatId.value = newId;
|
||||
async function handleNewChat() {
|
||||
// 调用后端新建对话
|
||||
const { data } = await createConversation();
|
||||
// 刷新对话列表
|
||||
await fetchConversations();
|
||||
// 选中新建的对话
|
||||
selectedChatId.value = data;
|
||||
messages.value = [];
|
||||
nextTick(scrollToBottom);
|
||||
}
|
||||
@@ -195,7 +199,7 @@ onMounted(() => {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div class="chat-messages" style="height: 100%;" ref="messagesRef">
|
||||
<div class="chat-messages" style="height: 100%" ref="messagesRef">
|
||||
<div
|
||||
v-for="msg in messages"
|
||||
:key="msg.id"
|
||||
|
||||
Reference in New Issue
Block a user