c79b2323f4
前端: - 删除 scriptType 字段及相关 store action、持久化、API 类型 - 删除 scriptDuration 字段及相关 store action、持久化、加载逻辑 - ScriptCreation 不再传 duration/style 参数给后端 后端: - ScriptParams 删除 duration/style 字段 - ScriptHandler 删除 duration/style 参数读取和传递 - ScriptService.generate_script 签名删除 duration/script_type - load_script_user_prompt 删除 duration 参数 影响:无,duration/style 在 prompt 模板中未被实际使用
229 lines
6.9 KiB
Python
229 lines
6.9 KiB
Python
"""
|
|
脚本生成服务
|
|
============
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from app.ai.model_router import get_model_router
|
|
from app.ai.prompts import load_script_user_prompt, load_system_prompt
|
|
from app.schemas.script import ScriptShot
|
|
from app.services.ai_response_utils import (
|
|
safe_parse_ai_json_response,
|
|
validate_and_normalize_shots,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ScriptService:
|
|
"""脚本生成服务"""
|
|
|
|
|
|
def __init__(self):
|
|
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
|
|
|
|
|
|
|
|
def _load_prompt(self, name: str) -> str:
|
|
"""加载 Prompt 模板"""
|
|
prompt_file = self.prompts_dir / f"{name}.txt"
|
|
if prompt_file.exists():
|
|
return prompt_file.read_text(encoding="utf-8")
|
|
return ""
|
|
|
|
async def generate_script(
|
|
self,
|
|
category: str,
|
|
subcategory: str,
|
|
model: str | None = None,
|
|
) -> list[ScriptShot]:
|
|
"""
|
|
同步生成脚本
|
|
|
|
Args:
|
|
category: 大类代码,如 "bk"
|
|
subcategory: 小类代码,如 "ht"
|
|
model: 指定模型
|
|
|
|
Returns:
|
|
分镜列表
|
|
"""
|
|
# 获取 model_router
|
|
model_router = await get_model_router()
|
|
|
|
# 加载 Prompt
|
|
system_prompt = load_system_prompt(category, subcategory)
|
|
if not system_prompt:
|
|
raise ValueError(f"未找到提示词: category={category}, subcategory={subcategory}")
|
|
|
|
# 用户提示词
|
|
user_prompt = load_script_user_prompt(
|
|
topic=f"{category}/{subcategory}",
|
|
)
|
|
|
|
# 调用 AI 生成
|
|
# 注意:system prompt 中已要求"只输出纯 JSON",不依赖 response_format 参数
|
|
result = await model_router.generate(
|
|
prompt=user_prompt,
|
|
system_prompt=system_prompt,
|
|
model_id=model,
|
|
task_type="script",
|
|
)
|
|
|
|
if not result.content or not result.content.strip():
|
|
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
|
|
|
|
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
|
|
|
|
if not success:
|
|
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
|
|
|
try:
|
|
shots_data = validate_and_normalize_shots(parsed_data)
|
|
|
|
if not shots_data:
|
|
raise ValueError("AI 返回的分镜数据为空或格式不正确")
|
|
|
|
shots = [ScriptShot(**shot) for shot in shots_data]
|
|
return shots
|
|
|
|
except Exception as e:
|
|
raise ValueError(f"分镜数据处理失败: {str(e)}")
|
|
|
|
async def polish_content(
|
|
self,
|
|
content: str,
|
|
polish_type: str = "voiceover",
|
|
shot_type: str = "segment",
|
|
) -> str:
|
|
"""
|
|
润色内容
|
|
|
|
Args:
|
|
content: 待润色内容
|
|
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文案)
|
|
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
|
|
|
|
Returns:
|
|
润色后的内容
|
|
"""
|
|
# 获取 model_router
|
|
model_router = await get_model_router()
|
|
|
|
# 从文件加载提示词模板
|
|
if polish_type == "scene":
|
|
# 画面润色需要根据镜头类型选择不同提示词
|
|
if shot_type == "empty_shot":
|
|
prompt_template = self._load_prompt("polish/scene_empty_shot")
|
|
else:
|
|
prompt_template = self._load_prompt("polish/scene_segment")
|
|
|
|
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
|
|
if not prompt_template:
|
|
prompt_template = self._load_prompt("polish/scene")
|
|
else:
|
|
# 配音文案润色
|
|
prompt_template = self._load_prompt("polish/voiceover")
|
|
|
|
if not prompt_template:
|
|
# 最终回退
|
|
prompt_template = "请润色以下内容:\n\n{content}"
|
|
|
|
prompt = prompt_template.format(content=content)
|
|
|
|
try:
|
|
async with asyncio.timeout(15):
|
|
result = await model_router.generate(
|
|
prompt=prompt,
|
|
task_type="polish",
|
|
max_tokens=300,
|
|
)
|
|
return result.content.strip()
|
|
except TimeoutError:
|
|
raise ValueError("润色请求超时,请重试")
|
|
except Exception as e:
|
|
raise ValueError(f"润色失败: {str(e)}")
|
|
|
|
async def check_model_health(self) -> dict:
|
|
"""检查模型健康状态"""
|
|
model_router = await get_model_router()
|
|
health_results = await model_router.health_check()
|
|
|
|
models = []
|
|
available_count = 0
|
|
recommended = None
|
|
|
|
for _provider_id, health in health_results.items():
|
|
model_info = {
|
|
"id": health.id,
|
|
"name": health.name,
|
|
"is_available": health.is_available,
|
|
"response_time": health.response_time,
|
|
"last_error": health.last_error,
|
|
}
|
|
models.append(model_info)
|
|
|
|
if health.is_available:
|
|
available_count += 1
|
|
if recommended is None or health.response_time < recommended.get(
|
|
"response_time", float("inf")
|
|
):
|
|
recommended = model_info
|
|
|
|
total = len(models)
|
|
|
|
return {
|
|
"status": "healthy" if available_count > 0 else "unhealthy",
|
|
"models": models,
|
|
"recommended_model": recommended,
|
|
"total_models": total,
|
|
"available_models": available_count,
|
|
}
|
|
|
|
async def test_model(self, model_id: str | None = None) -> dict:
|
|
"""测试指定模型连接"""
|
|
model_router = await get_model_router()
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
result = await model_router.generate(
|
|
prompt="你好",
|
|
model_id=model_id,
|
|
max_tokens=5,
|
|
)
|
|
|
|
response_time = (time.time() - start_time) * 1000
|
|
|
|
return {
|
|
"success": True,
|
|
"model": result.model,
|
|
"response_time": round(response_time, 2),
|
|
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"model": model_id or "default",
|
|
"error": str(e),
|
|
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
}
|
|
|
|
|
|
# 全局单例
|
|
_script_service: ScriptService | None = None
|
|
|
|
|
|
def get_script_service() -> ScriptService:
|
|
"""获取 ScriptService 单例"""
|
|
global _script_service
|
|
if _script_service is None:
|
|
_script_service = ScriptService()
|
|
return _script_service
|