""" Script 任务处理器 ================ 管理脚本生成的执行。 不占用 Volc 槽位,使用独立的 script 槽位池。 """ import logging from typing import Any from app.ai.prompts.loader import _load_system_meta from app.core.platform_config import get_platform_config_loader from app.db.session import AsyncSessionLocal from app.scheduler.handlers.base import AsyncHandler from app.scheduler.models import StateChange from app.scheduler.registry import TaskRegistry from app.scheduler.slot_manager import SlotManager from app.services import point_service as ps from app.services.script_service import ScriptService logger = logging.getLogger(__name__) def _get_category_name(category: str, filename: str) -> str: """从文件名解析文案作为 title""" meta = _load_system_meta() cat_name = category for cat in meta.get("categories", []): if cat.get("code") == category: cat_name = cat.get("name", category) break # 从文件名解析文案(前半部分) if filename: name = filename.replace(".txt", "") if "——" in name: label = name.split("——", 1)[0] return f"{cat_name} · {label}" return cat_name SLOT_KEY = "script:slots" def _get_script_max_slots() -> int: """从 platform-config.yaml 读取 rate_limit 配置作为 max_slots""" try: loader = get_platform_config_loader() platform = loader.get_platform("volcengine_ark") if platform: # LLM 推理是慢请求,max_slots 不应超过服务器承载能力 return min(int(platform.rate_limit_qps), 10) except Exception as e: logger.warning(f"读取脚本平台 rate_limit 配置失败: {e}") return 10 class ScriptHandler(AsyncHandler): name = "script" slot_key = SLOT_KEY max_slots = _get_script_max_slots() def __init__(self, service: ScriptService | None = None): self.service = service def _get_service(self) -> ScriptService: if self.service is None: raise RuntimeError( "ScriptHandler 需要通过构造函数传入 ScriptService 实例" ) return self.service async def tick( self, tasks: list[Any], registry: TaskRegistry, slots: SlotManager ) -> list[StateChange]: changes: list[StateChange] = [] for task in tasks: async with slots.acquire_ctx(SLOT_KEY, task.task_id, self.max_slots) as acquired: if not acquired: continue try: changes.extend(await self._process_task(task, registry, slots)) except Exception as e: logger.exception(f"[Script {task.task_id}] failed") changes.append(StateChange(task_id=task.task_id, field_path="status", value="failed")) changes.append( StateChange(task_id=task.task_id, field_path="error", value=str(e)[:500]) ) return changes async def _process_task( self, task: Any, registry: TaskRegistry, slots: SlotManager ) -> list[StateChange]: changes: list[StateChange] = [] params = task.params or {} category = params.get("category", "") filename = params.get("filename", "") await registry.update( task.task_id, status="running", progress=10, message="分析需求中...", completed=0, total=1, ) try: await registry.update( task.task_id, progress=10, message="构思脚本中...", ) service = self._get_service() shots = await service.generate_script( category=category, filename=filename, ) # 计算分镜真实总时长 total_duration = sum(s.duration for s in shots if s.duration) result_data = { "title": _get_category_name(category, filename), "scenes": [s.model_dump() for s in shots], "total_duration": total_duration, "shot_count": len(shots), } changes.append(StateChange(task_id=task.task_id, field_path="status", value="completed")) changes.append(StateChange(task_id=task.task_id, field_path="progress", value=100)) changes.append( StateChange(task_id=task.task_id, field_path="message", value="脚本生成完成") ) changes.append(StateChange(task_id=task.task_id, field_path="completed", value=1)) changes.append(StateChange(task_id=task.task_id, field_path="total", value=1)) changes.append(StateChange(task_id=task.task_id, field_path="result", value=result_data)) # 后置扣费(独立 session,失败不影响任务结果) try: async with AsyncSessionLocal() as db: points = ps._calculate_cost("script") await ps.consume( db, user_id=task.user_id, points=points, source_type="script", source_id=task.task_id, description="【脚本生成】", ) await db.commit() except Exception as e: logger.error(f"[Script {task.task_id}] 扣费失败: {e}") except Exception as exc: logger.exception(f"[ScriptTask {task.task_id}] Failed") changes.append(StateChange(task_id=task.task_id, field_path="status", value="failed")) changes.append( StateChange(task_id=task.task_id, field_path="message", value=str(exc)[:200]) ) changes.append(StateChange(task_id=task.task_id, field_path="error", value=str(exc)[:500])) return changes