43609de2f1
- script_service: 删除流式生成各阶段的 info/debug 日志 - model_router: 删除首 chunk 延迟、provider 信息日志 - volcengine_provider: 删除 SDK request、首 chunk 耗时、流结束统计日志 - 保留: 初始化日志、降级/错误日志、API 层异常日志 - 为后续统一结构化日志规划做准备
371 lines
11 KiB
Python
371 lines
11 KiB
Python
"""
|
||
脚本生成服务
|
||
============
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import re
|
||
import time
|
||
|
||
from collections.abc import AsyncIterator
|
||
from pathlib import Path
|
||
|
||
from app.ai.model_router import get_model_router
|
||
from app.ai.prompts import load_script_user_prompt, load_system_prompt
|
||
from app.schemas.script import ScriptGenerationEvent, ScriptShot
|
||
from app.services.ai_response_utils import (
|
||
safe_parse_ai_json_response,
|
||
validate_and_normalize_shots,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ScriptService:
|
||
"""脚本生成服务"""
|
||
|
||
|
||
def __init__(self):
|
||
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
|
||
|
||
|
||
|
||
def _load_prompt(self, name: str) -> str:
|
||
"""加载 Prompt 模板"""
|
||
prompt_file = self.prompts_dir / f"{name}.txt"
|
||
if prompt_file.exists():
|
||
return prompt_file.read_text(encoding="utf-8")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _extract_json(content: str) -> str:
|
||
"""
|
||
从 Markdown 代码块中提取 JSON,或返回原始内容
|
||
|
||
支持格式:
|
||
- ```json {...} ```
|
||
- ``` {...} ```
|
||
- 纯 JSON 文本
|
||
"""
|
||
if not content:
|
||
return ""
|
||
|
||
content = content.strip()
|
||
|
||
# 匹配 ```json ... ``` 或 ``` ... ```
|
||
pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
|
||
matches = re.findall(pattern, content)
|
||
|
||
if matches:
|
||
# 取最后一个匹配(避免前面有示例代码)
|
||
return matches[-1].strip()
|
||
|
||
# 如果没有代码块,返回原始内容
|
||
return content
|
||
|
||
async def generate_script(
|
||
self,
|
||
category: str,
|
||
subcategory: str,
|
||
duration: int,
|
||
script_type: str,
|
||
model: str | None = None,
|
||
) -> list[ScriptShot]:
|
||
"""
|
||
同步生成脚本
|
||
|
||
Args:
|
||
category: 大类代码,如 "bk"
|
||
subcategory: 小类代码,如 "ht"
|
||
duration: 视频时长(秒)
|
||
script_type: 脚本类型
|
||
model: 指定模型
|
||
|
||
Returns:
|
||
分镜列表
|
||
"""
|
||
# 获取 model_router
|
||
model_router = await get_model_router()
|
||
|
||
# 加载 Prompt
|
||
system_prompt = load_system_prompt(category, subcategory)
|
||
if not system_prompt:
|
||
raise ValueError(f"未找到提示词: category={category}, subcategory={subcategory}")
|
||
|
||
# 用户提示词
|
||
user_prompt = load_script_user_prompt(
|
||
topic=f"{category}/{subcategory}",
|
||
duration=duration,
|
||
)
|
||
|
||
# 调用 AI 生成
|
||
result = await model_router.generate(
|
||
prompt=user_prompt,
|
||
system_prompt=system_prompt,
|
||
model_id=model,
|
||
task_type="script",
|
||
temperature=0.7,
|
||
response_format="json_object",
|
||
)
|
||
|
||
if not result.content or not result.content.strip():
|
||
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
|
||
|
||
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
|
||
|
||
if not success:
|
||
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
||
|
||
try:
|
||
shots_data = validate_and_normalize_shots(parsed_data)
|
||
|
||
if not shots_data:
|
||
raise ValueError("AI 返回的分镜数据为空或格式不正确")
|
||
|
||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||
return shots
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"分镜数据处理失败: {str(e)}")
|
||
|
||
async def generate_script_stream(
|
||
self,
|
||
category: str,
|
||
subcategory: str,
|
||
duration: int,
|
||
script_type: str,
|
||
model: str | None = None,
|
||
) -> AsyncIterator[ScriptGenerationEvent]:
|
||
"""流式生成脚本(SSE)"""
|
||
model_router = await get_model_router()
|
||
|
||
try:
|
||
# 加载 Prompt
|
||
system_prompt = load_system_prompt(category, subcategory)
|
||
if not system_prompt:
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
message=f"未找到提示词: category={category}, subcategory={subcategory}",
|
||
)
|
||
return
|
||
|
||
user_prompt = load_script_user_prompt(
|
||
topic=f"{category}/{subcategory}",
|
||
duration=duration,
|
||
)
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="start",
|
||
message="准备生成脚本...",
|
||
)
|
||
|
||
full_content = ""
|
||
chunk_count = 0
|
||
has_shown_generating = False
|
||
|
||
async for chunk in model_router.generate_stream_with_progress(
|
||
prompt=user_prompt,
|
||
system_prompt=system_prompt,
|
||
model_id=model,
|
||
task_type="script",
|
||
temperature=0.7,
|
||
response_format="json_object",
|
||
):
|
||
chunk_count += 1
|
||
|
||
if chunk_count == 1:
|
||
yield ScriptGenerationEvent(
|
||
type="analyzing",
|
||
message="分析创作要点",
|
||
)
|
||
|
||
if chunk["type"] == "chunk":
|
||
chunk_content = chunk.get("content", "")
|
||
if not chunk_content:
|
||
continue
|
||
full_content += chunk_content
|
||
|
||
if not has_shown_generating:
|
||
yield ScriptGenerationEvent(
|
||
type="generating",
|
||
message="正在创作脚本...",
|
||
)
|
||
has_shown_generating = True
|
||
|
||
if not full_content or not full_content.strip():
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
message="AI 返回内容为空,请检查模型配置或重试",
|
||
)
|
||
return
|
||
|
||
success, parsed_data, error_msg = safe_parse_ai_json_response(full_content)
|
||
|
||
if not success:
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
message=f"脚本解析失败: {error_msg or '无法解析 AI 返回的内容'}",
|
||
)
|
||
return
|
||
|
||
try:
|
||
shots_data = validate_and_normalize_shots(parsed_data)
|
||
|
||
if not shots_data:
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
message="AI 返回的分镜数据为空或格式不正确",
|
||
)
|
||
return
|
||
|
||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="complete",
|
||
message="脚本生成成功",
|
||
result=shots,
|
||
)
|
||
|
||
except Exception as e:
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
message=f"分镜数据处理失败: {str(e)}",
|
||
)
|
||
|
||
except Exception as e:
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
message=f"生成失败: {str(e)}",
|
||
)
|
||
|
||
async def polish_content(
|
||
self,
|
||
content: str,
|
||
polish_type: str = "voiceover",
|
||
shot_type: str = "segment",
|
||
) -> str:
|
||
"""
|
||
润色内容
|
||
|
||
Args:
|
||
content: 待润色内容
|
||
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文案)
|
||
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
|
||
|
||
Returns:
|
||
润色后的内容
|
||
"""
|
||
# 获取 model_router
|
||
model_router = await get_model_router()
|
||
|
||
# 从文件加载提示词模板
|
||
if polish_type == "scene":
|
||
# 画面润色需要根据镜头类型选择不同提示词
|
||
if shot_type == "empty_shot":
|
||
prompt_template = self._load_prompt("polish/scene_empty_shot")
|
||
else:
|
||
prompt_template = self._load_prompt("polish/scene_segment")
|
||
|
||
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
|
||
if not prompt_template:
|
||
prompt_template = self._load_prompt("polish/scene")
|
||
else:
|
||
# 配音文案润色
|
||
prompt_template = self._load_prompt("polish/voiceover")
|
||
|
||
if not prompt_template:
|
||
# 最终回退
|
||
prompt_template = "请润色以下内容:\n\n{content}"
|
||
|
||
prompt = prompt_template.format(content=content)
|
||
|
||
result = await model_router.generate(
|
||
prompt=prompt,
|
||
task_type="polish",
|
||
temperature=0.5,
|
||
max_tokens=300,
|
||
)
|
||
|
||
return result.content.strip()
|
||
|
||
async def check_model_health(self) -> dict:
|
||
"""检查模型健康状态"""
|
||
model_router = await get_model_router()
|
||
health_results = await model_router.health_check()
|
||
|
||
models = []
|
||
available_count = 0
|
||
recommended = None
|
||
|
||
for provider_id, health in health_results.items():
|
||
model_info = {
|
||
"id": health.id,
|
||
"name": health.name,
|
||
"is_available": health.is_available,
|
||
"response_time": health.response_time,
|
||
"last_error": health.last_error,
|
||
}
|
||
models.append(model_info)
|
||
|
||
if health.is_available:
|
||
available_count += 1
|
||
if recommended is None or health.response_time < recommended.get(
|
||
"response_time", float("inf")
|
||
):
|
||
recommended = model_info
|
||
|
||
total = len(models)
|
||
|
||
return {
|
||
"status": "healthy" if available_count > 0 else "unhealthy",
|
||
"models": models,
|
||
"recommended_model": recommended,
|
||
"total_models": total,
|
||
"available_models": available_count,
|
||
}
|
||
|
||
async def test_model(self, model_id: str | None = None) -> dict:
|
||
"""测试指定模型连接"""
|
||
model_router = await get_model_router()
|
||
|
||
import time
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
result = await model_router.generate(
|
||
prompt="你好",
|
||
model_id=model_id,
|
||
max_tokens=5,
|
||
)
|
||
|
||
response_time = (time.time() - start_time) * 1000
|
||
|
||
return {
|
||
"success": True,
|
||
"model": result.model,
|
||
"response_time": round(response_time, 2),
|
||
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"model": model_id or "default",
|
||
"error": str(e),
|
||
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
_script_service: ScriptService | None = None
|
||
|
||
|
||
def get_script_service() -> ScriptService:
|
||
"""获取 ScriptService 单例"""
|
||
global _script_service
|
||
if _script_service is None:
|
||
_script_service = ScriptService()
|
||
return _script_service
|