""" Async Engine 核心调度器 ======================= 驱动所有 Handler 的 Tick 循环,批量查询、批量更新。 """ import asyncio import logging from datetime import UTC, datetime from typing import Any from app.core.redis_client import get_redis_client from app.scheduler.handlers.base import AsyncHandler from app.scheduler.models import StateChange from app.scheduler.registry import TaskRegistry from app.scheduler.slot_manager import SlotManager logger = logging.getLogger(__name__) # 各任务类型最大执行时间(秒),超过后自动标记为 failed TASK_TIMEOUT_SECONDS = { "script": 5 * 60, "subtitle": 10 * 60, "video": 30 * 60, } class AsyncEngine: """统一异步作业调度引擎""" def __init__(self, handlers: list[AsyncHandler] | None = None): self.redis = get_redis_client() self.registry = TaskRegistry(self.redis) self.slots = SlotManager(self.redis) self.handlers: dict[str, AsyncHandler] = {} if handlers: for h in handlers: self.handlers[h.name] = h def register(self, handler: AsyncHandler) -> None: """注册一个 Handler""" self.handlers[handler.name] = handler logger.info(f"Registered handler: {handler.name}") async def tick(self) -> None: """执行一次完整的调度 Tick""" tick_start = asyncio.get_event_loop().time() try: # 1. 加载所有 running 的作业 ID running_ids = await self.registry.get_running_task_ids() if not running_ids: logger.debug("Tick: no running tasks") return # 2. 按 task_type 分组,并处理超时任务 tasks_by_type: dict[str, list[Any]] = {} timeout_changes: list[StateChange] = [] now = datetime.now(UTC) for task_id in running_ids: record = await self.registry.get(task_id) if not record: await self.registry.remove_running(task_id) continue max_duration = TASK_TIMEOUT_SECONDS.get(record.task_type) is_timeout = ( max_duration and record.status == "running" and record.created_at and (now - datetime.fromisoformat(record.created_at)).total_seconds() > max_duration ) if is_timeout: logger.warning( f"Task timeout: {task_id}, type={record.task_type}, " f"created_at={record.created_at}" ) timeout_changes.append( StateChange(task_id=task_id, field_path="status", value="failed") ) timeout_changes.append( StateChange( task_id=task_id, field_path="message", value="任务执行超时,请稍后重试", ) ) timeout_changes.append( StateChange( task_id=task_id, field_path="error", value=f"任务执行超过 {max_duration} 秒", ) ) await self.registry.remove_running(task_id) continue tasks_by_type.setdefault(record.task_type, []).append(record) # 3. 并行执行各 Handler 的 tick results = await asyncio.gather( *[ self._safe_tick(handler_name, handler, tasks_by_type.get(handler_name, [])) for handler_name, handler in self.handlers.items() ] ) # 4. 收集并应用状态变更(包含超时任务) all_changes: list[StateChange] = [] for changes in results: if changes: all_changes.extend(changes) all_changes.extend(timeout_changes) if all_changes: await self._apply_changes(all_changes) # 5. 清理已结束的作业 await self._cleanup_finished() except Exception: logger.exception("Scheduler tick failed") finally: elapsed = asyncio.get_event_loop().time() - tick_start logger.debug(f"Tick completed in {elapsed:.2f}s") # 写入心跳,供 healthcheck 检查 await self.redis.set("scheduler:heartbeat", str(asyncio.get_event_loop().time()), ex=60) async def _safe_tick( self, name: str, handler: AsyncHandler, tasks: list[Any] ) -> list[StateChange]: """安全执行 Handler tick,捕获异常""" try: return await handler.tick(tasks, self.registry, self.slots) except Exception: logger.exception(f"Handler tick failed: {name}") return [] async def _apply_changes(self, changes: list[StateChange]) -> None: """批量应用状态变更到 Redis""" pipe = self.redis.pipeline() executed = False for change in changes: key, field, value = change.to_redis_command() pipe.hset(key, field, value) executed = True if executed: await pipe.execute() async def _cleanup_finished(self) -> None: """清理已完成的作业""" running_ids = await self.registry.get_running_task_ids() for task_id in running_ids: record = await self.registry.get(task_id) if not record: await self.registry.remove_running(task_id) continue if record.status in ("completed", "failed"): await self.registry.remove_running(task_id) logger.info(f"Task moved to finished: {task_id} ({record.status})") async def run_forever(self, interval: float = 10.0, min_interval: float = 2.0) -> None: """启动无限 Tick 循环""" logger.info("Async Engine started") while True: tick_start = asyncio.get_event_loop().time() await self.tick() elapsed = asyncio.get_event_loop().time() - tick_start sleep_time = max(interval - elapsed, min_interval) await asyncio.sleep(sleep_time)