ai chat init
This commit is contained in:
26
backend/ai/chat.py
Normal file
26
backend/ai/chat.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
import json
|
||||
from ai.langchain_client import get_ai_reply_stream
|
||||
from ai.utils import get_first_available_ai_config
|
||||
|
||||
|
||||
class ChatConsumer(AsyncWebsocketConsumer):
|
||||
async def connect(self):
|
||||
await self.accept()
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
pass
|
||||
|
||||
async def receive(self, text_data):
|
||||
data = json.loads(text_data)
|
||||
user_message = data.get("message", "")
|
||||
|
||||
model, api_key, api_base = await get_first_available_ai_config()
|
||||
|
||||
async def send_chunk(chunk):
|
||||
await self.send(text_data=json.dumps({"is_streaming": True, "message": chunk}))
|
||||
|
||||
await get_ai_reply_stream(user_message, send_chunk, model_name=model, api_key=api_key, api_base=api_base)
|
||||
|
||||
# 结束标记
|
||||
await self.send(text_data=json.dumps({"done": True}))
|
||||
25
backend/ai/langchain_client.py
Normal file
25
backend/ai/langchain_client.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
|
||||
|
||||
class MyHandler(AsyncCallbackHandler):
|
||||
def __init__(self, send_func):
|
||||
super().__init__()
|
||||
self.send_func = send_func
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
await self.send_func(token)
|
||||
|
||||
async def get_ai_reply_stream(message: str, send_func, api_key, api_base, model_name):
|
||||
# 实例化时就带回调
|
||||
chat = ChatOpenAI(
|
||||
openai_api_key=api_key,
|
||||
openai_api_base=api_base,
|
||||
model_name=model_name,
|
||||
temperature=0.7,
|
||||
streaming=True,
|
||||
callbacks=[MyHandler(send_func)]
|
||||
)
|
||||
await chat.ainvoke([HumanMessage(content=message)])
|
||||
@@ -218,14 +218,12 @@ class ChatRole(CoreModel):
|
||||
blank=True,
|
||||
related_name="roles",
|
||||
verbose_name="关联的知识库",
|
||||
db_comment="关联的知识库"
|
||||
)
|
||||
tools = models.ManyToManyField(
|
||||
'Tool',
|
||||
blank=True,
|
||||
related_name="roles",
|
||||
verbose_name="关联的工具",
|
||||
db_comment="关联的工具"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
|
||||
7
backend/ai/routing.py
Normal file
7
backend/ai/routing.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from django.urls import re_path
|
||||
|
||||
from ai.chat import ChatConsumer
|
||||
|
||||
websocket_urlpatterns = [
|
||||
re_path(r'ws/chat/$', ChatConsumer.as_asgi()),
|
||||
]
|
||||
11
backend/ai/utils.py
Normal file
11
backend/ai/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ai.models import AIModel
|
||||
from utils.models import CommonStatus
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
@sync_to_async
|
||||
def get_first_available_ai_config():
|
||||
# 这里只取第一个可用的,可以根据实际业务加筛选条件
|
||||
ai = AIModel.objects.filter(status=CommonStatus.ENABLED).prefetch_related('key').first()
|
||||
if not ai:
|
||||
raise Exception('没有可用的AI配置')
|
||||
return ai.model, ai.key.api_key, ai.key.url
|
||||
@@ -1,16 +1,17 @@
|
||||
"""
|
||||
ASGI config for backend project.
|
||||
|
||||
It exposes the ASGI callable as a module-level variable named ``application``.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/5.2/howto/deployment/asgi/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
|
||||
|
||||
application = get_asgi_application()
|
||||
# 延迟导入,避免 AppRegistryNotReady 错误
|
||||
def get_websocket_urlpatterns():
|
||||
from ai.routing import websocket_urlpatterns
|
||||
return websocket_urlpatterns
|
||||
|
||||
application = ProtocolTypeRouter({
|
||||
"http": get_asgi_application(),
|
||||
"websocket": URLRouter(
|
||||
get_websocket_urlpatterns()
|
||||
),
|
||||
})
|
||||
@@ -53,6 +53,7 @@ INSTALLED_APPS = [
|
||||
'django_filters',
|
||||
'corsheaders',
|
||||
'rest_framework.authtoken',
|
||||
'channels',
|
||||
"system",
|
||||
"ai",
|
||||
]
|
||||
@@ -231,5 +232,15 @@ LOGGING = {
|
||||
}
|
||||
}
|
||||
|
||||
ASGI_APPLICATION = 'backend.asgi.application'
|
||||
|
||||
|
||||
# 简单用内存通道层
|
||||
CHANNEL_LAYERS = {
|
||||
'default': {
|
||||
'BACKEND': 'channels.layers.InMemoryChannelLayer'
|
||||
}
|
||||
}
|
||||
|
||||
if os.path.exists(os.path.join(BASE_DIR, 'backend/local_settings.py')):
|
||||
from backend.local_settings import *
|
||||
@@ -13,4 +13,9 @@ eventlet==0.40.0
|
||||
goofish_api==0.0.6
|
||||
flower==2.0.1
|
||||
gunicorn==23.0.0
|
||||
django_redis==6.0.0
|
||||
django_redis==6.0.0
|
||||
django-ninja==1.4.3
|
||||
openai==1.95
|
||||
daphne==4.2.1
|
||||
langchain==0.3.26
|
||||
langchain-community==0.3.27
|
||||
Reference in New Issue
Block a user