Files
meijiaka-zy/python-api/app/services/script_service.py
T
小鱼开发 7c23cb3afb feat: 启用 JSON Mode 约束脚本生成输出
- script_service: 调用 model_router 时传入 response_format="json_object"
- volcengine_provider: generate 和 generate_stream_with_progress 支持 response_format 参数
- 强制模型输出合法 JSON,减少 Markdown 代码块包裹和说明文字导致的解析失败
2026-04-26 20:41:05 +08:00

409 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
脚本生成服务
============
"""
import asyncio
import logging
import re
import time
from collections.abc import AsyncIterator
from pathlib import Path
from app.ai.model_router import get_model_router
from app.ai.prompts import load_script_user_prompt, load_system_prompt
from app.schemas.script import ScriptGenerationEvent, ScriptShot
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
validate_and_normalize_shots,
)
logger = logging.getLogger(__name__)
class ScriptService:
"""脚本生成服务"""
def __init__(self):
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
def _load_prompt(self, name: str) -> str:
"""加载 Prompt 模板"""
prompt_file = self.prompts_dir / f"{name}.txt"
if prompt_file.exists():
return prompt_file.read_text(encoding="utf-8")
return ""
@staticmethod
def _extract_json(content: str) -> str:
"""
从 Markdown 代码块中提取 JSON,或返回原始内容
支持格式:
- ```json {...} ```
- ``` {...} ```
- 纯 JSON 文本
"""
if not content:
return ""
content = content.strip()
# 匹配 ```json ... ``` 或 ``` ... ```
pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
matches = re.findall(pattern, content)
if matches:
# 取最后一个匹配(避免前面有示例代码)
return matches[-1].strip()
# 如果没有代码块,返回原始内容
return content
async def generate_script(
self,
category: str,
subcategory: str,
duration: int,
script_type: str,
model: str | None = None,
) -> list[ScriptShot]:
"""
同步生成脚本
Args:
category: 大类代码,如 "bk"
subcategory: 小类代码,如 "ht"
duration: 视频时长(秒)
script_type: 脚本类型
model: 指定模型
Returns:
分镜列表
"""
# 获取 model_router
model_router = await get_model_router()
# 加载 Prompt
system_prompt = load_system_prompt(category, subcategory)
if not system_prompt:
raise ValueError(f"未找到提示词: category={category}, subcategory={subcategory}")
# 用户提示词
user_prompt = load_script_user_prompt(
topic=f"{category}/{subcategory}",
duration=duration,
)
logger.info(f"同步生成脚本: category={category}, subcategory={subcategory}, duration={duration}")
# 调用 AI 生成
result = await model_router.generate(
prompt=user_prompt,
system_prompt=system_prompt,
model_id=model,
task_type="script",
temperature=0.7,
response_format="json_object",
)
# 检查返回内容
if not result.content or not result.content.strip():
logger.error("AI 返回内容为空")
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
logger.info(f"AI 返回内容长度: {len(result.content)} 字符")
# 使用安全的 JSON 解析
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
if not success:
logger.error(f"JSON 解析失败: {error_msg}")
logger.error(f"原始内容: {result.content[:500]!r}")
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
# 验证并标准化分镜数据
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
raise ValueError("AI 返回的分镜数据为空或格式不正确")
# 转换为 ScriptShot 对象
shots = [ScriptShot(**shot) for shot in shots_data]
logger.info(f"成功解析 {len(shots)} 个分镜")
return shots
except Exception as e:
logger.error(f"分镜数据标准化失败: {e}")
raise ValueError(f"分镜数据处理失败: {str(e)}")
async def generate_script_stream(
self,
category: str,
subcategory: str,
duration: int,
script_type: str,
model: str | None = None,
) -> AsyncIterator[ScriptGenerationEvent]:
"""流式生成脚本(SSE"""
model_router = await get_model_router()
try:
# 加载 Prompt
system_prompt = load_system_prompt(category, subcategory)
if not system_prompt:
yield ScriptGenerationEvent(
type="error",
message=f"未找到提示词: category={category}, subcategory={subcategory}",
)
return
user_prompt = load_script_user_prompt(
topic=f"{category}/{subcategory}",
duration=duration,
)
logger.info(f"流式生成脚本: category={category}, subcategory={subcategory}, duration={duration}")
# 1. 开始阶段 — 立即发送 LLM 请求
start_time = time.time()
yield ScriptGenerationEvent(
type="start",
message="准备生成脚本...",
)
full_content = ""
chunk_count = 0
has_shown_generating = False
logger.info("[SSE] 开始 LLM 请求")
async for chunk in model_router.generate_stream_with_progress(
prompt=user_prompt,
system_prompt=system_prompt,
model_id=model,
task_type="script",
temperature=0.7,
response_format="json_object",
):
chunk_count += 1
# 第一个 chunk 到来时显示 analyzing
if chunk_count == 1:
first_chunk_time = time.time()
logger.info(f"[SSE] 首 token 延迟: {first_chunk_time - start_time:.3f}s")
yield ScriptGenerationEvent(
type="analyzing",
message="分析创作要点",
)
if chunk["type"] == "chunk":
chunk_content = chunk.get("content", "")
if not chunk_content:
continue
full_content += chunk_content
# 第一次收到实际内容时显示 generating
if not has_shown_generating:
content_time = time.time()
logger.info(f"[SSE] 首个内容 chunk 延迟: {content_time - start_time:.3f}s")
yield ScriptGenerationEvent(
type="generating",
message="正在创作脚本...",
)
has_shown_generating = True
elif chunk["type"] == "usage":
prompt_tokens = chunk.get("prompt_tokens", 0)
completion_tokens = chunk.get("completion_tokens", 0)
logger.info(
f"Token 使用: prompt={prompt_tokens}, completion={completion_tokens}"
)
total_time = time.time() - start_time
logger.info(f"[SSE] 流式生成结束: 共 {chunk_count} 个 chunk, {len(full_content)} 字符, 总耗时 {total_time:.3f}s")
if not full_content or not full_content.strip():
logger.error("AI 返回内容为空")
yield ScriptGenerationEvent(
type="error",
message="AI 返回内容为空,请检查模型配置或重试",
)
return
success, parsed_data, error_msg = safe_parse_ai_json_response(full_content)
if not success:
logger.error(f"JSON 解析失败: {error_msg}")
yield ScriptGenerationEvent(
type="error",
message=f"脚本解析失败: {error_msg or '无法解析 AI 返回的内容'}",
)
return
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
yield ScriptGenerationEvent(
type="error",
message="AI 返回的分镜数据为空或格式不正确",
)
return
shots = [ScriptShot(**shot) for shot in shots_data]
yield ScriptGenerationEvent(
type="complete",
message="脚本生成成功",
result=shots,
)
except Exception as e:
logger.error(f"分镜数据标准化失败: {e}")
yield ScriptGenerationEvent(
type="error",
message=f"分镜数据处理失败: {str(e)}",
)
except Exception as e:
logger.exception("脚本生成失败")
yield ScriptGenerationEvent(
type="error",
message=f"生成失败: {str(e)}",
)
async def polish_content(
self,
content: str,
polish_type: str = "voiceover",
shot_type: str = "segment",
) -> str:
"""
润色内容
Args:
content: 待润色内容
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文案)
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
Returns:
润色后的内容
"""
# 获取 model_router
model_router = await get_model_router()
# 从文件加载提示词模板
if polish_type == "scene":
# 画面润色需要根据镜头类型选择不同提示词
if shot_type == "empty_shot":
prompt_template = self._load_prompt("polish/scene_empty_shot")
else:
prompt_template = self._load_prompt("polish/scene_segment")
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
if not prompt_template:
prompt_template = self._load_prompt("polish/scene")
else:
# 配音文案润色
prompt_template = self._load_prompt("polish/voiceover")
if not prompt_template:
# 最终回退
prompt_template = "请润色以下内容:\n\n{content}"
prompt = prompt_template.format(content=content)
result = await model_router.generate(
prompt=prompt,
task_type="polish",
temperature=0.5,
max_tokens=300,
)
return result.content.strip()
async def check_model_health(self) -> dict:
"""检查模型健康状态"""
model_router = await get_model_router()
health_results = await model_router.health_check()
models = []
available_count = 0
recommended = None
for provider_id, health in health_results.items():
model_info = {
"id": health.id,
"name": health.name,
"is_available": health.is_available,
"response_time": health.response_time,
"last_error": health.last_error,
}
models.append(model_info)
if health.is_available:
available_count += 1
if recommended is None or health.response_time < recommended.get(
"response_time", float("inf")
):
recommended = model_info
total = len(models)
return {
"status": "healthy" if available_count > 0 else "unhealthy",
"models": models,
"recommended_model": recommended,
"total_models": total,
"available_models": available_count,
}
async def test_model(self, model_id: str | None = None) -> dict:
"""测试指定模型连接"""
model_router = await get_model_router()
import time
start_time = time.time()
try:
result = await model_router.generate(
prompt="你好",
model_id=model_id,
max_tokens=5,
)
response_time = (time.time() - start_time) * 1000
return {
"success": True,
"model": result.model,
"response_time": round(response_time, 2),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
except Exception as e:
return {
"success": False,
"model": model_id or "default",
"error": str(e),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
# 全局单例
_script_service: ScriptService | None = None
def get_script_service() -> ScriptService:
"""获取 ScriptService 单例"""
global _script_service
if _script_service is None:
_script_service = ScriptService()
return _script_service