02b5a89eaf
- Add 15s asyncio.timeout() to polish_content and generate_title - Add try/except to /polish route for unified error response - Add asyncio.TimeoutError handling to /generate-title route
282 lines
9.2 KiB
Python
282 lines
9.2 KiB
Python
"""
|
||
脚本生成 API
|
||
============
|
||
|
||
提供脚本生成、润色、模型健康检查等功能。
|
||
支持 SSE 流式响应。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
|
||
from fastapi import APIRouter, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from app.schemas.common import ApiResponse, success_response
|
||
from app.ai.model_router import get_model_router
|
||
from app.ai.prompts import list_categories, load_prompt, render_template
|
||
from app.schemas.script import (
|
||
CategoryItem,
|
||
GenerateScriptRequest,
|
||
GenerateTitleRequest,
|
||
GenerateTitleResponse,
|
||
ModelHealthResponse,
|
||
PolishRequest,
|
||
ScriptGenerationEvent,
|
||
ScriptShot,
|
||
TestModelRequest,
|
||
TestModelResponse,
|
||
)
|
||
from app.services.script_service import get_script_service
|
||
|
||
router = APIRouter()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@router.get("/categories", response_model=ApiResponse[list[CategoryItem]])
|
||
async def get_categories():
|
||
"""
|
||
获取提示词分类列表
|
||
|
||
返回所有大类和小类结构,供前端选择。
|
||
"""
|
||
categories = list_categories()
|
||
return success_response(
|
||
data=categories,
|
||
message="获取分类列表成功",
|
||
)
|
||
|
||
|
||
@router.post("/generate/stream")
|
||
async def generate_script_stream(request: Request, data: GenerateScriptRequest):
|
||
"""
|
||
流式生成脚本(SSE)
|
||
|
||
返回 Server-Sent Events,包含进度更新和最终结果。
|
||
前端通过 EventSource 接收实时进度。
|
||
|
||
**SSE 事件类型:**
|
||
- `start`: 开始生成
|
||
- `generating`: AI 生成中
|
||
- `complete`: 完成,包含 result 字段
|
||
- `error`: 错误
|
||
|
||
**示例事件流:**
|
||
```
|
||
data: {"type": "start", "message": "正在创作脚本..."}
|
||
|
||
data: {"type": "generating", "message": "正在创作脚本..."}
|
||
|
||
data: {"type": "complete", "message": "脚本生成成功", "result": [...]}
|
||
```
|
||
"""
|
||
service = get_script_service()
|
||
|
||
async def event_generator():
|
||
"""SSE 事件生成器,带客户端断开检测"""
|
||
try:
|
||
async for event in service.generate_script_stream(
|
||
category=data.category,
|
||
subcategory=data.subcategory,
|
||
duration=data.duration,
|
||
script_type=data.script_type,
|
||
model=data.model,
|
||
):
|
||
# 检查客户端是否已断开
|
||
if await request.is_disconnected():
|
||
logger.info("[SSE] 客户端已断开连接,停止生成")
|
||
break
|
||
|
||
# SSE 格式:data: {...}\n\n
|
||
try:
|
||
yield f"data: {event.model_dump_json()}\n\n"
|
||
await asyncio.sleep(0.05) # 确保事件被 flush,前端有时间渲染
|
||
except Exception as e:
|
||
logger.error(f"[SSE] 序列化事件失败: {e}")
|
||
continue
|
||
|
||
# 发送结束标记(如果客户端还连接着)
|
||
if not await request.is_disconnected():
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
logger.exception("[SSE] 事件生成器异常")
|
||
# 尝试发送错误信息给客户端
|
||
try:
|
||
error_event = ScriptGenerationEvent(
|
||
type="error",
|
||
progress=0,
|
||
message=f"服务器错误: {str(e)}",
|
||
)
|
||
yield f"data: {error_event.model_dump_json()}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
except:
|
||
pass
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
|
||
},
|
||
)
|
||
|
||
|
||
@router.post("/polish", response_model=ApiResponse[str])
|
||
async def polish_content(request: PolishRequest):
|
||
"""
|
||
AI 润色文案/画面描述
|
||
|
||
- `polishType=scene`: 润色画面描述(根据 shot_type 自动区分分镜/空镜)
|
||
- `polishType=voiceover`: 润色配音文案
|
||
|
||
参数:
|
||
- `shot_type`: "segment"(分镜)或 "empty_shot"(空镜),画面润色时必填
|
||
"""
|
||
service = get_script_service()
|
||
type_name = "画面" if request.polish_type == "scene" else "文案"
|
||
|
||
try:
|
||
polished = await service.polish_content(
|
||
content=request.content,
|
||
polish_type=request.polish_type,
|
||
shot_type=request.shot_type or "segment",
|
||
)
|
||
|
||
return success_response(
|
||
data=polished,
|
||
message=f"{type_name}润色完成",
|
||
)
|
||
except ValueError as e:
|
||
logger.warning(f"[Polish] 润色失败: {e}")
|
||
return success_response(
|
||
code=500,
|
||
message=str(e),
|
||
data=None,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"[Polish] 润色异常: {e}")
|
||
return success_response(
|
||
code=500,
|
||
message=f"{type_name}润色失败: {str(e)}",
|
||
data=None,
|
||
)
|
||
|
||
|
||
@router.get("/model-health", response_model=ApiResponse[ModelHealthResponse])
|
||
async def check_model_health():
|
||
"""
|
||
检查 AI 模型健康状态
|
||
|
||
返回所有配置的模型及其可用性状态。
|
||
"""
|
||
service = get_script_service()
|
||
health_data = await service.check_model_health()
|
||
|
||
return success_response(
|
||
data=ModelHealthResponse(**health_data),
|
||
message="模型健康检查完成",
|
||
)
|
||
|
||
|
||
@router.post("/test-model", response_model=ApiResponse[TestModelResponse])
|
||
async def test_model(request: TestModelRequest):
|
||
"""
|
||
测试指定模型连接
|
||
|
||
发送一个简单的测试请求,验证模型是否可用。
|
||
"""
|
||
service = get_script_service()
|
||
result = await service.test_model(request.model_id)
|
||
|
||
return success_response(
|
||
data=TestModelResponse(**result),
|
||
message="模型测试完成" if result["success"] else f"模型测试失败: {result.get('error')}",
|
||
)
|
||
|
||
|
||
@router.post("/generate-title", response_model=ApiResponse[GenerateTitleResponse])
|
||
async def generate_title(request: GenerateTitleRequest):
|
||
"""
|
||
根据脚本内容智能生成标题
|
||
|
||
调用 LLM 根据脚本内容生成大标题或小标题。
|
||
提示词从文件加载,支持热更新。
|
||
"""
|
||
model_router = await get_model_router()
|
||
|
||
# 加载提示词
|
||
system_prompt = load_prompt("user/title_system")
|
||
user_template = load_prompt("user/title")
|
||
|
||
if not system_prompt or not user_template:
|
||
return success_response(
|
||
code=500,
|
||
message="标题生成提示词文件缺失",
|
||
data=None,
|
||
)
|
||
|
||
# 根据使用场景确定描述
|
||
if request.usage == "cover":
|
||
usage_desc = "视频封面标题——用于封面图设计,是决定用户是否点击的第一要素"
|
||
style_requirement = "极具冲击力、抓眼球,适合静态封面大图展示,善用爆款句式"
|
||
usage_note = "- 封面主标题必须极度吸睛,让用户一眼就想点进去,善用数字、疑问、痛点、冲突\n- 封面副标题要补充悬念或细节,激发点击欲望"
|
||
else:
|
||
usage_desc = "视频画面标题——直接叠加在视频画面上,与动态视频内容共存"
|
||
style_requirement = "口语化、精炼有力,适合视频内展示,避免遮挡画面主体"
|
||
usage_note = "- 视频画面上的标题需要精炼,聚焦核心关键词\n- 副标题与主标题形成呼应,补充说明但不喧宾夺主"
|
||
|
||
# 渲染用户提示词
|
||
title_type_desc = "大标题(主标题,提炼核心卖点,吸睛)" if request.title_type == "main" else "小标题(副标题,补充说明或制造悬念)"
|
||
user_prompt = render_template(
|
||
user_template,
|
||
title_type=request.title_type,
|
||
title_type_desc=title_type_desc,
|
||
script_content=request.script_content,
|
||
max_length=request.max_length,
|
||
usage=request.usage,
|
||
usage_desc=usage_desc,
|
||
style_requirement=style_requirement,
|
||
usage_note=usage_note,
|
||
)
|
||
|
||
try:
|
||
async with asyncio.timeout(15):
|
||
result = await model_router.generate(
|
||
prompt=user_prompt,
|
||
system_prompt=system_prompt,
|
||
task_type="script",
|
||
temperature=0.8,
|
||
max_tokens=64,
|
||
)
|
||
|
||
title = result.content.strip() if result.content else ""
|
||
# 去除可能的引号
|
||
title = title.strip('"').strip("'").strip('「」').strip('『』').strip('《》')
|
||
# 截断到最大长度
|
||
if len(title) > request.max_length:
|
||
title = title[:request.max_length]
|
||
|
||
return success_response(
|
||
data=GenerateTitleResponse(title=title),
|
||
message="标题生成成功",
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.warning("[generate_title] 标题生成超时")
|
||
return success_response(
|
||
code=500,
|
||
message="标题生成超时,请重试",
|
||
data=None,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"[generate_title] 标题生成失败: {e}")
|
||
return success_response(
|
||
code=500,
|
||
message=f"标题生成失败: {str(e)}",
|
||
data=None,
|
||
)
|