""" 脚本生成服务 ============ """ import asyncio import logging import math import re import time from collections.abc import AsyncIterator from pathlib import Path from app.ai.model_router import get_model_router from app.ai.prompts import load_script_user_prompt, load_system_prompt from app.schemas.script import ScriptGenerationEvent, ScriptShot from app.services.ai_response_utils import ( safe_parse_ai_json_response, validate_and_normalize_shots, ) logger = logging.getLogger(__name__) class ScriptService: """脚本生成服务""" # 根据视频时长估算输出字符数(经验值) # 格式: {时长: (最小字符数, 最大字符数)} DURATION_ESTIMATES = { 30: (800, 1200), 45: (1200, 1600), 60: (1500, 2000), 90: (2000, 2800), } def __init__(self): self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts" def _estimate_total_chars(self, duration: int) -> int: """ 根据时长估算总输出字符数 Args: duration: 视频时长(秒) Returns: 预估字符数 """ # 找到最接近的预设 closest_duration = min(self.DURATION_ESTIMATES.keys(), key=lambda x: abs(x - duration)) min_chars, max_chars = self.DURATION_ESTIMATES[closest_duration] # 根据实际时长在区间内插值 ratio = duration / closest_duration estimated = int(min_chars + (max_chars - min_chars) * ratio) logger.debug(f"时长 {duration}s 预估字符数: {estimated}") return estimated def _calculate_progress( self, current_chars: int, estimated_total: int, elapsed_time: float, min_expected_time: float = 5.0, ) -> int: """ 计算平滑进度(使用对数曲线)- 优化版,避免抖动 设计思路: - 主要基于内容生成进度,时间只作为保底 - 使用单调递增函数,确保进度只增不减 - 前 20% 内容:进度到 30%(慢启动) - 中间 60% 内容:进度到 75%(稳定生成期) - 最后 20% 内容:进度到 85%(收尾阶段) Args: current_chars: 当前字符数 estimated_total: 预估总字符数 elapsed_time: 已过去时间(秒) min_expected_time: 最少预期的生成时间(避免太快跑完) Returns: 进度百分比 (0-85) """ # 基于内容的进度(对数曲线) ratio = min(current_chars / estimated_total, 1.5) # 允许超出生成 50% # 对数曲线:前期慢,后期快 if ratio <= 1.0: # 未完成或刚好完成:使用调整后的对数曲线,最高到 85% progress_ratio = math.log(1 + ratio * 2) / math.log(3) * 0.85 else: # 已超出生成:从 85% 线性增长,最多到 95%(预留空间给后续阶段) progress_ratio = 0.85 + min((ratio - 1) * 0.2, 0.1) # 基于时间的保底进度(只在内容很少时生效,避免生成太快时进度条没动) # 使用平滑的保底函数,只在前期(前3秒)和内容很少时生效 time_progress = 0 if current_chars < estimated_total * 0.3 and elapsed_time < min_expected_time: # 内容生成少于30%且时间少于5秒时,提供保底进度 time_progress = min(elapsed_time / min_expected_time * 0.15, 0.15) # 取较大值,但主要依赖 progress_ratio,time_progress 只作为早期保底 final_ratio = max(progress_ratio, time_progress) # 生成阶段最高到 85% return min(int(final_ratio * 100), 85) def _load_prompt(self, name: str) -> str: """加载 Prompt 模板""" prompt_file = self.prompts_dir / f"{name}.txt" if prompt_file.exists(): return prompt_file.read_text(encoding="utf-8") return "" @staticmethod def _extract_json(content: str) -> str: """ 从 Markdown 代码块中提取 JSON,或返回原始内容 支持格式: - ```json {...} ``` - ``` {...} ``` - 纯 JSON 文本 """ if not content: return "" content = content.strip() # 匹配 ```json ... ``` 或 ``` ... ``` pattern = r"```(?:json)?\s*([\s\S]*?)\s*```" matches = re.findall(pattern, content) if matches: # 取最后一个匹配(避免前面有示例代码) return matches[-1].strip() # 如果没有代码块,返回原始内容 return content async def generate_script( self, category: str, subcategory: str, duration: int, script_type: str, model: str | None = None, ) -> list[ScriptShot]: """ 同步生成脚本 Args: category: 大类代码,如 "bk" subcategory: 小类代码,如 "ht" duration: 视频时长(秒) script_type: 脚本类型 model: 指定模型 Returns: 分镜列表 """ # 获取 model_router model_router = await get_model_router() # 加载 Prompt system_prompt = load_system_prompt(category, subcategory) if not system_prompt: raise ValueError(f"未找到提示词: category={category}, subcategory={subcategory}") # 用户提示词 user_prompt = load_script_user_prompt( topic=f"{category}/{subcategory}", duration=duration, ) logger.info(f"同步生成脚本: category={category}, subcategory={subcategory}, duration={duration}") # 调用 AI 生成 result = await model_router.generate( prompt=user_prompt, system_prompt=system_prompt, model_id=model, task_type="script", temperature=0.7, ) # 检查返回内容 if not result.content or not result.content.strip(): logger.error("AI 返回内容为空") raise ValueError("AI 返回内容为空,请检查模型配置或重试") logger.info(f"AI 返回内容长度: {len(result.content)} 字符") # 使用安全的 JSON 解析 success, parsed_data, error_msg = safe_parse_ai_json_response(result.content) if not success: logger.error(f"JSON 解析失败: {error_msg}") logger.error(f"原始内容: {result.content[:500]!r}") raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON") # 验证并标准化分镜数据 try: shots_data = validate_and_normalize_shots(parsed_data) if not shots_data: raise ValueError("AI 返回的分镜数据为空或格式不正确") # 转换为 ScriptShot 对象 shots = [ScriptShot(**shot) for shot in shots_data] logger.info(f"成功解析 {len(shots)} 个分镜") return shots except Exception as e: logger.error(f"分镜数据标准化失败: {e}") raise ValueError(f"分镜数据处理失败: {str(e)}") async def generate_script_stream( self, category: str, subcategory: str, duration: int, script_type: str, model: str | None = None, ) -> AsyncIterator[ScriptGenerationEvent]: """ 流式生成脚本(SSE) 进度设计: - 0-5%: start(初始化) - 5-15%: analyzing(分析主题) - 15-85%: generating(AI 生成,平滑对数曲线增长) - 85-92%: validating(JSON 验证) - 92-98%: parsing(解析分镜) - 98-100%: complete(完成) """ model_router = await get_model_router() start_time = time.time() try: # 加载 Prompt system_prompt = load_system_prompt(category, subcategory) if not system_prompt: yield ScriptGenerationEvent( type="error", progress=0, message=f"未找到提示词: category={category}, subcategory={subcategory}", ) return user_prompt = load_script_user_prompt( topic=f"{category}/{subcategory}", duration=duration, ) logger.info(f"流式生成脚本: category={category}, subcategory={subcategory}, duration={duration}") # 1. 开始阶段(0-5%) yield ScriptGenerationEvent( type="start", progress=2, message="准备生成脚本...", ) # 2. 分析阶段(5-15%) yield ScriptGenerationEvent( type="analyzing", progress=10, message="分析创作要点", ) # 估算总长度(根据时长) estimated_total = self._estimate_total_chars(duration) # 3. 生成阶段(15-55%) yield ScriptGenerationEvent( type="generating", progress=15, message="正在创作脚本...", ) full_content = "" last_progress = 15 last_update_time = start_time update_interval = 0.5 # 最少 500ms 更新一次 chunk_count = 0 logger.info(f"开始流式生成: category={category}, subcategory={subcategory}, duration={duration}") async for chunk in model_router.generate_stream_with_progress( prompt=user_prompt, system_prompt=system_prompt, model_id=model, task_type="script", temperature=0.7, ): chunk_count += 1 if chunk["type"] == "chunk": chunk_content = chunk.get("content", "") if not chunk_content: logger.warning(f"收到空 chunk,序号: {chunk_count}") continue full_content += chunk_content current_chars = len(full_content) elapsed = time.time() - start_time # 计算平滑进度(对数曲线,最高到55%) base_progress = self._calculate_progress( current_chars=current_chars, estimated_total=estimated_total, elapsed_time=elapsed, ) # 将原来的 15-85 映射到 15-55 progress = 15 + int((base_progress - 15) * 40 / 70) # 限制更新频率,但确保每次有变化都上报(最小 2% 变化) current_time = time.time() if progress > last_progress and ( progress - last_progress >= 2 or current_time - last_update_time >= update_interval ): yield ScriptGenerationEvent( type="generating", progress=progress, message="正在创作脚本...", ) last_progress = progress last_update_time = current_time elif chunk["type"] == "usage": prompt_tokens = chunk.get("prompt_tokens", 0) completion_tokens = chunk.get("completion_tokens", 0) logger.info( f"Token 使用: prompt={prompt_tokens}, completion={completion_tokens}" ) logger.info(f"流式生成结束: 共 {chunk_count} 个 chunk, {len(full_content)} 字符") # 4. 验证阶段(55-70%) actual_chars = len(full_content) logger.info(f"生成完成: {actual_chars} 字符 (预估: {estimated_total})") yield ScriptGenerationEvent( type="validating", progress=60, message="验证脚本格式...", ) await asyncio.sleep(0.5) yield ScriptGenerationEvent( type="validating", progress=65, message="检查数据完整性...", ) await asyncio.sleep(0.5) yield ScriptGenerationEvent( type="validating", progress=70, message="验证通过", ) await asyncio.sleep(0.5) # 5. 解析阶段(70-80%) yield ScriptGenerationEvent( type="parsing", progress=75, message="解析分镜内容...", ) # 检查内容是否为空 if not full_content or not full_content.strip(): logger.error("AI 返回内容为空") yield ScriptGenerationEvent( type="error", progress=0, message="AI 返回内容为空,请检查模型配置或重试", ) return # 记录原始内容(调试用) logger.info(f"AI 原始输出: {full_content[:500]}...") # 使用安全的 JSON 解析 success, parsed_data, error_msg = safe_parse_ai_json_response(full_content) if not success: logger.error(f"JSON 解析失败: {error_msg}") logger.error(f"原始内容前500字符: {full_content[:500]!r}") # 给前端更详细的错误信息 error_detail = error_msg or "无法解析 AI 返回的内容" if not full_content or not full_content.strip(): error_detail = "AI 返回内容为空,请检查模型配置或重试" yield ScriptGenerationEvent( type="error", progress=0, message=f"脚本解析失败: {error_detail}", ) return # 验证并标准化分镜数据 try: shots_data = validate_and_normalize_shots(parsed_data) if not shots_data: logger.error("标准化后分镜列表为空") yield ScriptGenerationEvent( type="error", progress=0, message="AI 返回的分镜数据为空或格式不正确", ) return # 转换为 ScriptShot 对象 shots = [ScriptShot(**shot) for shot in shots_data] # 6. 完成阶段(80-100%)- 细分为多个步骤,让用户感知进度 yield ScriptGenerationEvent( type="finalizing", progress=80, message=f"整理 {len(shots)} 个分镜...", ) await asyncio.sleep(0.5) yield ScriptGenerationEvent( type="finalizing", progress=85, message="优化镜头顺序...", ) await asyncio.sleep(0.5) yield ScriptGenerationEvent( type="finalizing", progress=90, message="检查时长分配...", ) await asyncio.sleep(0.5) yield ScriptGenerationEvent( type="finalizing", progress=95, message="准备完成...", ) await asyncio.sleep(0.5) yield ScriptGenerationEvent( type="complete", progress=100, message=f"成功生成 {len(shots)} 个分镜", result=shots, ) except Exception as e: logger.error(f"分镜数据标准化失败: {e}") yield ScriptGenerationEvent( type="error", progress=0, message=f"分镜数据处理失败: {str(e)}", ) except Exception as e: logger.exception("脚本生成失败") yield ScriptGenerationEvent( type="error", progress=0, message=f"生成失败: {str(e)}", ) async def polish_content( self, content: str, polish_type: str = "voiceover", shot_type: str = "segment", ) -> str: """ 润色内容 Args: content: 待润色内容 polish_type: 润色类型,可选 "scene"(画面描述)或 "voiceover"(配音文案) shot_type: 镜头类型,可选 "segment"(分镜)或 "empty_shot"(空镜),仅用于画面润色 Returns: 润色后的内容 """ # 获取 model_router model_router = await get_model_router() # 从文件加载提示词模板 if polish_type == "scene": # 画面润色需要根据镜头类型选择不同提示词 if shot_type == "empty_shot": prompt_template = self._load_prompt("polish/scene_empty_shot") else: prompt_template = self._load_prompt("polish/scene_segment") # 如果特定类型的提示词不存在,回退到通用 scene 提示词 if not prompt_template: prompt_template = self._load_prompt("polish/scene") else: # 配音文案润色 prompt_template = self._load_prompt("polish/voiceover") if not prompt_template: # 最终回退 prompt_template = "请润色以下内容:\n\n{content}" prompt = prompt_template.format(content=content) result = await model_router.generate( prompt=prompt, task_type="polish", temperature=0.5, max_tokens=300, ) return result.content.strip() async def check_model_health(self) -> dict: """检查模型健康状态""" model_router = await get_model_router() health_results = await model_router.health_check() models = [] available_count = 0 recommended = None for provider_id, health in health_results.items(): model_info = { "id": health.id, "name": health.name, "is_available": health.is_available, "response_time": health.response_time, "last_error": health.last_error, } models.append(model_info) if health.is_available: available_count += 1 if recommended is None or health.response_time < recommended.get( "response_time", float("inf") ): recommended = model_info total = len(models) return { "status": "healthy" if available_count > 0 else "unhealthy", "models": models, "recommended_model": recommended, "total_models": total, "available_models": available_count, } async def test_model(self, model_id: str | None = None) -> dict: """测试指定模型连接""" model_router = await get_model_router() import time start_time = time.time() try: result = await model_router.generate( prompt="你好", model_id=model_id, max_tokens=5, ) response_time = (time.time() - start_time) * 1000 return { "success": True, "model": result.model, "response_time": round(response_time, 2), "checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"), } except Exception as e: return { "success": False, "model": model_id or "default", "error": str(e), "checked_at": time.strftime("%Y-%m-%dT%H:%M:%S"), } # 全局单例 _script_service: ScriptService | None = None def get_script_service() -> ScriptService: """获取 ScriptService 单例""" global _script_service if _script_service is None: _script_service = ScriptService() return _script_service