663 lines
22 KiB
Python
663 lines
22 KiB
Python
"""
|
||
脚本生成服务
|
||
============
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import math
|
||
import re
|
||
import time
|
||
from collections.abc import AsyncIterator
|
||
from pathlib import Path
|
||
|
||
from app.ai.model_router import get_model_router
|
||
from app.ai.prompts import load_script_system, load_script_user
|
||
from app.schemas.script import ScriptGenerationEvent, ScriptShot
|
||
from app.services.ai_response_utils import (
|
||
safe_parse_ai_json_response,
|
||
validate_and_normalize_shots,
|
||
)
|
||
from app.services.anytocopy_service import (
|
||
AnyToCopyService,
|
||
get_anytocopy_service,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ScriptService:
|
||
"""脚本生成服务"""
|
||
|
||
# 根据视频时长估算输出字符数(经验值)
|
||
# 格式: {时长: (最小字符数, 最大字符数)}
|
||
DURATION_ESTIMATES = {
|
||
30: (800, 1200),
|
||
45: (1200, 1600),
|
||
60: (1500, 2000),
|
||
90: (2000, 2800),
|
||
}
|
||
|
||
def __init__(self):
|
||
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
|
||
|
||
def _estimate_total_chars(self, duration: int) -> int:
|
||
"""
|
||
根据时长估算总输出字符数
|
||
|
||
Args:
|
||
duration: 视频时长(秒)
|
||
|
||
Returns:
|
||
预估字符数
|
||
"""
|
||
# 找到最接近的预设
|
||
closest_duration = min(self.DURATION_ESTIMATES.keys(), key=lambda x: abs(x - duration))
|
||
min_chars, max_chars = self.DURATION_ESTIMATES[closest_duration]
|
||
|
||
# 根据实际时长在区间内插值
|
||
ratio = duration / closest_duration
|
||
estimated = int(min_chars + (max_chars - min_chars) * ratio)
|
||
|
||
logger.debug(f"时长 {duration}s 预估字符数: {estimated}")
|
||
return estimated
|
||
|
||
def _calculate_progress(
|
||
self,
|
||
current_chars: int,
|
||
estimated_total: int,
|
||
elapsed_time: float,
|
||
min_expected_time: float = 5.0,
|
||
) -> int:
|
||
"""
|
||
计算平滑进度(使用对数曲线)- 优化版,避免抖动
|
||
|
||
设计思路:
|
||
- 主要基于内容生成进度,时间只作为保底
|
||
- 使用单调递增函数,确保进度只增不减
|
||
- 前 20% 内容:进度到 30%(慢启动)
|
||
- 中间 60% 内容:进度到 75%(稳定生成期)
|
||
- 最后 20% 内容:进度到 85%(收尾阶段)
|
||
|
||
Args:
|
||
current_chars: 当前字符数
|
||
estimated_total: 预估总字符数
|
||
elapsed_time: 已过去时间(秒)
|
||
min_expected_time: 最少预期的生成时间(避免太快跑完)
|
||
|
||
Returns:
|
||
进度百分比 (0-85)
|
||
"""
|
||
# 基于内容的进度(对数曲线)
|
||
ratio = min(current_chars / estimated_total, 1.5) # 允许超出生成 50%
|
||
|
||
# 对数曲线:前期慢,后期快
|
||
if ratio <= 1.0:
|
||
# 未完成或刚好完成:使用调整后的对数曲线,最高到 85%
|
||
progress_ratio = math.log(1 + ratio * 2) / math.log(3) * 0.85
|
||
else:
|
||
# 已超出生成:从 85% 线性增长,最多到 95%(预留空间给后续阶段)
|
||
progress_ratio = 0.85 + min((ratio - 1) * 0.2, 0.1)
|
||
|
||
# 基于时间的保底进度(只在内容很少时生效,避免生成太快时进度条没动)
|
||
# 使用平滑的保底函数,只在前期(前3秒)和内容很少时生效
|
||
time_progress = 0
|
||
if current_chars < estimated_total * 0.3 and elapsed_time < min_expected_time:
|
||
# 内容生成少于30%且时间少于5秒时,提供保底进度
|
||
time_progress = min(elapsed_time / min_expected_time * 0.15, 0.15)
|
||
|
||
# 取较大值,但主要依赖 progress_ratio,time_progress 只作为早期保底
|
||
final_ratio = max(progress_ratio, time_progress)
|
||
|
||
# 生成阶段最高到 85%
|
||
return min(int(final_ratio * 100), 85)
|
||
|
||
def _load_prompt(self, name: str) -> str:
|
||
"""加载 Prompt 模板"""
|
||
prompt_file = self.prompts_dir / f"{name}.txt"
|
||
if prompt_file.exists():
|
||
return prompt_file.read_text(encoding="utf-8")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _extract_json(content: str) -> str:
|
||
"""
|
||
从 Markdown 代码块中提取 JSON,或返回原始内容
|
||
|
||
支持格式:
|
||
- ```json {...} ```
|
||
- ``` {...} ```
|
||
- 纯 JSON 文本
|
||
"""
|
||
if not content:
|
||
return ""
|
||
|
||
content = content.strip()
|
||
|
||
# 匹配 ```json ... ``` 或 ``` ... ```
|
||
pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
|
||
matches = re.findall(pattern, content)
|
||
|
||
if matches:
|
||
# 取最后一个匹配(避免前面有示例代码)
|
||
return matches[-1].strip()
|
||
|
||
# 如果没有代码块,返回原始内容
|
||
return content
|
||
|
||
async def generate_script(
|
||
self,
|
||
topic: str,
|
||
duration: int,
|
||
script_type: str,
|
||
model: str | None = None,
|
||
) -> list[ScriptShot]:
|
||
"""
|
||
同步生成脚本
|
||
|
||
Args:
|
||
topic: 创作主题(支持视频链接,自动提取文案)
|
||
duration: 视频时长(秒)
|
||
script_type: 脚本类型
|
||
model: 指定模型
|
||
|
||
Returns:
|
||
分镜列表
|
||
"""
|
||
# 1. 检测并提取视频链接中的文案
|
||
anytocopy = get_anytocopy_service()
|
||
extract_result = await anytocopy.extract_text_from_input(topic)
|
||
|
||
if extract_result["error"]:
|
||
logger.warning(f"视频文案提取失败: {extract_result['error']}")
|
||
# 提取失败但不中断,使用原始输入
|
||
|
||
if extract_result["is_video_url"]:
|
||
logger.info(f"检测到视频链接,提取文案长度: {len(extract_result['extracted_text'])}")
|
||
# 使用提取的文案作为创作主题
|
||
topic = extract_result["extracted_text"] or topic
|
||
|
||
# 2. 获取 model_router
|
||
model_router = await get_model_router()
|
||
|
||
# 加载 Prompt(使用新的 loader)
|
||
system_prompt = load_script_system()
|
||
user_prompt = load_script_user(
|
||
topic=topic,
|
||
duration=duration,
|
||
script_type=script_type,
|
||
)
|
||
|
||
logger.info(f"同步生成脚本: topic={topic[:20]}, duration={duration}")
|
||
|
||
# 调用 AI 生成
|
||
result = await model_router.generate(
|
||
prompt=user_prompt,
|
||
system_prompt=system_prompt,
|
||
model_id=model,
|
||
task_type="script",
|
||
temperature=0.7,
|
||
)
|
||
|
||
# 检查返回内容
|
||
if not result.content or not result.content.strip():
|
||
logger.error("AI 返回内容为空")
|
||
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
|
||
|
||
logger.info(f"AI 返回内容长度: {len(result.content)} 字符")
|
||
|
||
# 使用安全的 JSON 解析
|
||
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
|
||
|
||
if not success:
|
||
logger.error(f"JSON 解析失败: {error_msg}")
|
||
logger.error(f"原始内容: {result.content[:500]!r}")
|
||
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
||
|
||
# 验证并标准化分镜数据
|
||
try:
|
||
shots_data = validate_and_normalize_shots(parsed_data)
|
||
|
||
if not shots_data:
|
||
raise ValueError("AI 返回的分镜数据为空或格式不正确")
|
||
|
||
# 转换为 ScriptShot 对象
|
||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||
logger.info(f"成功解析 {len(shots)} 个分镜")
|
||
return shots
|
||
|
||
except Exception as e:
|
||
logger.error(f"分镜数据标准化失败: {e}")
|
||
raise ValueError(f"分镜数据处理失败: {str(e)}")
|
||
|
||
async def generate_script_stream(
|
||
self,
|
||
topic: str,
|
||
duration: int,
|
||
script_type: str,
|
||
model: str | None = None,
|
||
) -> AsyncIterator[ScriptGenerationEvent]:
|
||
"""
|
||
流式生成脚本(SSE)- 优化版
|
||
|
||
支持视频链接自动提取文案。
|
||
|
||
进度设计:
|
||
- 0-5%: start(初始化)
|
||
- 5-15%: analyzing(分析主题,含视频文案提取)
|
||
- 15-85%: generating(AI 生成,平滑对数曲线增长)
|
||
- 85-92%: validating(JSON 验证)
|
||
- 92-98%: parsing(解析分镜)
|
||
- 98-100%: complete(完成)
|
||
"""
|
||
model_router = await get_model_router()
|
||
start_time = time.time()
|
||
|
||
# 1. 检测并提取视频链接中的文案
|
||
original_topic = topic
|
||
anytocopy = get_anytocopy_service()
|
||
extracted_info = None # 保存提取的视频信息
|
||
|
||
# 检查是否为视频链接
|
||
if AnyToCopyService.is_video_url(topic) or AnyToCopyService.extract_url_from_text(topic):
|
||
yield ScriptGenerationEvent(
|
||
type="analyzing",
|
||
progress=5,
|
||
message="检测到视频链接,正在提取文案...",
|
||
)
|
||
|
||
extract_result = await anytocopy.extract_text_from_input(topic)
|
||
|
||
if extract_result["error"]:
|
||
logger.warning(f"视频文案提取失败: {extract_result['error']}")
|
||
yield ScriptGenerationEvent(
|
||
type="analyzing",
|
||
progress=8,
|
||
message="视频文案提取失败,使用原始输入继续生成...",
|
||
)
|
||
elif extract_result["is_video_url"]:
|
||
extracted_text = extract_result["extracted_text"]
|
||
logger.info(f"视频文案提取成功,长度: {len(extracted_text)}")
|
||
topic = extracted_text or topic
|
||
|
||
# 保存提取的视频信息(只要有 video_info 就返回)
|
||
video_info = extract_result.get("video_info")
|
||
if video_info:
|
||
extracted_info = {
|
||
"title": video_info.title,
|
||
"content": video_info.content,
|
||
"text_content": video_info.text_content,
|
||
"platform": video_info.platform,
|
||
"duration": video_info.duration,
|
||
"original_url": original_topic,
|
||
}
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="analyzing",
|
||
progress=10,
|
||
message=f"视频文案提取成功,共 {len(extracted_text)} 字符",
|
||
)
|
||
|
||
try:
|
||
# 加载 Prompt
|
||
system_prompt = load_script_system()
|
||
user_prompt = load_script_user(
|
||
topic=topic,
|
||
duration=duration,
|
||
script_type=script_type,
|
||
)
|
||
|
||
# 1. 开始阶段(0-5%)
|
||
yield ScriptGenerationEvent(
|
||
type="start",
|
||
progress=2,
|
||
message="准备生成脚本...",
|
||
)
|
||
|
||
# 2. 分析阶段(5-15%)
|
||
yield ScriptGenerationEvent(
|
||
type="analyzing",
|
||
progress=10,
|
||
message="分析创作要点",
|
||
)
|
||
|
||
# 估算总长度(根据时长)
|
||
estimated_total = self._estimate_total_chars(duration)
|
||
|
||
# 3. 生成阶段(15-55%)- 降低占比,给后续步骤留更多空间
|
||
yield ScriptGenerationEvent(
|
||
type="generating",
|
||
progress=15,
|
||
message="正在创作脚本...",
|
||
)
|
||
|
||
full_content = ""
|
||
last_progress = 15
|
||
last_update_time = start_time
|
||
update_interval = 0.5 # 最少 500ms 更新一次
|
||
chunk_count = 0
|
||
|
||
logger.info(f"开始流式生成: topic={topic[:20]}, duration={duration}")
|
||
|
||
async for chunk in model_router.generate_stream_with_progress(
|
||
prompt=user_prompt,
|
||
system_prompt=system_prompt,
|
||
model_id=model,
|
||
task_type="script",
|
||
temperature=0.7,
|
||
):
|
||
chunk_count += 1
|
||
|
||
if chunk["type"] == "chunk":
|
||
chunk_content = chunk.get("content", "")
|
||
if not chunk_content:
|
||
logger.warning(f"收到空 chunk,序号: {chunk_count}")
|
||
continue
|
||
|
||
full_content += chunk_content
|
||
current_chars = len(full_content)
|
||
elapsed = time.time() - start_time
|
||
|
||
# 计算平滑进度(对数曲线,最高到55%)
|
||
base_progress = self._calculate_progress(
|
||
current_chars=current_chars,
|
||
estimated_total=estimated_total,
|
||
elapsed_time=elapsed,
|
||
)
|
||
# 将原来的 15-85 映射到 15-55
|
||
progress = 15 + int((base_progress - 15) * 40 / 70)
|
||
|
||
# 限制更新频率,但确保每次有变化都上报(最小 2% 变化)
|
||
current_time = time.time()
|
||
if progress > last_progress and (
|
||
progress - last_progress >= 2
|
||
or current_time - last_update_time >= update_interval
|
||
):
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="generating",
|
||
progress=progress,
|
||
message="正在创作脚本...",
|
||
)
|
||
last_progress = progress
|
||
last_update_time = current_time
|
||
|
||
elif chunk["type"] == "usage":
|
||
prompt_tokens = chunk.get("prompt_tokens", 0)
|
||
completion_tokens = chunk.get("completion_tokens", 0)
|
||
logger.info(
|
||
f"Token 使用: prompt={prompt_tokens}, completion={completion_tokens}"
|
||
)
|
||
|
||
logger.info(f"流式生成结束: 共 {chunk_count} 个 chunk, {len(full_content)} 字符")
|
||
|
||
# 4. 验证阶段(55-70%)
|
||
actual_chars = len(full_content)
|
||
logger.info(f"生成完成: {actual_chars} 字符 (预估: {estimated_total})")
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="validating",
|
||
progress=60,
|
||
message="验证脚本格式...",
|
||
)
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="validating",
|
||
progress=65,
|
||
message="检查数据完整性...",
|
||
)
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="validating",
|
||
progress=70,
|
||
message="验证通过",
|
||
)
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 5. 解析阶段(70-80%)
|
||
yield ScriptGenerationEvent(
|
||
type="parsing",
|
||
progress=75,
|
||
message="解析分镜内容...",
|
||
)
|
||
|
||
# 检查内容是否为空
|
||
if not full_content or not full_content.strip():
|
||
logger.error("AI 返回内容为空")
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
progress=0,
|
||
message="AI 返回内容为空,请检查模型配置或重试",
|
||
)
|
||
return
|
||
|
||
# 记录原始内容(调试用)
|
||
logger.info(f"AI 原始输出: {full_content[:500]}...")
|
||
|
||
# 使用安全的 JSON 解析
|
||
success, parsed_data, error_msg = safe_parse_ai_json_response(full_content)
|
||
|
||
if not success:
|
||
logger.error(f"JSON 解析失败: {error_msg}")
|
||
logger.error(f"原始内容前500字符: {full_content[:500]!r}")
|
||
|
||
# 给前端更详细的错误信息
|
||
error_detail = error_msg or "无法解析 AI 返回的内容"
|
||
if not full_content or not full_content.strip():
|
||
error_detail = "AI 返回内容为空,请检查模型配置或重试"
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
progress=0,
|
||
message=f"脚本解析失败: {error_detail}",
|
||
)
|
||
return
|
||
|
||
# 验证并标准化分镜数据
|
||
try:
|
||
shots_data = validate_and_normalize_shots(parsed_data)
|
||
|
||
if not shots_data:
|
||
logger.error("标准化后分镜列表为空")
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
progress=0,
|
||
message="AI 返回的分镜数据为空或格式不正确",
|
||
)
|
||
return
|
||
|
||
# 转换为 ScriptShot 对象
|
||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||
|
||
# 6. 完成阶段(80-100%)- 细分为多个步骤,让用户感知进度
|
||
yield ScriptGenerationEvent(
|
||
type="finalizing",
|
||
progress=80,
|
||
message=f"整理 {len(shots)} 个分镜...",
|
||
)
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="finalizing",
|
||
progress=85,
|
||
message="优化镜头顺序...",
|
||
)
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="finalizing",
|
||
progress=90,
|
||
message="检查时长分配...",
|
||
)
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="finalizing",
|
||
progress=95,
|
||
message="准备完成...",
|
||
)
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
yield ScriptGenerationEvent(
|
||
type="complete",
|
||
progress=100,
|
||
message=f"成功生成 {len(shots)} 个分镜",
|
||
result=shots,
|
||
extracted_info=extracted_info,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"分镜数据标准化失败: {e}")
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
progress=0,
|
||
message=f"分镜数据处理失败: {str(e)}",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.exception("脚本生成失败")
|
||
yield ScriptGenerationEvent(
|
||
type="error",
|
||
progress=0,
|
||
message=f"生成失败: {str(e)}",
|
||
)
|
||
|
||
async def polish_content(
|
||
self,
|
||
content: str,
|
||
polish_type: str = "voiceover",
|
||
shot_type: str = "segment",
|
||
) -> str:
|
||
"""
|
||
润色内容
|
||
|
||
Args:
|
||
content: 待润色内容
|
||
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文案)
|
||
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
|
||
|
||
Returns:
|
||
润色后的内容
|
||
"""
|
||
# 获取 model_router
|
||
model_router = await get_model_router()
|
||
|
||
# 从文件加载提示词模板
|
||
if polish_type == "scene":
|
||
# 画面润色需要根据镜头类型选择不同提示词
|
||
if shot_type == "empty_shot":
|
||
prompt_template = self._load_prompt("polish/scene_empty_shot")
|
||
else:
|
||
prompt_template = self._load_prompt("polish/scene_segment")
|
||
|
||
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
|
||
if not prompt_template:
|
||
prompt_template = self._load_prompt("polish/scene")
|
||
else:
|
||
# 配音文案润色
|
||
prompt_template = self._load_prompt("polish/voiceover")
|
||
|
||
if not prompt_template:
|
||
# 最终回退
|
||
prompt_template = "请润色以下内容:\n\n{content}"
|
||
|
||
prompt = prompt_template.format(content=content)
|
||
|
||
result = await model_router.generate(
|
||
prompt=prompt,
|
||
task_type="polish",
|
||
temperature=0.5,
|
||
max_tokens=300,
|
||
)
|
||
|
||
return result.content.strip()
|
||
|
||
async def check_model_health(self) -> dict:
|
||
"""检查模型健康状态"""
|
||
model_router = await get_model_router()
|
||
health_results = await model_router.health_check()
|
||
|
||
models = []
|
||
available_count = 0
|
||
recommended = None
|
||
|
||
for provider_id, health in health_results.items():
|
||
model_info = {
|
||
"id": health.id,
|
||
"name": health.name,
|
||
"is_available": health.is_available,
|
||
"response_time": health.response_time,
|
||
"last_error": health.last_error,
|
||
}
|
||
models.append(model_info)
|
||
|
||
if health.is_available:
|
||
available_count += 1
|
||
if recommended is None or health.response_time < recommended.get(
|
||
"response_time", float("inf")
|
||
):
|
||
recommended = model_info
|
||
|
||
total = len(models)
|
||
|
||
return {
|
||
"status": "healthy" if available_count > 0 else "unhealthy",
|
||
"models": models,
|
||
"recommended_model": recommended,
|
||
"total_models": total,
|
||
"available_models": available_count,
|
||
}
|
||
|
||
async def test_model(self, model_id: str | None = None) -> dict:
|
||
"""测试指定模型连接"""
|
||
model_router = await get_model_router()
|
||
|
||
import time
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
result = await model_router.generate(
|
||
prompt="你好",
|
||
model_id=model_id,
|
||
max_tokens=5,
|
||
)
|
||
|
||
response_time = (time.time() - start_time) * 1000
|
||
|
||
return {
|
||
"success": True,
|
||
"model": result.model,
|
||
"response_time": round(response_time, 2),
|
||
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"model": model_id or "default",
|
||
"error": str(e),
|
||
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
_script_service: ScriptService | None = None
|
||
|
||
|
||
def get_script_service() -> ScriptService:
|
||
"""获取 ScriptService 单例"""
|
||
global _script_service
|
||
if _script_service is None:
|
||
_script_service = ScriptService()
|
||
return _script_service
|