Files

173 lines
6.4 KiB
Python

"""
Async Engine 核心调度器
=======================
驱动所有 Handler 的 Tick 循环,批量查询、批量更新。
"""
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from app.core.redis_client import get_redis_client
from app.scheduler.handlers.base import AsyncHandler
from app.scheduler.models import StateChange
from app.scheduler.registry import TaskRegistry
from app.scheduler.slot_manager import SlotManager
logger = logging.getLogger(__name__)
# 各任务类型最大执行时间(秒),超过后自动标记为 failed
TASK_TIMEOUT_SECONDS = {
"script": 5 * 60,
"subtitle": 10 * 60,
"video": 30 * 60,
}
class AsyncEngine:
"""统一异步作业调度引擎"""
def __init__(self, handlers: list[AsyncHandler] | None = None):
self.redis = get_redis_client()
self.registry = TaskRegistry(self.redis)
self.slots = SlotManager(self.redis)
self.handlers: dict[str, AsyncHandler] = {}
if handlers:
for h in handlers:
self.handlers[h.name] = h
def register(self, handler: AsyncHandler) -> None:
"""注册一个 Handler"""
self.handlers[handler.name] = handler
logger.info(f"Registered handler: {handler.name}")
async def tick(self) -> None:
"""执行一次完整的调度 Tick"""
tick_start = asyncio.get_event_loop().time()
try:
# 1. 加载所有 running 的作业 ID
running_ids = await self.registry.get_running_task_ids()
if not running_ids:
logger.debug("Tick: no running tasks")
return
# 2. 按 task_type 分组,并处理超时任务
tasks_by_type: dict[str, list[Any]] = {}
timeout_changes: list[StateChange] = []
now = datetime.now(UTC)
for task_id in running_ids:
record = await self.registry.get(task_id)
if not record:
await self.registry.remove_running(task_id)
continue
max_duration = TASK_TIMEOUT_SECONDS.get(record.task_type)
is_timeout = (
max_duration
and record.status == "running"
and record.created_at
and (now - datetime.fromisoformat(record.created_at)).total_seconds() > max_duration
)
if is_timeout:
logger.warning(
f"Task timeout: {task_id}, type={record.task_type}, "
f"created_at={record.created_at}"
)
timeout_changes.append(
StateChange(task_id=task_id, field_path="status", value="failed")
)
timeout_changes.append(
StateChange(
task_id=task_id,
field_path="message",
value="任务执行超时,请稍后重试",
)
)
timeout_changes.append(
StateChange(
task_id=task_id,
field_path="error",
value=f"任务执行超过 {max_duration}",
)
)
await self.registry.remove_running(task_id)
continue
tasks_by_type.setdefault(record.task_type, []).append(record)
# 3. 并行执行各 Handler 的 tick
results = await asyncio.gather(
*[
self._safe_tick(handler_name, handler, tasks_by_type.get(handler_name, []))
for handler_name, handler in self.handlers.items()
]
)
# 4. 收集并应用状态变更(包含超时任务)
all_changes: list[StateChange] = []
for changes in results:
if changes:
all_changes.extend(changes)
all_changes.extend(timeout_changes)
if all_changes:
await self._apply_changes(all_changes)
# 5. 清理已结束的作业
await self._cleanup_finished()
except Exception:
logger.exception("Scheduler tick failed")
finally:
elapsed = asyncio.get_event_loop().time() - tick_start
logger.debug(f"Tick completed in {elapsed:.2f}s")
# 写入心跳,供 healthcheck 检查
await self.redis.set("scheduler:heartbeat", str(asyncio.get_event_loop().time()), ex=60)
async def _safe_tick(
self, name: str, handler: AsyncHandler, tasks: list[Any]
) -> list[StateChange]:
"""安全执行 Handler tick,捕获异常"""
try:
return await handler.tick(tasks, self.registry, self.slots)
except Exception:
logger.exception(f"Handler tick failed: {name}")
return []
async def _apply_changes(self, changes: list[StateChange]) -> None:
"""批量应用状态变更到 Redis"""
pipe = self.redis.pipeline()
executed = False
for change in changes:
key, field, value = change.to_redis_command()
pipe.hset(key, field, value)
executed = True
if executed:
await pipe.execute()
async def _cleanup_finished(self) -> None:
"""清理已完成的作业"""
running_ids = await self.registry.get_running_task_ids()
for task_id in running_ids:
record = await self.registry.get(task_id)
if not record:
await self.registry.remove_running(task_id)
continue
if record.status in ("completed", "failed"):
await self.registry.remove_running(task_id)
logger.info(f"Task moved to finished: {task_id} ({record.status})")
async def run_forever(self, interval: float = 10.0, min_interval: float = 2.0) -> None:
"""启动无限 Tick 循环"""
logger.info("Async Engine started")
while True:
tick_start = asyncio.get_event_loop().time()
await self.tick()
elapsed = asyncio.get_event_loop().time() - tick_start
sleep_time = max(interval - elapsed, min_interval)
await asyncio.sleep(sleep_time)