对话ds
This commit is contained in:
2
chat/.env.example
Normal file
2
chat/.env.example
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
OPENAI_API_KEY=你的API密钥
|
||||||
|
DEEPSEEK_API_KEY='你的API密钥'
|
||||||
@@ -4,27 +4,20 @@ from fastapi import APIRouter, Depends, Request, Query
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List
|
from typing import List
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from pydantic import BaseModel, SecretStr
|
from pydantic import BaseModel, SecretStr
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
|
||||||
|
|
||||||
from api.v1.vo import MessageVO, ConversationsVO
|
from api.v1.vo import MessageVO
|
||||||
from deps.auth import get_current_user
|
from deps.auth import get_current_user
|
||||||
from services.chat_service import ChatDBService
|
from services.chat_service import ChatDBService
|
||||||
from db.session import get_db
|
from db.session import get_db
|
||||||
from models.ai import ChatConversation, ChatMessage, MessageType
|
from models.ai import ChatConversation, ChatMessage
|
||||||
from utils.resp import resp_success, Response
|
from utils.resp import resp_success, Response
|
||||||
from langchain_deepseek import ChatDeepSeek
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_deepseek_llm(api_key: SecretStr, model: str):
|
def get_deepseek_llm(api_key: SecretStr, model: str):
|
||||||
# deepseek 兼容 OpenAI API,需指定 base_url
|
# deepseek 兼容 OpenAI API,需指定 base_url
|
||||||
return ChatDeepSeek(
|
return ChatDeepSeek(
|
||||||
@@ -38,26 +31,22 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
body = await request.json()
|
body = await request.json()
|
||||||
content = body.get('content')
|
content = body.get('content')
|
||||||
conversation_id = body.get('conversation_id')
|
conversation_id = body.get('conversation_id')
|
||||||
print(content, 'content')
|
|
||||||
model = 'deepseek-chat'
|
model = 'deepseek-chat'
|
||||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||||
openai_api_base = "https://api.deepseek.com/v1"
|
llm = get_deepseek_llm(SecretStr(api_key), model)
|
||||||
llm = get_deepseek_llm(api_key, model)
|
|
||||||
|
|
||||||
if not content or not isinstance(content, str):
|
if not content or not isinstance(content, str):
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
return JSONResponse({"error": "content不能为空"}, status_code=400)
|
||||||
|
|
||||||
user_id = user["user_id"]
|
user_id = user["user_id"]
|
||||||
print(conversation_id, 'conversation_id')
|
# 1. 获取对话
|
||||||
# 1. 获取或新建对话
|
|
||||||
try:
|
try:
|
||||||
conversation = ChatDBService.get_or_create_conversation(db, conversation_id, user_id, model, content)
|
conversation = ChatDBService.get_conversation(db, conversation_id)
|
||||||
|
conversation = db.merge(conversation) # ✅ 防止 DetachedInstanceError
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(23232)
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
return JSONResponse({"error": str(e)}, status_code=400)
|
return JSONResponse({"error": str(e)}, status_code=400)
|
||||||
print(conversation, 'dsds')
|
|
||||||
# 2. 插入当前消息
|
# 2. 插入当前消息
|
||||||
ChatDBService.add_message(db, conversation, user_id, content)
|
ChatDBService.add_message(db, conversation, user_id, content)
|
||||||
context = [
|
context = [
|
||||||
@@ -65,13 +54,16 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
]
|
]
|
||||||
# 3. 查询历史消息,组装上下文
|
# 3. 查询历史消息,组装上下文
|
||||||
history = ChatDBService.get_history(db, conversation.id)
|
history = ChatDBService.get_history(db, conversation.id)
|
||||||
|
# === 新增:如果只有一条消息,更新 title ===
|
||||||
|
if len(history) == 1:
|
||||||
|
ChatDBService.update_conversation_title(db, conversation.id, content[:255])
|
||||||
|
|
||||||
for msg in history:
|
for msg in history:
|
||||||
# 假设 msg.type 存储的是 'user' 或 'assistant'
|
# 假设 msg.type 存储的是 'user' 或 'assistant'
|
||||||
# role = msg.type if msg.type in ("user", "assistant") else "user"
|
# role = msg.type if msg.type in ("user", "assistant") else "user"
|
||||||
context.append((msg.type, msg.content))
|
context.append((msg.type, msg.content))
|
||||||
print('context', context)
|
|
||||||
ai_reply = ""
|
|
||||||
|
|
||||||
|
ai_reply = ""
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
nonlocal ai_reply
|
nonlocal ai_reply
|
||||||
async for chunk in llm.astream(context):
|
async for chunk in llm.astream(context):
|
||||||
@@ -88,6 +80,13 @@ async def chat_stream(request: Request, user=Depends(get_current_user), db: Sess
|
|||||||
|
|
||||||
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
return StreamingResponse(event_generator(), media_type='text/event-stream')
|
||||||
|
|
||||||
|
@router.post("/conversations")
|
||||||
|
def create_conversation(db: Session = Depends(get_db), user=Depends(get_current_user),):
|
||||||
|
user_id = user["user_id"]
|
||||||
|
model = 'deepseek-chat'
|
||||||
|
conversation = ChatDBService.get_or_create_conversation(db, None, user_id, model, '新对话')
|
||||||
|
return resp_success(data=conversation.id)
|
||||||
|
|
||||||
@router.get('/conversations')
|
@router.get('/conversations')
|
||||||
async def get_conversations(
|
async def get_conversations(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
from fastapi import HTTPException
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from crud.base import CRUDBase
|
|
||||||
from models.ai import AIApiKey # SQLAlchemy模型
|
|
||||||
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate
|
|
||||||
|
|
||||||
# 继承通用CRUD基类,指定模型和Pydantic类型
|
|
||||||
class CRUDApiKey(CRUDBase[AIApiKey, AIApiKeyCreate, AIApiKeyUpdate]):
|
|
||||||
# 如有特殊逻辑,可重写父类方法(如创建时验证平台唯一性)
|
|
||||||
def create(self, db: Session, *, obj_in: AIApiKeyCreate):
|
|
||||||
# 示例:验证平台+名称唯一
|
|
||||||
if self.get_by(db, platform=obj_in.platform, name=obj_in.name):
|
|
||||||
raise HTTPException(status_code=400, detail="该平台下名称已存在")
|
|
||||||
return super().create(db, obj_in=obj_in)
|
|
||||||
|
|
||||||
# 创建CRUD实例
|
|
||||||
ai_api_key_crud = CRUDApiKey(AIApiKey)
|
|
||||||
@@ -3,7 +3,6 @@ from fastapi import FastAPI
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from api.v1 import ai_chat
|
from api.v1 import ai_chat
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from routers.ai_api_key import router as ai_api_key_router
|
|
||||||
|
|
||||||
# 加载.env环境变量,优先项目根目录
|
# 加载.env环境变量,优先项目根目录
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -25,7 +24,6 @@ app.add_middleware(
|
|||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
app.include_router(ai_chat.router, prefix="/chat/api/v1", tags=["chat"])
|
app.include_router(ai_chat.router, prefix="/chat/api/v1", tags=["chat"])
|
||||||
app.include_router(ai_api_key_router, tags=["chat"])
|
|
||||||
|
|
||||||
# 健康检查
|
# 健康检查
|
||||||
@app.get("/ping")
|
@app.get("/ping")
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
fastapi
|
fastapi==0.116.1
|
||||||
uvicorn[standard]
|
uvicorn[standard]==0.35.0
|
||||||
langchain-openai
|
langchain-openai==0.3.28
|
||||||
|
langchain-deepseek==0.1.3
|
||||||
|
langchain==0.3.26
|
||||||
|
langchain-community==0.3.26
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
from schemas.ai_api_key import AIApiKeyCreate, AIApiKeyUpdate, AIApiKeyRead
|
|
||||||
from crud.ai_api_key import ai_api_key_crud
|
|
||||||
from routers.base import GenericRouter
|
|
||||||
|
|
||||||
# 继承通用路由基类,传入参数即可生成所有CRUD接口
|
|
||||||
router = GenericRouter(
|
|
||||||
crud=ai_api_key_crud,
|
|
||||||
create_schema=AIApiKeyCreate,
|
|
||||||
update_schema=AIApiKeyUpdate,
|
|
||||||
read_schema=AIApiKeyRead,
|
|
||||||
prefix="/chat/api/ai-api-keys",
|
|
||||||
tags=["AI API密钥"]
|
|
||||||
)
|
|
||||||
10
chat/schemas/ai_chat.py
Normal file
10
chat/schemas/ai_chat.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class ChatCreate(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Chat(ChatCreate):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
@@ -4,10 +4,13 @@ from datetime import datetime
|
|||||||
from models.ai import ChatConversation, ChatMessage, MessageType
|
from models.ai import ChatConversation, ChatMessage, MessageType
|
||||||
|
|
||||||
class ChatDBService:
|
class ChatDBService:
|
||||||
|
@staticmethod
|
||||||
|
def get_conversation(db: Session, conversation_id: int):
|
||||||
|
return db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
|
def get_or_create_conversation(db: Session, conversation_id: int | None, user_id: int, model: str, content: str) -> ChatConversation:
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
print(conversation_id, 'conversation_id')
|
|
||||||
conversation = ChatConversation(
|
conversation = ChatConversation(
|
||||||
title=content,
|
title=content,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -32,6 +35,17 @@ class ChatDBService:
|
|||||||
raise ValueError("无效的conversation_id")
|
raise ValueError("无效的conversation_id")
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_conversation_title(db, conversation_id: int, title: str):
|
||||||
|
conversation = db.query(ChatConversation).filter(ChatConversation.id == conversation_id).first()
|
||||||
|
if conversation:
|
||||||
|
conversation.title = title[:255] # 保证不超过255字符
|
||||||
|
db.add(conversation)
|
||||||
|
db.commit()
|
||||||
|
return conversation
|
||||||
|
else:
|
||||||
|
raise ValueError("Conversation not found")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
|
def add_message(db: Session, conversation: ChatConversation, user_id: int, content: str) -> ChatMessage:
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
|
|||||||
@@ -12,5 +12,5 @@ class Response(BaseModel, Generic[T]):
|
|||||||
def resp_success(data: T, message: str = "success") -> Response[T]:
|
def resp_success(data: T, message: str = "success") -> Response[T]:
|
||||||
return Response(code=0, message=message, data=data)
|
return Response(code=0, message=message, data=data)
|
||||||
|
|
||||||
def resp_error(message="error", code=1) -> Response[None]:
|
def resp_error(message="error", code=1) -> Response[T]:
|
||||||
return Response(code=code, message=message, data=None)
|
return Response(code=code, message=message, data=None)
|
||||||
@@ -5,6 +5,16 @@ export async function getConversations() {
|
|||||||
return await res.json();
|
return await res.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function createConversation() {
|
||||||
|
const response = await fetchWithAuth('/chat/api/v1/conversations', {
|
||||||
|
method: 'POST',
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('创建对话失败');
|
||||||
|
}
|
||||||
|
return await response.json();
|
||||||
|
}
|
||||||
|
|
||||||
export async function getMessages(conversationId: number) {
|
export async function getMessages(conversationId: number) {
|
||||||
const res = await fetchWithAuth(
|
const res = await fetchWithAuth(
|
||||||
`/chat/api/v1/messages?conversation_id=${conversationId}`,
|
`/chat/api/v1/messages?conversation_id=${conversationId}`,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"title": "AI Management",
|
"title": "AI Management",
|
||||||
"ai_api_key": {
|
"api_key": {
|
||||||
"title": "KEY Management",
|
"title": "KEY Management",
|
||||||
"name": "KEY Management"
|
"name": "KEY Management"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"title": "AI大模型",
|
"title": "AI大模型",
|
||||||
"ai_api_key": {
|
"api_key": {
|
||||||
"title": "API 密钥",
|
"title": "API 密钥",
|
||||||
"name": "API 密钥"
|
"name": "API 密钥"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ function refreshGrid() {
|
|||||||
v-permission="'ai:ai_api_key:create'"
|
v-permission="'ai:ai_api_key:create'"
|
||||||
>
|
>
|
||||||
<Plus class="size-5" />
|
<Plus class="size-5" />
|
||||||
{{ $t('ui.actionTitle.create', [$t('ai.ai_api_key.name')]) }}
|
{{ $t('ui.actionTitle.create', [$t('ai.api_key.name')]) }}
|
||||||
</Button>
|
</Button>
|
||||||
</template>
|
</template>
|
||||||
</Grid>
|
</Grid>
|
||||||
@@ -20,8 +20,8 @@ const formModel = new AiAIApiKeyModel();
|
|||||||
const formData = ref<AiAIApiKeyApi.AiAIApiKey>();
|
const formData = ref<AiAIApiKeyApi.AiAIApiKey>();
|
||||||
const getTitle = computed(() => {
|
const getTitle = computed(() => {
|
||||||
return formData.value?.id
|
return formData.value?.id
|
||||||
? $t('ui.actionTitle.edit', [$t('ai.ai_api_key.name')])
|
? $t('ui.actionTitle.edit', [$t('ai.api_key.name')])
|
||||||
: $t('ui.actionTitle.create', [$t('ai.ai_api_key.name')]);
|
: $t('ui.actionTitle.create', [$t('ai.api_key.name')]);
|
||||||
});
|
});
|
||||||
|
|
||||||
const [Form, formApi] = useVbenForm({
|
const [Form, formApi] = useVbenForm({
|
||||||
@@ -14,16 +14,21 @@ import {
|
|||||||
Select,
|
Select,
|
||||||
} from 'ant-design-vue';
|
} from 'ant-design-vue';
|
||||||
|
|
||||||
import { fetchAIStream, getConversations, getMessages } from '#/api/ai/chat';
|
import {
|
||||||
|
createConversation,
|
||||||
|
fetchAIStream,
|
||||||
|
getConversations,
|
||||||
|
getMessages,
|
||||||
|
} from '#/api/ai/chat';
|
||||||
|
|
||||||
interface Message {
|
interface Message {
|
||||||
id: null | number;
|
id: number;
|
||||||
type: 'assistant' | 'user';
|
type: 'assistant' | 'user';
|
||||||
content: string;
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ChatItem {
|
interface ChatItem {
|
||||||
id: null | number;
|
id: number;
|
||||||
title: string;
|
title: string;
|
||||||
lastMessage: string;
|
lastMessage: string;
|
||||||
}
|
}
|
||||||
@@ -60,14 +65,13 @@ async function selectChat(id: number) {
|
|||||||
nextTick(scrollToBottom);
|
nextTick(scrollToBottom);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChat() {
|
async function handleNewChat() {
|
||||||
const newId = null;
|
// 调用后端新建对话
|
||||||
chatList.value.unshift({
|
const { data } = await createConversation();
|
||||||
id: newId,
|
// 刷新对话列表
|
||||||
title: `新对话${chatList.value.length + 1}`,
|
await fetchConversations();
|
||||||
lastMessage: '',
|
// 选中新建的对话
|
||||||
});
|
selectedChatId.value = data;
|
||||||
selectedChatId.value = newId;
|
|
||||||
messages.value = [];
|
messages.value = [];
|
||||||
nextTick(scrollToBottom);
|
nextTick(scrollToBottom);
|
||||||
}
|
}
|
||||||
@@ -195,7 +199,7 @@ onMounted(() => {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="chat-messages" style="height: 100%;" ref="messagesRef">
|
<div class="chat-messages" style="height: 100%" ref="messagesRef">
|
||||||
<div
|
<div
|
||||||
v-for="msg in messages"
|
v-for="msg in messages"
|
||||||
:key="msg.id"
|
:key="msg.id"
|
||||||
|
|||||||
Reference in New Issue
Block a user