aa818b75a8
- 删除 system/bk/ 下所有旧提示词,平铺替换为 23 个新文件 - 文件名格式统一为: 文案——描述.txt - 后端: _meta.json 扁平化,loader.py 新增 list_prompt_files() + load_prompt_file() - 后端: API 从 subcategory 改为 filename,按指定文件读取 - 后端: categories 接口返回文件列表(label/desc/filename)供前端展示 - 前端: ScriptCreation 分类选择改为卡片网格,展示文案+描述 - 前端: 清理 subcategoryCode,统一改为 filename - 前端: 字幕字号调整为 64/96/80px
229 lines
6.9 KiB
Python
229 lines
6.9 KiB
Python
"""
|
|
脚本生成服务
|
|
============
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from app.ai.model_router import get_model_router
|
|
from app.ai.prompts import load_prompt_file, load_script_user_prompt
|
|
from app.schemas.script import ScriptShot
|
|
from app.services.ai_response_utils import (
|
|
safe_parse_ai_json_response,
|
|
validate_and_normalize_shots,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ScriptService:
|
|
"""脚本生成服务"""
|
|
|
|
|
|
def __init__(self):
|
|
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
|
|
|
|
|
|
|
|
def _load_prompt(self, name: str) -> str:
|
|
"""加载 Prompt 模板"""
|
|
prompt_file = self.prompts_dir / f"{name}.txt"
|
|
if prompt_file.exists():
|
|
return prompt_file.read_text(encoding="utf-8")
|
|
return ""
|
|
|
|
async def generate_script(
|
|
self,
|
|
category: str,
|
|
filename: str,
|
|
model: str | None = None,
|
|
) -> list[ScriptShot]:
|
|
"""
|
|
同步生成脚本
|
|
|
|
Args:
|
|
category: 大类代码,如 "bk"
|
|
filename: 提示词文件名,如 "水电改造避坑——水电改造的4个坑.txt"
|
|
model: 指定模型
|
|
|
|
Returns:
|
|
分镜列表
|
|
"""
|
|
# 获取 model_router
|
|
model_router = await get_model_router()
|
|
|
|
# 加载 Prompt
|
|
system_prompt = load_prompt_file(category, filename)
|
|
if not system_prompt:
|
|
raise ValueError(f"未找到提示词: category={category}, filename={filename}")
|
|
|
|
# 用户提示词
|
|
user_prompt = load_script_user_prompt(
|
|
topic=f"{category}/{filename}",
|
|
)
|
|
|
|
# 调用 AI 生成
|
|
# 注意:system prompt 中已要求"只输出纯 JSON",不依赖 response_format 参数
|
|
result = await model_router.generate(
|
|
prompt=user_prompt,
|
|
system_prompt=system_prompt,
|
|
model_id=model,
|
|
task_type="script",
|
|
)
|
|
|
|
if not result.content or not result.content.strip():
|
|
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
|
|
|
|
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
|
|
|
|
if not success:
|
|
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
|
|
|
try:
|
|
shots_data = validate_and_normalize_shots(parsed_data)
|
|
|
|
if not shots_data:
|
|
raise ValueError("AI 返回的分镜数据为空或格式不正确")
|
|
|
|
shots = [ScriptShot(**shot) for shot in shots_data]
|
|
return shots
|
|
|
|
except Exception as e:
|
|
raise ValueError(f"分镜数据处理失败: {str(e)}")
|
|
|
|
async def polish_content(
|
|
self,
|
|
content: str,
|
|
polish_type: str = "voiceover",
|
|
shot_type: str = "segment",
|
|
) -> str:
|
|
"""
|
|
润色内容
|
|
|
|
Args:
|
|
content: 待润色内容
|
|
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文本)
|
|
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
|
|
|
|
Returns:
|
|
润色后的内容
|
|
"""
|
|
# 获取 model_router
|
|
model_router = await get_model_router()
|
|
|
|
# 从文件加载提示词模板
|
|
if polish_type == "scene":
|
|
# 画面润色需要根据镜头类型选择不同提示词
|
|
if shot_type == "empty_shot":
|
|
prompt_template = self._load_prompt("polish/scene_empty_shot")
|
|
else:
|
|
prompt_template = self._load_prompt("polish/scene_segment")
|
|
|
|
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
|
|
if not prompt_template:
|
|
prompt_template = self._load_prompt("polish/scene")
|
|
else:
|
|
# 配音文本润色
|
|
prompt_template = self._load_prompt("polish/voiceover")
|
|
|
|
if not prompt_template:
|
|
# 最终回退
|
|
prompt_template = "请润色以下内容:\n\n{content}"
|
|
|
|
prompt = prompt_template.format(content=content)
|
|
|
|
try:
|
|
async with asyncio.timeout(15):
|
|
result = await model_router.generate(
|
|
prompt=prompt,
|
|
task_type="polish",
|
|
max_tokens=300,
|
|
)
|
|
return result.content.strip()
|
|
except TimeoutError:
|
|
raise ValueError("润色请求超时,请重试")
|
|
except Exception as e:
|
|
raise ValueError(f"润色失败: {str(e)}")
|
|
|
|
async def check_model_health(self) -> dict:
|
|
"""检查模型健康状态"""
|
|
model_router = await get_model_router()
|
|
health_results = await model_router.health_check()
|
|
|
|
models = []
|
|
available_count = 0
|
|
recommended = None
|
|
|
|
for _provider_id, health in health_results.items():
|
|
model_info = {
|
|
"id": health.id,
|
|
"name": health.name,
|
|
"is_available": health.is_available,
|
|
"response_time": health.response_time,
|
|
"last_error": health.last_error,
|
|
}
|
|
models.append(model_info)
|
|
|
|
if health.is_available:
|
|
available_count += 1
|
|
if recommended is None or health.response_time < recommended.get(
|
|
"response_time", float("inf")
|
|
):
|
|
recommended = model_info
|
|
|
|
total = len(models)
|
|
|
|
return {
|
|
"status": "healthy" if available_count > 0 else "unhealthy",
|
|
"models": models,
|
|
"recommended_model": recommended,
|
|
"total_models": total,
|
|
"available_models": available_count,
|
|
}
|
|
|
|
async def test_model(self, model_id: str | None = None) -> dict:
|
|
"""测试指定模型连接"""
|
|
model_router = await get_model_router()
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
result = await model_router.generate(
|
|
prompt="你好",
|
|
model_id=model_id,
|
|
max_tokens=5,
|
|
)
|
|
|
|
response_time = (time.time() - start_time) * 1000
|
|
|
|
return {
|
|
"success": True,
|
|
"model": result.model,
|
|
"response_time": round(response_time, 2),
|
|
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"model": model_id or "default",
|
|
"error": str(e),
|
|
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
}
|
|
|
|
|
|
# 全局单例
|
|
_script_service: ScriptService | None = None
|
|
|
|
|
|
def get_script_service() -> ScriptService:
|
|
"""获取 ScriptService 单例"""
|
|
global _script_service
|
|
if _script_service is None:
|
|
_script_service = ScriptService()
|
|
return _script_service
|