Files
meijiaka-zy/python-api/app/services/script_service.py
T
小鱼开发 bb08d0f586 refactor: 从智影 Fork 重构为智剪,独立 Docker 基础设施,开发模式认证兜底
主要变更:
- 修复 /tasks/script 路由 404(去掉重复 prefix)
- 开发模式自动认证兜底(无需登录即可测试流程)
- Docker 基础设施独立化(共用 db/redis)
- 前端 API 端口改为 8081
- 新增 TTS/语音克隆、视频粗剪、音频混音等智剪功能
- 删除智影专属模块(avatar、model_usage、qiniu 上传等)
2026-04-21 12:35:50 +08:00

681 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
脚本生成服务
============
"""
import asyncio
import logging
import math
import re
import time
from collections.abc import AsyncIterator
from pathlib import Path
from app.ai.model_router import get_model_router
from app.ai.prompts import load_script_system, load_script_user_prompt, load_topic_prompt, TOPIC_PROMPT_MAP
from app.schemas.script import ScriptGenerationEvent, ScriptShot
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
validate_and_normalize_shots,
)
from app.services.anytocopy_service import (
AnyToCopyService,
get_anytocopy_service,
)
logger = logging.getLogger(__name__)
class ScriptService:
"""脚本生成服务"""
# 根据视频时长估算输出字符数(经验值)
# 格式: {时长: (最小字符数, 最大字符数)}
DURATION_ESTIMATES = {
30: (800, 1200),
45: (1200, 1600),
60: (1500, 2000),
90: (2000, 2800),
}
def __init__(self):
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
def _estimate_total_chars(self, duration: int) -> int:
"""
根据时长估算总输出字符数
Args:
duration: 视频时长(秒)
Returns:
预估字符数
"""
# 找到最接近的预设
closest_duration = min(self.DURATION_ESTIMATES.keys(), key=lambda x: abs(x - duration))
min_chars, max_chars = self.DURATION_ESTIMATES[closest_duration]
# 根据实际时长在区间内插值
ratio = duration / closest_duration
estimated = int(min_chars + (max_chars - min_chars) * ratio)
logger.debug(f"时长 {duration}s 预估字符数: {estimated}")
return estimated
def _calculate_progress(
self,
current_chars: int,
estimated_total: int,
elapsed_time: float,
min_expected_time: float = 5.0,
) -> int:
"""
计算平滑进度(使用对数曲线)- 优化版,避免抖动
设计思路:
- 主要基于内容生成进度,时间只作为保底
- 使用单调递增函数,确保进度只增不减
- 前 20% 内容:进度到 30%(慢启动)
- 中间 60% 内容:进度到 75%(稳定生成期)
- 最后 20% 内容:进度到 85%(收尾阶段)
Args:
current_chars: 当前字符数
estimated_total: 预估总字符数
elapsed_time: 已过去时间(秒)
min_expected_time: 最少预期的生成时间(避免太快跑完)
Returns:
进度百分比 (0-85)
"""
# 基于内容的进度(对数曲线)
ratio = min(current_chars / estimated_total, 1.5) # 允许超出生成 50%
# 对数曲线:前期慢,后期快
if ratio <= 1.0:
# 未完成或刚好完成:使用调整后的对数曲线,最高到 85%
progress_ratio = math.log(1 + ratio * 2) / math.log(3) * 0.85
else:
# 已超出生成:从 85% 线性增长,最多到 95%(预留空间给后续阶段)
progress_ratio = 0.85 + min((ratio - 1) * 0.2, 0.1)
# 基于时间的保底进度(只在内容很少时生效,避免生成太快时进度条没动)
# 使用平滑的保底函数,只在前期(前3秒)和内容很少时生效
time_progress = 0
if current_chars < estimated_total * 0.3 and elapsed_time < min_expected_time:
# 内容生成少于30%且时间少于5秒时,提供保底进度
time_progress = min(elapsed_time / min_expected_time * 0.15, 0.15)
# 取较大值,但主要依赖 progress_ratiotime_progress 只作为早期保底
final_ratio = max(progress_ratio, time_progress)
# 生成阶段最高到 85%
return min(int(final_ratio * 100), 85)
def _load_prompt(self, name: str) -> str:
"""加载 Prompt 模板"""
prompt_file = self.prompts_dir / f"{name}.txt"
if prompt_file.exists():
return prompt_file.read_text(encoding="utf-8")
return ""
@staticmethod
def _extract_json(content: str) -> str:
"""
从 Markdown 代码块中提取 JSON,或返回原始内容
支持格式:
- ```json {...} ```
- ``` {...} ```
- 纯 JSON 文本
"""
if not content:
return ""
content = content.strip()
# 匹配 ```json ... ``` 或 ``` ... ```
pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
matches = re.findall(pattern, content)
if matches:
# 取最后一个匹配(避免前面有示例代码)
return matches[-1].strip()
# 如果没有代码块,返回原始内容
return content
async def generate_script(
self,
topic: str,
duration: int,
script_type: str,
model: str | None = None,
) -> list[ScriptShot]:
"""
同步生成脚本
Args:
topic: 创作主题(预设主题名或自定义输入/视频链接)
duration: 视频时长(秒)
script_type: 脚本类型
model: 指定模型
Returns:
分镜列表
"""
# 1. 判断是否为预设主题
is_preset_topic = topic in TOPIC_PROMPT_MAP
# 2. 根据类型决定处理方式
actual_topic = topic
if not is_preset_topic:
# 非预设主题:检测并提取视频链接中的文案
anytocopy = get_anytocopy_service()
extract_result = await anytocopy.extract_text_from_input(topic)
if extract_result["error"]:
logger.warning(f"视频文案提取失败: {extract_result['error']}")
# 提取失败但不中断,使用原始输入
if extract_result["is_video_url"]:
logger.info(f"检测到视频链接,提取文案长度: {len(extract_result['extracted_text'])}")
# 使用提取的文案作为创作主题
actual_topic = extract_result["extracted_text"] or topic
# 3. 获取 model_router
model_router = await get_model_router()
# 4. 加载 Prompt
# 系统提示词:预设主题用专用提示词,否则用通用提示词
system_prompt = load_topic_prompt(topic) if is_preset_topic else load_script_system()
# 用户提示词
user_prompt = load_script_user_prompt(
topic=topic,
duration=duration,
)
logger.info(f"同步生成脚本: topic={topic}, is_preset={is_preset_topic}, duration={duration}")
logger.info(f"同步生成脚本: topic={topic[:20]}, duration={duration}")
# 调用 AI 生成
result = await model_router.generate(
prompt=user_prompt,
system_prompt=system_prompt,
model_id=model,
task_type="script",
temperature=0.7,
)
# 检查返回内容
if not result.content or not result.content.strip():
logger.error("AI 返回内容为空")
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
logger.info(f"AI 返回内容长度: {len(result.content)} 字符")
# 使用安全的 JSON 解析
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
if not success:
logger.error(f"JSON 解析失败: {error_msg}")
logger.error(f"原始内容: {result.content[:500]!r}")
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
# 验证并标准化分镜数据
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
raise ValueError("AI 返回的分镜数据为空或格式不正确")
# 转换为 ScriptShot 对象
shots = [ScriptShot(**shot) for shot in shots_data]
logger.info(f"成功解析 {len(shots)} 个分镜")
return shots
except Exception as e:
logger.error(f"分镜数据标准化失败: {e}")
raise ValueError(f"分镜数据处理失败: {str(e)}")
async def generate_script_stream(
self,
topic: str,
duration: int,
script_type: str,
model: str | None = None,
) -> AsyncIterator[ScriptGenerationEvent]:
"""
流式生成脚本(SSE)- 优化版
支持预设主题和视频链接自动提取文案。
进度设计:
- 0-5%: start(初始化)
- 5-15%: analyzing(分析主题,含视频文案提取)
- 15-85%: generating(AI 生成,平滑对数曲线增长)
- 85-92%: validatingJSON 验证)
- 92-98%: parsing(解析分镜)
- 98-100%: complete(完成)
"""
model_router = await get_model_router()
start_time = time.time()
# 1. 判断是否为预设主题
is_preset_topic = topic in TOPIC_PROMPT_MAP
# 2. 非预设主题时,检测并提取视频链接中的文案
original_topic = topic
anytocopy = get_anytocopy_service()
extracted_info = None # 保存提取的视频信息
actual_topic = topic
# 检查是否为视频链接(非预设主题才检测)
if not is_preset_topic and (
AnyToCopyService.is_video_url(topic) or AnyToCopyService.extract_url_from_text(topic)
):
yield ScriptGenerationEvent(
type="analyzing",
progress=5,
message="检测到视频链接,正在提取文案...",
)
extract_result = await anytocopy.extract_text_from_input(topic)
if extract_result["error"]:
logger.warning(f"视频文案提取失败: {extract_result['error']}")
yield ScriptGenerationEvent(
type="analyzing",
progress=8,
message="视频文案提取失败,使用原始输入继续生成...",
)
elif extract_result["is_video_url"]:
extracted_text = extract_result["extracted_text"]
logger.info(f"视频文案提取成功,长度: {len(extracted_text)}")
topic = extracted_text or topic
# 保存提取的视频信息(只要有 video_info 就返回)
video_info = extract_result.get("video_info")
if video_info:
extracted_info = {
"title": video_info.title,
"content": video_info.content,
"text_content": video_info.text_content,
"platform": video_info.platform,
"duration": video_info.duration,
"original_url": original_topic,
}
yield ScriptGenerationEvent(
type="analyzing",
progress=10,
message=f"视频文案提取成功,共 {len(extracted_text)} 字符",
)
try:
# 加载 Prompt
# 系统提示词:预设主题用专用提示词,否则用通用提示词
system_prompt = load_topic_prompt(topic) if is_preset_topic else load_script_system()
# 用户提示词
user_prompt = load_script_user_prompt(
topic=topic,
duration=duration,
)
logger.info(f"流式生成脚本: topic={topic}, is_preset={is_preset_topic}, duration={duration}")
# 1. 开始阶段(0-5%
yield ScriptGenerationEvent(
type="start",
progress=2,
message="准备生成脚本...",
)
# 2. 分析阶段(5-15%
yield ScriptGenerationEvent(
type="analyzing",
progress=10,
message="分析创作要点",
)
# 估算总长度(根据时长)
estimated_total = self._estimate_total_chars(duration)
# 3. 生成阶段(15-55%)- 降低占比,给后续步骤留更多空间
yield ScriptGenerationEvent(
type="generating",
progress=15,
message="正在创作脚本...",
)
full_content = ""
last_progress = 15
last_update_time = start_time
update_interval = 0.5 # 最少 500ms 更新一次
chunk_count = 0
logger.info(f"开始流式生成: topic={topic[:20]}, duration={duration}")
async for chunk in model_router.generate_stream_with_progress(
prompt=user_prompt,
system_prompt=system_prompt,
model_id=model,
task_type="script",
temperature=0.7,
):
chunk_count += 1
if chunk["type"] == "chunk":
chunk_content = chunk.get("content", "")
if not chunk_content:
logger.warning(f"收到空 chunk,序号: {chunk_count}")
continue
full_content += chunk_content
current_chars = len(full_content)
elapsed = time.time() - start_time
# 计算平滑进度(对数曲线,最高到55%)
base_progress = self._calculate_progress(
current_chars=current_chars,
estimated_total=estimated_total,
elapsed_time=elapsed,
)
# 将原来的 15-85 映射到 15-55
progress = 15 + int((base_progress - 15) * 40 / 70)
# 限制更新频率,但确保每次有变化都上报(最小 2% 变化)
current_time = time.time()
if progress > last_progress and (
progress - last_progress >= 2
or current_time - last_update_time >= update_interval
):
yield ScriptGenerationEvent(
type="generating",
progress=progress,
message="正在创作脚本...",
)
last_progress = progress
last_update_time = current_time
elif chunk["type"] == "usage":
prompt_tokens = chunk.get("prompt_tokens", 0)
completion_tokens = chunk.get("completion_tokens", 0)
logger.info(
f"Token 使用: prompt={prompt_tokens}, completion={completion_tokens}"
)
logger.info(f"流式生成结束: 共 {chunk_count} 个 chunk, {len(full_content)} 字符")
# 4. 验证阶段(55-70%
actual_chars = len(full_content)
logger.info(f"生成完成: {actual_chars} 字符 (预估: {estimated_total})")
yield ScriptGenerationEvent(
type="validating",
progress=60,
message="验证脚本格式...",
)
await asyncio.sleep(0.5)
yield ScriptGenerationEvent(
type="validating",
progress=65,
message="检查数据完整性...",
)
await asyncio.sleep(0.5)
yield ScriptGenerationEvent(
type="validating",
progress=70,
message="验证通过",
)
await asyncio.sleep(0.5)
# 5. 解析阶段(70-80%
yield ScriptGenerationEvent(
type="parsing",
progress=75,
message="解析分镜内容...",
)
# 检查内容是否为空
if not full_content or not full_content.strip():
logger.error("AI 返回内容为空")
yield ScriptGenerationEvent(
type="error",
progress=0,
message="AI 返回内容为空,请检查模型配置或重试",
)
return
# 记录原始内容(调试用)
logger.info(f"AI 原始输出: {full_content[:500]}...")
# 使用安全的 JSON 解析
success, parsed_data, error_msg = safe_parse_ai_json_response(full_content)
if not success:
logger.error(f"JSON 解析失败: {error_msg}")
logger.error(f"原始内容前500字符: {full_content[:500]!r}")
# 给前端更详细的错误信息
error_detail = error_msg or "无法解析 AI 返回的内容"
if not full_content or not full_content.strip():
error_detail = "AI 返回内容为空,请检查模型配置或重试"
yield ScriptGenerationEvent(
type="error",
progress=0,
message=f"脚本解析失败: {error_detail}",
)
return
# 验证并标准化分镜数据
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
logger.error("标准化后分镜列表为空")
yield ScriptGenerationEvent(
type="error",
progress=0,
message="AI 返回的分镜数据为空或格式不正确",
)
return
# 转换为 ScriptShot 对象
shots = [ScriptShot(**shot) for shot in shots_data]
# 6. 完成阶段(80-100%)- 细分为多个步骤,让用户感知进度
yield ScriptGenerationEvent(
type="finalizing",
progress=80,
message=f"整理 {len(shots)} 个分镜...",
)
await asyncio.sleep(0.5)
yield ScriptGenerationEvent(
type="finalizing",
progress=85,
message="优化镜头顺序...",
)
await asyncio.sleep(0.5)
yield ScriptGenerationEvent(
type="finalizing",
progress=90,
message="检查时长分配...",
)
await asyncio.sleep(0.5)
yield ScriptGenerationEvent(
type="finalizing",
progress=95,
message="准备完成...",
)
await asyncio.sleep(0.5)
yield ScriptGenerationEvent(
type="complete",
progress=100,
message=f"成功生成 {len(shots)} 个分镜",
result=shots,
extracted_info=extracted_info,
)
except Exception as e:
logger.error(f"分镜数据标准化失败: {e}")
yield ScriptGenerationEvent(
type="error",
progress=0,
message=f"分镜数据处理失败: {str(e)}",
)
except Exception as e:
logger.exception("脚本生成失败")
yield ScriptGenerationEvent(
type="error",
progress=0,
message=f"生成失败: {str(e)}",
)
async def polish_content(
self,
content: str,
polish_type: str = "voiceover",
shot_type: str = "segment",
) -> str:
"""
润色内容
Args:
content: 待润色内容
polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文案)
shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色
Returns:
润色后的内容
"""
# 获取 model_router
model_router = await get_model_router()
# 从文件加载提示词模板
if polish_type == "scene":
# 画面润色需要根据镜头类型选择不同提示词
if shot_type == "empty_shot":
prompt_template = self._load_prompt("polish/scene_empty_shot")
else:
prompt_template = self._load_prompt("polish/scene_segment")
# 如果特定类型的提示词不存在,回退到通用 scene 提示词
if not prompt_template:
prompt_template = self._load_prompt("polish/scene")
else:
# 配音文案润色
prompt_template = self._load_prompt("polish/voiceover")
if not prompt_template:
# 最终回退
prompt_template = "请润色以下内容:\n\n{content}"
prompt = prompt_template.format(content=content)
result = await model_router.generate(
prompt=prompt,
task_type="polish",
temperature=0.5,
max_tokens=300,
)
return result.content.strip()
async def check_model_health(self) -> dict:
"""检查模型健康状态"""
model_router = await get_model_router()
health_results = await model_router.health_check()
models = []
available_count = 0
recommended = None
for provider_id, health in health_results.items():
model_info = {
"id": health.id,
"name": health.name,
"is_available": health.is_available,
"response_time": health.response_time,
"last_error": health.last_error,
}
models.append(model_info)
if health.is_available:
available_count += 1
if recommended is None or health.response_time < recommended.get(
"response_time", float("inf")
):
recommended = model_info
total = len(models)
return {
"status": "healthy" if available_count > 0 else "unhealthy",
"models": models,
"recommended_model": recommended,
"total_models": total,
"available_models": available_count,
}
async def test_model(self, model_id: str | None = None) -> dict:
"""测试指定模型连接"""
model_router = await get_model_router()
import time
start_time = time.time()
try:
result = await model_router.generate(
prompt="你好",
model_id=model_id,
max_tokens=5,
)
response_time = (time.time() - start_time) * 1000
return {
"success": True,
"model": result.model,
"response_time": round(response_time, 2),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
except Exception as e:
return {
"success": False,
"model": model_id or "default",
"error": str(e),
"checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
# 全局单例
_script_service: ScriptService | None = None
def get_script_service() -> ScriptService:
"""获取 ScriptService 单例"""
global _script_service
if _script_service is None:
_script_service = ScriptService()
return _script_service