Files
meijiaka-zy/python-api/app/services/script_service.py
T

239 lines
7.5 KiB
Python

"""
脚本生成服务
============
"""
import asyncio
import logging
import time
from pathlib import Path
from typing import Any
from app.ai.model_router import get_model_router
from app.ai.prompts import load_prompt_file, load_script_user_prompt
from app.core.exceptions import (
AIEmptyResponseException,
AIParseErrorException,
AITimeoutException,
PromptNotFoundException,
)
from app.schemas.script import ScriptShot
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
validate_and_normalize_shots,
)
logger = logging.getLogger(__name__)
class ScriptService:
"""脚本生成服务"""
def __init__(self):
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
def _load_prompt(self, name: str) -> str:
"""加载 Prompt 模板"""
prompt_file = self.prompts_dir / f"{name}.txt"
if prompt_file.exists():
return prompt_file.read_text(encoding="utf-8")
return ""
async def generate_script(
self,
category: str,
filename: str,
model: str | None = None,
) -> list[ScriptShot]:
"""
同步生成脚本
Args:
category: 大类代码,如 "bk"
filename: 提示词文件名,如 "水电改造避坑——水电改造的4个坑.txt"
model: 指定模型
Returns:
分镜列表
"""
# 获取 model_router
model_router = await get_model_router()
# 加载 Prompt
system_prompt = load_prompt_file(category, filename)
if not system_prompt:
raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}")
# 用户提示词
user_prompt = load_script_user_prompt(
topic=f"{category}/{filename}",
)
# 调用 AI 生成
# 注意:system prompt 中已要求"只输出纯 JSON",不依赖 response_format 参数
result = await model_router.generate(
prompt=user_prompt,
system_prompt=system_prompt,
model_id=model,
task_type="script",
)
if not result.content or not result.content.strip():
raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试")
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
if not success:
raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON")
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确")
shots = [ScriptShot(**shot) for shot in shots_data]
return shots
except (AIEmptyResponseException, AIParseErrorException):
raise
except Exception as e:
raise AIParseErrorException(f"分镜数据处理失败: {str(e)}")
async def polish_content(
self,
content: str,
polish_type: str = "voiceover",
shot_type: str = "segment",
) -> str:
"""
润色内容
Args:
content: 待润色内容
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文本)
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
Returns:
润色后的内容
"""
# 获取 model_router
model_router = await get_model_router()
# 从文件加载提示词模板
if polish_type == "scene":
# 画面润色需要根据镜头类型选择不同提示词
if shot_type == "empty_shot":
prompt_template = self._load_prompt("polish/scene_empty_shot")
else:
prompt_template = self._load_prompt("polish/scene_segment")
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
if not prompt_template:
prompt_template = self._load_prompt("polish/scene")
else:
# 配音文本润色
prompt_template = self._load_prompt("polish/voiceover")
if not prompt_template:
# 最终回退
prompt_template = "请润色以下内容:\n\n{content}"
prompt = prompt_template.format(content=content)
try:
async with asyncio.timeout(15):
result = await model_router.generate(
prompt=prompt,
task_type="polish",
max_tokens=300,
)
return result.content.strip()
except TimeoutError:
raise AITimeoutException("润色请求超时,请重试")
except (AIEmptyResponseException, AIParseErrorException, AITimeoutException):
raise
except Exception as e:
raise AIParseErrorException(f"润色失败: {str(e)}")
async def check_model_health(self) -> dict:
"""检查模型健康状态"""
model_router = await get_model_router()
health_results = await model_router.health_check()
models: list[dict[str, Any]] = []
available_count = 0
recommended: dict[str, Any] | None = None
for _provider_id, health in health_results.items():
model_info: dict[str, Any] = {
"id": health.id,
"name": health.name,
"is_available": health.is_available,
"response_time": health.response_time,
"last_error": health.last_error,
}
models.append(model_info)
if health.is_available:
available_count += 1
current_best = (
float("inf")
if recommended is None
else float(recommended.get("response_time") or float("inf"))
)
if health.response_time < current_best:
recommended = model_info
total = len(models)
return {
"status": "healthy" if available_count > 0 else "unhealthy",
"models": models,
"recommended_model": recommended,
"total_models": total,
"available_models": available_count,
}
async def test_model(self, model_id: str | None = None) -> dict:
"""测试指定模型连接"""
model_router = await get_model_router()
start_time = time.time()
try:
result = await model_router.generate(
prompt="你好",
model_id=model_id,
max_tokens=5,
)
response_time = (time.time() - start_time) * 1000
return {
"success": True,
"model": result.model,
"response_time": round(response_time, 2),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
except Exception as e:
return {
"success": False,
"model": model_id or "default",
"error": str(e),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
# 全局单例
_script_service: ScriptService | None = None
def get_script_service() -> ScriptService:
"""获取 ScriptService 单例"""
global _script_service
if _script_service is None:
_script_service = ScriptService()
return _script_service