Files
meijiaka-zy/python-api/app/services/script_service.py
T
小鱼开发 7640dc23ba feat: 区分 SSE 阶段文案 + 前端请求去重锁
- script_service: start文案"准备生成...", generating文案"正在创作脚本..."
- ScriptCreation: 添加 requestLock ref,防止网络延迟期间快速点击发起多个请求
- 锁在请求开始时设置,完成后释放,与 generating 状态互补
2026-04-26 21:17:27 +08:00

362 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
脚本生成服务
============
"""
import asyncio
import logging
import re
import time
from collections.abc import AsyncIterator
from pathlib import Path
from app.ai.model_router import get_model_router
from app.ai.prompts import load_script_user_prompt, load_system_prompt
from app.schemas.script import ScriptGenerationEvent, ScriptShot
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
validate_and_normalize_shots,
)
logger = logging.getLogger(__name__)
class ScriptService:
"""脚本生成服务"""
def __init__(self):
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
def _load_prompt(self, name: str) -> str:
"""加载 Prompt 模板"""
prompt_file = self.prompts_dir / f"{name}.txt"
if prompt_file.exists():
return prompt_file.read_text(encoding="utf-8")
return ""
@staticmethod
def _extract_json(content: str) -> str:
"""
从 Markdown 代码块中提取 JSON,或返回原始内容
支持格式:
- ```json {...} ```
- ``` {...} ```
- 纯 JSON 文本
"""
if not content:
return ""
content = content.strip()
# 匹配 ```json ... ``` 或 ``` ... ```
pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
matches = re.findall(pattern, content)
if matches:
# 取最后一个匹配(避免前面有示例代码)
return matches[-1].strip()
# 如果没有代码块,返回原始内容
return content
async def generate_script(
self,
category: str,
subcategory: str,
duration: int,
script_type: str,
model: str | None = None,
) -> list[ScriptShot]:
"""
同步生成脚本
Args:
category: 大类代码,如 "bk"
subcategory: 小类代码,如 "ht"
duration: 视频时长(秒)
script_type: 脚本类型
model: 指定模型
Returns:
分镜列表
"""
# 获取 model_router
model_router = await get_model_router()
# 加载 Prompt
system_prompt = load_system_prompt(category, subcategory)
if not system_prompt:
raise ValueError(f"未找到提示词: category={category}, subcategory={subcategory}")
# 用户提示词
user_prompt = load_script_user_prompt(
topic=f"{category}/{subcategory}",
duration=duration,
)
# 调用 AI 生成
result = await model_router.generate(
prompt=user_prompt,
system_prompt=system_prompt,
model_id=model,
task_type="script",
temperature=0.7,
response_format="json_object",
)
if not result.content or not result.content.strip():
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
if not success:
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
raise ValueError("AI 返回的分镜数据为空或格式不正确")
shots = [ScriptShot(**shot) for shot in shots_data]
return shots
except Exception as e:
raise ValueError(f"分镜数据处理失败: {str(e)}")
async def generate_script_stream(
self,
category: str,
subcategory: str,
duration: int,
script_type: str,
model: str | None = None,
) -> AsyncIterator[ScriptGenerationEvent]:
"""流式生成脚本(SSE"""
model_router = await get_model_router()
try:
# 加载 Prompt
system_prompt = load_system_prompt(category, subcategory)
if not system_prompt:
yield ScriptGenerationEvent(
type="error",
message=f"未找到提示词: category={category}, subcategory={subcategory}",
)
return
user_prompt = load_script_user_prompt(
topic=f"{category}/{subcategory}",
duration=duration,
)
yield ScriptGenerationEvent(
type="start",
message="准备生成...",
)
full_content = ""
has_shown_generating = False
async for chunk in model_router.generate_stream_with_progress(
prompt=user_prompt,
system_prompt=system_prompt,
model_id=model,
task_type="script",
temperature=0.7,
response_format="json_object",
):
if chunk["type"] == "chunk":
chunk_content = chunk.get("content", "")
if not chunk_content:
continue
full_content += chunk_content
if not has_shown_generating:
yield ScriptGenerationEvent(
type="generating",
message="正在创作脚本...",
)
has_shown_generating = True
if not full_content or not full_content.strip():
yield ScriptGenerationEvent(
type="error",
message="AI 返回内容为空,请检查模型配置或重试",
)
return
success, parsed_data, error_msg = safe_parse_ai_json_response(full_content)
if not success:
yield ScriptGenerationEvent(
type="error",
message=f"脚本解析失败: {error_msg or '无法解析 AI 返回的内容'}",
)
return
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
yield ScriptGenerationEvent(
type="error",
message="AI 返回的分镜数据为空或格式不正确",
)
return
shots = [ScriptShot(**shot) for shot in shots_data]
yield ScriptGenerationEvent(
type="complete",
message="脚本生成成功",
result=shots,
)
except Exception as e:
yield ScriptGenerationEvent(
type="error",
message=f"分镜数据处理失败: {str(e)}",
)
except Exception as e:
yield ScriptGenerationEvent(
type="error",
message=f"生成失败: {str(e)}",
)
async def polish_content(
self,
content: str,
polish_type: str = "voiceover",
shot_type: str = "segment",
) -> str:
"""
润色内容
Args:
content: 待润色内容
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文案)
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
Returns:
润色后的内容
"""
# 获取 model_router
model_router = await get_model_router()
# 从文件加载提示词模板
if polish_type == "scene":
# 画面润色需要根据镜头类型选择不同提示词
if shot_type == "empty_shot":
prompt_template = self._load_prompt("polish/scene_empty_shot")
else:
prompt_template = self._load_prompt("polish/scene_segment")
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
if not prompt_template:
prompt_template = self._load_prompt("polish/scene")
else:
# 配音文案润色
prompt_template = self._load_prompt("polish/voiceover")
if not prompt_template:
# 最终回退
prompt_template = "请润色以下内容:\n\n{content}"
prompt = prompt_template.format(content=content)
result = await model_router.generate(
prompt=prompt,
task_type="polish",
temperature=0.5,
max_tokens=300,
)
return result.content.strip()
async def check_model_health(self) -> dict:
"""检查模型健康状态"""
model_router = await get_model_router()
health_results = await model_router.health_check()
models = []
available_count = 0
recommended = None
for provider_id, health in health_results.items():
model_info = {
"id": health.id,
"name": health.name,
"is_available": health.is_available,
"response_time": health.response_time,
"last_error": health.last_error,
}
models.append(model_info)
if health.is_available:
available_count += 1
if recommended is None or health.response_time < recommended.get(
"response_time", float("inf")
):
recommended = model_info
total = len(models)
return {
"status": "healthy" if available_count > 0 else "unhealthy",
"models": models,
"recommended_model": recommended,
"total_models": total,
"available_models": available_count,
}
async def test_model(self, model_id: str | None = None) -> dict:
"""测试指定模型连接"""
model_router = await get_model_router()
import time
start_time = time.time()
try:
result = await model_router.generate(
prompt="你好",
model_id=model_id,
max_tokens=5,
)
response_time = (time.time() - start_time) * 1000
return {
"success": True,
"model": result.model,
"response_time": round(response_time, 2),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
except Exception as e:
return {
"success": False,
"model": model_id or "default",
"error": str(e),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
# 全局单例
_script_service: ScriptService | None = None
def get_script_service() -> ScriptService:
"""获取 ScriptService 单例"""
global _script_service
if _script_service is None:
_script_service = ScriptService()
return _script_service