d0057ecc2c
- volcengine_provider: Ark SDK 同步迭代器 → AsyncOpenAI → httpx 原始 SSE - generate_stream_with_progress 使用 httpx 直接请求,消除 80s+ 缓冲 - 新增 generate_stream (AsyncOpenAI) 作为备用方案 - enable_thinking 替换为 reasoning_effort,支持思考程度控制 - ai_models.yaml: 默认 LLM 改为 doubao-seed-2-0-pro,添加 reasoning_effort: minimal - model_router: 透传 reasoning_effort 参数 - script_service: 4 阶段 SSE 精简 (start→analyzing→generating→complete) - script.py: SSE 直连端点 /script/generate/stream - 前端 ScriptCreation: 直连 SSE 端点,弃用调度器轮询模式
208 lines
6.1 KiB
Python
208 lines
6.1 KiB
Python
"""
|
||
脚本生成 API
|
||
============
|
||
|
||
提供脚本生成、润色、模型健康检查等功能。
|
||
支持 SSE 流式响应。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
|
||
from fastapi import APIRouter, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from app.schemas.common import ApiResponse, success_response
|
||
from app.ai.prompts import list_categories
|
||
from app.schemas.script import (
|
||
CategoryItem,
|
||
GenerateScriptRequest,
|
||
ModelHealthResponse,
|
||
PolishRequest,
|
||
ScriptGenerationEvent,
|
||
ScriptShot,
|
||
TestModelRequest,
|
||
TestModelResponse,
|
||
)
|
||
from app.services.script_service import get_script_service
|
||
|
||
router = APIRouter()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@router.get("/categories", response_model=ApiResponse[list[CategoryItem]])
|
||
async def get_categories():
|
||
"""
|
||
获取提示词分类列表
|
||
|
||
返回所有大类和小类结构,供前端选择。
|
||
"""
|
||
categories = list_categories()
|
||
return success_response(
|
||
data=categories,
|
||
message="获取分类列表成功",
|
||
)
|
||
|
||
|
||
@router.post("/generate", response_model=ApiResponse[list[ScriptShot]])
|
||
async def generate_script(request: GenerateScriptRequest):
|
||
"""
|
||
同步生成脚本
|
||
|
||
直接返回生成的分镜列表,适合快速预览。
|
||
"""
|
||
service = get_script_service()
|
||
|
||
shots = await service.generate_script(
|
||
category=request.category,
|
||
subcategory=request.subcategory,
|
||
duration=request.duration,
|
||
script_type=request.script_type,
|
||
model=request.model,
|
||
)
|
||
|
||
return success_response(
|
||
data=shots,
|
||
message=f"成功生成 {len(shots)} 个分镜",
|
||
)
|
||
|
||
|
||
@router.post("/generate/stream")
|
||
async def generate_script_stream(request: Request, data: GenerateScriptRequest):
|
||
"""
|
||
流式生成脚本(SSE)
|
||
|
||
返回 Server-Sent Events,包含进度更新和最终结果。
|
||
前端通过 EventSource 接收实时进度。
|
||
|
||
**SSE 事件类型:**
|
||
- `start`: 开始生成
|
||
- `analyzing`: 分析主题
|
||
- `planning`: 规划结构
|
||
- `generating`: AI 生成中
|
||
- `parsing`: 解析结果
|
||
- `complete`: 完成,包含 result 字段
|
||
- `error`: 错误
|
||
|
||
**示例事件流:**
|
||
```
|
||
data: {"type": "start", "progress": 0, "message": "开始生成脚本"}
|
||
|
||
data: {"type": "analyzing", "progress": 15, "message": "分析目标受众..."}
|
||
|
||
data: {"type": "complete", "progress": 100, "message": "成功生成 5 个分镜", "result": [...]}
|
||
```
|
||
"""
|
||
service = get_script_service()
|
||
|
||
async def event_generator():
|
||
"""SSE 事件生成器,带客户端断开检测"""
|
||
try:
|
||
async for event in service.generate_script_stream(
|
||
category=data.category,
|
||
subcategory=data.subcategory,
|
||
duration=data.duration,
|
||
script_type=data.script_type,
|
||
model=data.model,
|
||
):
|
||
# 检查客户端是否已断开
|
||
if await request.is_disconnected():
|
||
logger.info("[SSE] 客户端已断开连接,停止生成")
|
||
break
|
||
|
||
# SSE 格式:data: {...}\n\n
|
||
try:
|
||
yield f"data: {event.model_dump_json()}\n\n"
|
||
await asyncio.sleep(0.05) # 确保事件被 flush,前端有时间渲染
|
||
except Exception as e:
|
||
logger.error(f"[SSE] 序列化事件失败: {e}")
|
||
continue
|
||
|
||
# 发送结束标记(如果客户端还连接着)
|
||
if not await request.is_disconnected():
|
||
yield "data: [DONE]\n\n"
|
||
|
||
except Exception as e:
|
||
logger.exception("[SSE] 事件生成器异常")
|
||
# 尝试发送错误信息给客户端
|
||
try:
|
||
error_event = ScriptGenerationEvent(
|
||
type="error",
|
||
progress=0,
|
||
message=f"服务器错误: {str(e)}",
|
||
)
|
||
yield f"data: {error_event.model_dump_json()}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
except:
|
||
pass
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
|
||
},
|
||
)
|
||
|
||
|
||
@router.post("/polish", response_model=ApiResponse[str])
|
||
async def polish_content(request: PolishRequest):
|
||
"""
|
||
AI 润色文案/画面描述
|
||
|
||
- `polishType=scene`: 润色画面描述(根据 shot_type 自动区分分镜/空镜)
|
||
- `polishType=voiceover`: 润色配音文案
|
||
|
||
参数:
|
||
- `shot_type`: "segment"(分镜)或 "empty_shot"(空镜),画面润色时必填
|
||
"""
|
||
service = get_script_service()
|
||
|
||
polished = await service.polish_content(
|
||
content=request.content,
|
||
polish_type=request.polish_type,
|
||
shot_type=request.shot_type or "segment",
|
||
)
|
||
|
||
type_name = "画面" if request.polish_type == "scene" else "文案"
|
||
return success_response(
|
||
data=polished,
|
||
message=f"{type_name}润色完成",
|
||
)
|
||
|
||
|
||
@router.get("/model-health", response_model=ApiResponse[ModelHealthResponse])
|
||
async def check_model_health():
|
||
"""
|
||
检查 AI 模型健康状态
|
||
|
||
返回所有配置的模型及其可用性状态。
|
||
"""
|
||
service = get_script_service()
|
||
health_data = await service.check_model_health()
|
||
|
||
return success_response(
|
||
data=ModelHealthResponse(**health_data),
|
||
message="模型健康检查完成",
|
||
)
|
||
|
||
|
||
@router.post("/test-model", response_model=ApiResponse[TestModelResponse])
|
||
async def test_model(request: TestModelRequest):
|
||
"""
|
||
测试指定模型连接
|
||
|
||
发送一个简单的测试请求,验证模型是否可用。
|
||
"""
|
||
service = get_script_service()
|
||
result = await service.test_model(request.model_id)
|
||
|
||
return success_response(
|
||
data=TestModelResponse(**result),
|
||
message="模型测试完成" if result["success"] else f"模型测试失败: {result.get('error')}",
|
||
)
|