84 lines
3.4 KiB
Python
84 lines
3.4 KiB
Python
import os
|
|
from datetime import datetime
|
|
|
|
from rest_framework.response import Response
|
|
|
|
from ai.models import Drawing
|
|
from backend import settings
|
|
from llm.enums import LLMProvider
|
|
from llm.factory import get_adapter
|
|
from utils.serializers import CustomModelSerializer
|
|
from utils.custom_model_viewSet import CustomModelViewSet
|
|
from django_filters import rest_framework as filters
|
|
|
|
|
|
class DrawingSerializer(CustomModelSerializer):
|
|
"""
|
|
AI 绘画表 序列化器
|
|
"""
|
|
class Meta:
|
|
model = Drawing
|
|
fields = '__all__'
|
|
read_only_fields = ['id', 'create_time', 'update_time']
|
|
|
|
|
|
class DrawingFilter(filters.FilterSet):
|
|
|
|
class Meta:
|
|
model = Drawing
|
|
fields = ['id', 'remark', 'creator', 'modifier', 'is_deleted', 'public_status', 'platform',
|
|
'model', 'width', 'height', 'status', 'pic_url', 'error_message', 'task_id', 'buttons']
|
|
|
|
|
|
class DrawingViewSet(CustomModelViewSet):
|
|
"""
|
|
AI 绘画表 视图集
|
|
"""
|
|
queryset = Drawing.objects.filter(is_deleted=False).order_by('-id')
|
|
serializer_class = DrawingSerializer
|
|
filterset_class = DrawingFilter
|
|
search_fields = ['name'] # 根据实际字段调整
|
|
ordering_fields = ['create_time', 'id']
|
|
ordering = ['-create_time']
|
|
|
|
def create(self, request, *args, **kwargs):
|
|
model = request.data.get('model')
|
|
prompt = request.data.get('prompt')
|
|
n = request.data.get('n', 1)
|
|
style = request.data.get('style')
|
|
size = request.data.get('size')
|
|
api_key = settings.DASHSCOPE_API_KEY
|
|
request.data['width'] = int(size.split('*')[0])
|
|
request.data['height'] = int(size.split('*')[1])
|
|
adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model=model)
|
|
rsp = adapter.create_drawing_task(prompt=prompt, n=n, style=style, size=size)
|
|
if rsp['status_code'] != 200:
|
|
return Response(rsp['data'], status=rsp['status_code'])
|
|
else:
|
|
request.data['status'] = rsp['output']['task_status']
|
|
request.data['task_id'] = rsp['output']['task_id']
|
|
return super().create(request, *args, **kwargs)
|
|
|
|
def retrieve(self, request, *args, **kwargs):
|
|
instance = self.get_object()
|
|
if instance.status in ("PENDING", 'RUNNING'):
|
|
api_key = settings.DASHSCOPE_API_KEY
|
|
adapter = get_adapter(LLMProvider.TONGYI, api_key=api_key, model='')
|
|
rsp = adapter.fetch_drawing_task_status(instance.task_id)
|
|
print(rsp, 'sadsadas')
|
|
if rsp['status_code'] == 200:
|
|
# 可根据 status.output.task_status 更新数据库
|
|
if rsp['output']['task_status'] == 'SUCCEEDED':
|
|
instance.update_time = datetime.now()
|
|
instance.status = rsp['output']['task_status']
|
|
instance.pic_url = rsp['output']['results'][0]['url']
|
|
elif rsp['output']['task_status'] == 'FAILED':
|
|
instance.update_time = datetime.now()
|
|
instance.status = rsp['output']['task_status']
|
|
instance.error_message = rsp['output']['message']
|
|
elif rsp['output']['task_status'] == 'RUNNING':
|
|
instance.update_time = datetime.now()
|
|
instance.status = rsp['output']['task_status']
|
|
instance.save()
|
|
return super().retrieve(request, *args, **kwargs)
|